Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions source/source_esolver/esolver_ks_lcao_tddft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ ESolver_KS_LCAO_TDDFT<TR, Device>::~ESolver_KS_LCAO_TDDFT()
delete td_p;
}
TD_info::td_vel_op = nullptr;

if (td_mg_ != nullptr)
{
delete td_mg_;
td_mg_ = nullptr;
}
}

template <typename TR, typename Device>
Expand Down Expand Up @@ -94,6 +100,16 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::runner(UnitCell& ucell, const int istep)
// 1) before_scf (electronic iteration loops)
//----------------------------------------------------------------
this->before_scf(ucell, istep); // From ESolver_KS_LCAO

// Initialize the moving spatial gauge
if (use_td_moving_gauge && this->td_mg_ == nullptr)
{
this->td_mg_ = new module_rt::TD_MovingGauge();
auto* hamilt_lcao = dynamic_cast<hamilt::HamiltLCAO<std::complex<double>, TR>*>(this->p_hamilt);
const hamilt::HContainer<TR>* sR_template = hamilt_lcao->getSR();
this->td_mg_->init_DR(sR_template, &ucell, &this->pv, this->two_center_bundle_.overlap_orb.get());
}

if (PARAM.inp.td_stype == 2)
{
this->dmat.dm->cal_DMR_td(ucell, TD_info::cart_At);
Expand Down Expand Up @@ -242,6 +258,14 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::hamilt2rho_single(UnitCell& ucell,
const int iter,
const double ethr)
{
// Update the moving spatial gauge
if (use_td_moving_gauge)
{
auto* hamilt_lcao = dynamic_cast<hamilt::HamiltLCAO<std::complex<double>, TR>*>(this->p_hamilt);
const hamilt::HContainer<TR>* sR_template = hamilt_lcao->getSR();
this->td_mg_->update_DR(sR_template, &ucell, &this->pv, this->two_center_bundle_.overlap_orb.get());
}

if (PARAM.inp.init_wfc == "file")
{
if (istep >= TD_info::estep_shift + 1)
Expand All @@ -261,7 +285,11 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::hamilt2rho_single(UnitCell& ucell,
GlobalV::ofs_running,
PARAM.inp.propagator,
use_tensor,
use_lapack);
use_lapack,
this->td_mg_,
&ucell,
this->kv.kvec_d,
use_td_moving_gauge);
}
this->weight_dm_rho(ucell);
}
Expand All @@ -281,7 +309,11 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::hamilt2rho_single(UnitCell& ucell,
GlobalV::ofs_running,
PARAM.inp.propagator,
use_tensor,
use_lapack);
use_lapack,
this->td_mg_,
&ucell,
this->kv.kvec_d,
use_td_moving_gauge);
this->weight_dm_rho(ucell);
}
else
Expand Down
5 changes: 5 additions & 0 deletions source/source_esolver/esolver_ks_lcao_tddft.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "source_lcao/module_rt/gather_mat.h" // MPI gathering and distributing functions
#include "source_lcao/module_rt/kernels/cublasmp_context.h"
#include "source_lcao/module_rt/td_info.h"
#include "source_lcao/module_rt/td_moving_gauge.h"
#include "source_lcao/module_rt/velocity_op.h"

namespace ModuleESolver
Expand Down Expand Up @@ -66,6 +67,10 @@ class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO<std::complex<double>, TR>

TD_info* td_p = nullptr;

//! Moving spatial gauge for Ehrenfest dynamics, to calculate the correction term arising from the movement of basis
bool use_td_moving_gauge = false;
module_rt::TD_MovingGauge* td_mg_ = nullptr;

//! Restart flag
bool restart_done = false;

Expand Down
1 change: 1 addition & 0 deletions source/source_lcao/module_rt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ if(ENABLE_LCAO)
td_folding.cpp
solve_propagation.cpp
boundary_fix.cpp
td_moving_gauge.cpp
)

if(USE_CUDA)
Expand Down
19 changes: 16 additions & 3 deletions source/source_lcao/module_rt/evolve_elec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
namespace module_rt
{
template <typename Device>
Evolve_elec<Device>::Evolve_elec(){};
Evolve_elec<Device>::Evolve_elec() {};
template <typename Device>
Evolve_elec<Device>::~Evolve_elec(){};
Evolve_elec<Device>::~Evolve_elec() {};

template <typename Device>
ct::DeviceType Evolve_elec<Device>::ct_device_type = ct::DeviceTypeToEnum<Device>::value;
Expand All @@ -33,7 +33,11 @@ void Evolve_elec<Device>::solve_psi(const int& istep,
std::ofstream& ofs_running,
const int propagator,
const bool use_tensor,
const bool use_lapack)
const bool use_lapack,
module_rt::TD_MovingGauge* td_mg,
const UnitCell* ucell,
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
const bool use_td_moving_gauge)
{
ModuleBase::TITLE("Evolve_elec", "solve_psi");
ModuleBase::timer::start("Evolve_elec", "solve_psi");
Expand All @@ -57,6 +61,13 @@ void Evolve_elec<Device>::solve_psi(const int& istep,

if (!use_tensor)
{
// Construct the local P_k matrix for moving spatial gauge, CPU only for now
std::vector<std::complex<double>> P_k_local(para_orb.nloc, {0.0, 0.0});
if (use_td_moving_gauge && td_mg != nullptr)
{
td_mg->get_P_k(ucell, kvec_d[ik], P_k_local.data(), para_orb.nloc, para_orb.ncol);
}

const int len_HS_laststep = use_lapack ? nlocal * nlocal : para_orb.nloc;
evolve_psi(nband,
nlocal,
Expand All @@ -66,6 +77,8 @@ void Evolve_elec<Device>::solve_psi(const int& istep,
psi_laststep[0].get_pointer(),
Hk_laststep.data<std::complex<double>>() + ik * len_HS_laststep,
Sk_laststep.data<std::complex<double>>() + ik * len_HS_laststep,
P_k_local.data(),
use_td_moving_gauge,
&(ekb(ik, 0)),
propagator,
ofs_running,
Expand Down
7 changes: 6 additions & 1 deletion source/source_lcao/module_rt/evolve_elec.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "source_lcao/hamilt_lcao.h"
#include "source_lcao/module_rt/gather_mat.h" // MPI gathering and distributing functions
#include "source_lcao/module_rt/kernels/cublasmp_context.h"
#include "source_lcao/module_rt/td_moving_gauge.h"
#include "source_psi/psi.h"

//-----------------------------------------------------------
Expand Down Expand Up @@ -158,7 +159,11 @@ class Evolve_elec
std::ofstream& ofs_running,
const int propagator,
const bool use_tensor,
const bool use_lapack);
const bool use_lapack,
module_rt::TD_MovingGauge* td_mg,
const UnitCell* ucell,
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
const bool use_td_moving_gauge);

// ct_device_type = ct::DeviceType::CpuDevice or ct::DeviceType::GpuDevice
static ct::DeviceType ct_device_type;
Expand Down
13 changes: 11 additions & 2 deletions source/source_lcao/module_rt/evolve_psi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ void evolve_psi(const int nband,
std::complex<double>* psi_k_laststep,
std::complex<double>* H_laststep,
std::complex<double>* S_laststep,
std::complex<double>* P_k,
const bool use_td_moving_gauge,
double* ekb,
int propagator,
std::ofstream& ofs_running,
Expand Down Expand Up @@ -85,8 +87,15 @@ void evolve_psi(const int nband,
{
/// @brief solve the propagation equation
/// @input Stmp, Htmp, psi_k_laststep
/// @output psi_k
solve_propagation(pv, nband, nlocal, PARAM.inp.td_dt, Stmp, Htmp, psi_k_laststep, psi_k);
/// @output psi_k
if (use_td_moving_gauge)
{
solve_propagation(pv, nband, nlocal, PARAM.inp.td_dt, Stmp, Htmp, P_k, psi_k_laststep, psi_k);
}
else
{
solve_propagation(pv, nband, nlocal, PARAM.inp.td_dt, Stmp, Htmp, psi_k_laststep, psi_k);
}
}

// (4)->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
Expand Down
2 changes: 2 additions & 0 deletions source/source_lcao/module_rt/evolve_psi.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ void evolve_psi(const int nband,
std::complex<double>* psi_k_laststep,
std::complex<double>* H_laststep,
std::complex<double>* S_laststep,
std::complex<double>* P_k,
const bool use_td_moving_gauge,
double* ekb,
int propagator,
std::ofstream& ofs_running,
Expand Down
Loading
Loading