diff --git a/source/source_base/module_container/ATen/kernels/cuda/lapack.cu b/source/source_base/module_container/ATen/kernels/cuda/lapack.cu index 4c69973b6be..bfdd32c8a28 100644 --- a/source/source_base/module_container/ATen/kernels/cuda/lapack.cu +++ b/source/source_base/module_container/ATen/kernels/cuda/lapack.cu @@ -325,14 +325,22 @@ struct lapack_hegvx { const char uplo = 'U'; int meig = 0; - // this hegvdx will protect the input A, B from being overwritten - // and write the eigenvectors into eigen_vec. + // cuSOLVER hegvdx overwrites A and B on exit, so copy A to + // eigen_vec and backup B to protect the origin matrices. + CHECK_CUDA(cudaMemcpy(eigen_vec, A, sizeof(T) * n * lda, cudaMemcpyDeviceToDevice)); + + T* d_B_backup = nullptr; + CHECK_CUDA(cudaMalloc(&d_B_backup, sizeof(T) * n * lda)); + CHECK_CUDA(cudaMemcpy(d_B_backup, B, sizeof(T) * n * lda, cudaMemcpyDeviceToDevice)); + cuSolverConnector::hegvdx(cusolver_handle, itype, jobz, range, uplo, - n, lda, A, B, + n, lda, eigen_vec, d_B_backup, Real(0), Real(0), 1, m, &meig, eigen_val, eigen_vec); + + CHECK_CUDA(cudaFree(d_B_backup)); } }; diff --git a/source/source_base/module_container/ATen/kernels/lapack.cpp b/source/source_base/module_container/ATen/kernels/lapack.cpp index 2ab02f35c81..d294b899075 100644 --- a/source/source_base/module_container/ATen/kernels/lapack.cpp +++ b/source/source_base/module_container/ATen/kernels/lapack.cpp @@ -382,7 +382,7 @@ struct lapack_hegvx { const int itype = 1; // ITYPE = 1: A*x = (lambda)*B*x const char jobz = 'V';// JOBZ = 'V': Compute eigenvalues and eigenvectors. const char range = 'I'; // RANGE = 'I': the IL-th through IU-th eigenvalues will be found. - const char uplo = 'L'; // UPLO = 'L': Lower triangles of A and B are stored. + const char uplo = 'U'; // UPLO = 'U': Upper triangles of A and B are stored. const int il = 1; const int iu = m; @@ -394,6 +394,13 @@ struct lapack_hegvx { T work_query; Real rwork_query; + // dummy arrays for workspace query (some LAPACK implementations + // require valid pointers even during query) + const int liwork_query = 5 * n; + const int lrwork_query = 7 * n; + std::vector iwork_query(liwork_query); + std::vector ifail_query(n); + // set lwork = -1 to query optimal work size lapackConnector::hegvx( itype, jobz, range, uplo, @@ -409,8 +416,8 @@ struct lapack_hegvx { &work_query, // WORK (query) lwork, &rwork_query, // RWORK (query) - static_cast(nullptr), // IWORK (query) - static_cast(nullptr), // IFAIL (query) + iwork_query.data(), // IWORK (query) + ifail_query.data(), // IFAIL (query) info); // !> If LWORK = -1, then a workspace query is assumed; the routine diff --git a/source/source_base/module_container/ATen/kernels/rocm/lapack.hip.cu b/source/source_base/module_container/ATen/kernels/rocm/lapack.hip.cu index 07572a657ab..219d2366723 100644 --- a/source/source_base/module_container/ATen/kernels/rocm/lapack.hip.cu +++ b/source/source_base/module_container/ATen/kernels/rocm/lapack.hip.cu @@ -133,6 +133,34 @@ struct lapack_hegvd { } }; +template +struct lapack_hegvx { + using Real = typename GetTypeReal::type; + void operator()( + const int n, + const int lda, + T* A, + T* B, + const int m, + Real* eigen_val, + T* eigen_vec) + { + // Fallback to CPU: copy matrices to host, call CPU lapack_hegvx, copy results back + std::vector H_A(n * lda); + std::vector H_B(n * lda); + std::vector H_eigen_val(n); + std::vector H_eigen_vec(n * lda); + + hipErrcheck(hipMemcpy(H_A.data(), A, sizeof(T) * n * lda, hipMemcpyDeviceToHost)); + hipErrcheck(hipMemcpy(H_B.data(), B, sizeof(T) * n * lda, hipMemcpyDeviceToHost)); + + lapack_hegvx()(n, lda, H_A.data(), H_B.data(), m, H_eigen_val.data(), H_eigen_vec.data()); + + hipErrcheck(hipMemcpy(eigen_val, H_eigen_val.data(), sizeof(Real) * n, hipMemcpyHostToDevice)); + hipErrcheck(hipMemcpy(eigen_vec, H_eigen_vec.data(), sizeof(T) * n * lda, hipMemcpyHostToDevice)); + } +}; + template struct set_matrix; template struct set_matrix; template struct set_matrix, DEVICE_GPU>; @@ -158,5 +186,10 @@ template struct lapack_hegvd; template struct lapack_hegvd, DEVICE_GPU>; template struct lapack_hegvd, DEVICE_GPU>; +template struct lapack_hegvx; +template struct lapack_hegvx; +template struct lapack_hegvx, DEVICE_GPU>; +template struct lapack_hegvx, DEVICE_GPU>; + } // namespace kernels } // namespace container diff --git a/source/source_hsolver/diago_dav_subspace.cpp b/source/source_hsolver/diago_dav_subspace.cpp index 048653c2a87..3d077059d54 100644 --- a/source/source_hsolver/diago_dav_subspace.cpp +++ b/source/source_hsolver/diago_dav_subspace.cpp @@ -673,52 +673,17 @@ void Diago_DavSubspace::diag_zhegvx(const int& nbase, } #endif } - else + else if (this->diag_subspace == 0) { - if (this->diag_subspace == 0) + if (this->diag_comm.rank == 0) { - if (this->diag_comm.rank == 0) - { - std::vector> h_diag(nbase, std::vector(nbase, *this->zero)); - std::vector> s_diag(nbase, std::vector(nbase, *this->zero)); - - for (size_t i = 0; i < nbase; i++) - { - for (size_t j = 0; j < nbase; j++) - { - h_diag[i][j] = hcc[i * this->nbase_x + j]; - s_diag[i][j] = scc[i * this->nbase_x + j]; - } - } - hegvx_op()(this->ctx, - nbase, - this->nbase_x, - this->hcc, - this->scc, - nband, - (*eigenvalue_iter).data(), - this->vcc); - // reset: - for (size_t i = 0; i < nbase; i++) - { - for (size_t j = 0; j < nbase; j++) - { - hcc[i * this->nbase_x + j] = h_diag[i][j]; - scc[i * this->nbase_x + j] = s_diag[i][j]; - } - - for (size_t j = nbase; j < this->nbase_x; j++) - { - hcc[i * this->nbase_x + j] = *this->zero; - hcc[j * this->nbase_x + i] = *this->zero; - scc[i * this->nbase_x + j] = *this->zero; - scc[j * this->nbase_x + i] = *this->zero; - } - } - } + ct::kernels::lapack_hegvx()( + nbase, this->nbase_x, this->hcc, this->scc, nband, + (*eigenvalue_iter).data(), this->vcc); } - else - { + } + else + { #ifdef __MPI std::vector h_diag; std::vector s_diag; @@ -760,7 +725,6 @@ void Diago_DavSubspace::diag_zhegvx(const int& nbase, std::cout << "Error: parallel diagonalization is not supported in serial mode." << std::endl; exit(1); #endif - } } #ifdef __MPI