#*  **************************************************************************
#*
#*  CDTK, Chemical Dynamics Toolkit
#*  A modular system for chemical dynamics applications and more
#*
#*  Copyright (C) 2011, 2012, 2013, 2014, 2015, 2016
#*  Oriol Vendrell, DESY, <oriol.vendrell@desy.de>
#*
#*  Copyright (C) 2017, 2018, 2019
#*  Ralph Welsch, DESY, <ralph.welsch@desy.de>
#*
#*  Copyright (C) 2020, 2021, 2022, 2023
#*  Ludger Inhester, DESY, ludger.inhester@cfel.de
#*
#*  This file is part of CDTK.
#*
#*  CDTK is free software: you can redistribute it and/or modify
#*  it under the terms of the GNU General Public License as published by
#*  the Free Software Foundation, either version 3 of the License, or
#*  (at your option) any later version.
#*
#*  This program is distributed in the hope that it will be useful,
#*  but WITHOUT ANY WARRANTY; without even the implied warranty of
#*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#*  GNU General Public License for more details.
#*
#*  You should have received a copy of the GNU General Public License
#*  along with this program.  If not, see <http://www.gnu.org/licenses/>.
#*
#*  **************************************************************************
#! /usr/bin/env python
import sys
import string
import os
import random
import math
 
import numpy as np
import CDTK.Tools.Conversion as conv
from . import MDIntegrators as mdi
from scipy.integrate import ode
from CDTK.Tools.Mathematics import ovrMatrix2
HUGE = 999999999999.9
TINY = 1.e-12
SMALL = 1.e-3
[docs]
class EhrenfestDyn(object):
    """
    Control an Ehrenfest-dynamical evolution
    """
    def __init__(self,**opts):
        """
        Init an EhrenfestDyn object
        
        optional arguments:
        - initial_state: initially populated electronic state
                default = 0
        - nstates: number of electronic states
                default = 1
        - N_q : quantum time steps inside a classical step
                default = 10
        - external_E  : function returning array of adiabatic energies
                default = None
        - external_G  : function returning array of adiabatic gradients
                default = None
        - external_D  : function returning array of non-adiabatic spatial derivative couplings
                default = None
        - external_VD : function returning array of non-adiabatic couplings
                default = None
        - external_DPL: function returning array of electronic dipole vectors
                default = None
        - gradient_method : method for adiabatic gradients
                default = numerical
                optional : analytical
        - cplVD_method : method for the non-adiabatic couplings
                default = wavefunction_overlap
                optional : explicit_d_matrix
        - ehrenfest_method : method for the ehrenfest-dynamics
                default = classical
                optional :  ndm [natural decay-of-mixing]
                           scdm [self-consistent decay-of-mixing]
        - simbox: SimulationBox object
                default = None
        - dxgrad: step for fallback numerical differenciation if no gradients
                default = 0.001, [au]
        - state: the state towards which the system decoheres to
                default = initial_state
        - _stepnum_intp_ref: classical step number reference for interpolation procedure
                        initialized from step zero
        - is_switch_field: whether turning on the electric field
                default = False
        - cplFIELD_method: method for the field coupling
                default = None 
                options : [direct_coupling]
        - MuE_MatElt: the electric field dipole couplings
        """
        self.initial_state = opts.get('initial_state',0)
        self.nstates = opts.get('nstates',1)
        self.N_q = opts.get('n_q',10) 
        self.external_E = opts.get('external_E',None)
        self.external_G = opts.get('external_G',None)
        self.external_D = opts.get('external_D',None)
        self.external_VD = opts.get('external_VD',None)
        self.external_DPL = opts.get('external_DPL',None)
        self.gradient_method = opts.get('gradient_method','numerical')
        self.cplVD_method = opts.get('cplVD_method','wavefunction_overlap')
        self.ehrenfest_method = opts.get('ehrenfest_method','classical')
        self.simbox = opts.get('simbox',None)
        self.dxgrad = opts.get('dxgrad',0.001)
        self.state = self.initial_state
        self._logstep = {} # step info is stored here for logging, etc.
        self.is_switch_field = opts.get('switch_field',False)
        self.cplFIELD_method = opts.get('cplFIELD_method',None)
        self.MuE_MatElt = np.zeros((self.nstates,self.nstates),float)
        self.delta_e = np.zeros((self.nstates,self.nstates),float)
        self._stepnum_intp_ref = 0
        self.is_hopp = False
    
[docs]
    def get_funcE(self,**opts):
        """
        Return the energy function to be used from a SimulationBox object
        """
        if self.external_E is None:
            raise ValueError('external energy function not defined')
        def f(x):
            # the diagonal energies on all electronic state surfaces
            e_surface = self.e_adiabatic.copy() 
            # overall energy scalar from density matrix
            e = np.dot(self.rho.diagonal().real, e_surface)
            return e
        return f 
[docs]
    def get_funcG(self):
        """
        Return the gradient function to be used from a SimulationBox object
        The gradient contains two components
        """
        def f(x,**opts):
            e, g = self.gradient_ED(x)
            return g
        return f 
[docs]
    def get_funcVD(self,a_x,**opts):
        """
        Return the coupling matrix to be used from an EhrenfestDyn object
        """
        d = self.external_D(a_x)
        ncoords = self.simbox.V.size
        nvec_v = 1 # velocity vector
        nvec_d = self.nstates * self.nstates # non-adiabatic spatial derivative coupling vectors
        veclen = ncoords
        vd = np.zeros((nvec_v, nvec_d),float)
        vd = ovrMatrix2(self.simbox.V, d, nvec_v, nvec_d, veclen, vd)
        vd.shape = (self.nstates, self.nstates)
        self.d_couplingmat = d.copy()
        return vd 
[docs]
    def get_funcDPL(self,a_x,**opts):
        """
        Return the dipole unction to be used from an EhrenfestDyn object
        """
        if self.external_DPL is None:
            raise ValueError('external dipole function not defined')
        dipole = self.external_DPL(a_x)
        return np.array(dipole) 
[docs]
    def gradient_ED(self,a_x,**opts):
        """
        Return the overall energy scalar and gradient vector via density matrix
        for the Ehrenfest dynamics.
        The [E]hrenfest and [D]ecoherence contribution are implemented in the density matrix
        computed in the _updateDensityMat function.
        Input arguments:
            a_x -- positions of the system
        Returns:
            e_ehrenfest -- potential energy scalar for Ehrenfest dynamics
            g_ehrenfest -- energy gradient vector  for Ehrenfest dynamics
            e_adiabatic -- potential energy for all electronic states
            g_adiabatic -- energy gradient  for all electronic states
        """
        # get list of adiabatic energies  for all electronic states
        self.e_adiabatic = self.external_E(a_x)
        self.e_adiabatic -= self.energy_reference
        # compute energy differences
        for i in range(self.nstates):
            for j in range(self.nstates):
                self.delta_e[i,j] = self.e_adiabatic[i] - self.e_adiabatic[j]
        # overall energy scalar from density matrix
        self.e_ehrenfest = np.dot(self.rho.diagonal().real, self.e_adiabatic)
        # get list of adiabatic gradients for all electronic states
        if self.gradient_method == 'analytical':
            # Ehrenfest dynamics method requires adiabatic gradients for all electronic states
            if self.external_G is not None:
                nstates_g = self.nstates
                state_g = range(nstates_g)
                self.g_adiabatic = self.external_G(a_x, nstates_g=nstates_g, state_g=state_g)
                
                # correct the possible error in Molcas. the error makes the gradients for all electronic states computed as the same
                if np.linalg.norm(self.g_adiabatic[0] - self.g_adiabatic[1]) < SMALL:
                    raise ValueError('Analytical gradient calculation incorrect')
            else:
                raise ValueError('Analytical gradient calculation requires interface function external_G')
        elif self.gradient_method == 'numerical':
            if self.external_G is not None:
                self.g_adiabatic = self.external_G(a_x)
            else:
                self.g_adiabatic = self._gradients_num(a_x)
        else:
            raise ValueError('Method for adiabatic gradient not defined') 
        # formulate the [E]hrenfest coherent contribution [gradient_E]
        ncoords = self.simbox.X.size
        self.gradient_E = np.zeros((ncoords),dtype=float)
        for i in range(self.nstates):
            self.gradient_E += self.rho[i,i].real * self.g_adiabatic[i]
        # formulate the [D]ecoherent contribution [gradient_D]
        if self.ehrenfest_method == 'classical':
            self.gradient_D = np.zeros((ncoords),dtype=float)
        else:
            # decay-of-mixing methods. compute the [D]ecoherence contributions to the gradient
            natoms = len(self.simbox.X) / self.simbox.ndim
            self.gradient_D = np.zeros((ncoords),dtype=float) # the decoherence gradient in the vector form
            _g_D = np.zeros((natoms,self.simbox.ndim),dtype=float) # the decoherence gradient in the matrix form
            dVdt_decoh = 0.0
            for i in range(self.nstates):
                if self.simbox._stepnum == 0:
                    dVdt_decoh += self.DrhoDt_decoh_TIMEpDT[i,i].real * self.e_adiabatic[i]
                else:
                    dVdt_decoh += self.DrhoDt_decoh_TIMEpDT[i,i].real * self.e_TIMEpDT[i]
            vel = self.simbox.V.copy()
            masses = self.simbox.M.copy()
            vel.shape = (natoms, self.simbox.ndim)
            masses.shape = (natoms, self.simbox.ndim)
            for i in range(natoms):
                _g_D[i] = dVdt_decoh / np.dot(vel[i],self.decoh_direction[i]) * self.decoh_direction[i]
                _g_D[i] /= masses[i,0] 
                _g_D[i] *= self.simbox.reduced_mass 
            
            self.gradient_D = _g_D.reshape(natoms*self.simbox.ndim)   
        if self.ehrenfest_method == 'classical':
            # classical method contains the [E]hrenfest contribution
            self.g_ehrenfest = self.gradient_E
        else:
            # decay-of-mixing methods contain the [E]hrenfest and [D]ecoherence contributions
            self.g_ehrenfest = self.gradient_E + self.gradient_D
        return self.e_ehrenfest, self.g_ehrenfest 
    def _updateDensityMat(self,a_x0,**opts):
        """
        Determine the electronic density matrix
        A function for the Ehrenfest dynamics method
        Input arguments:
            a_x0 -- positions of th system
        
        Returns
            rho -- electronic density matrix - global
        optional arguments
            integrator -- ode integrator for quantum electronic SE
                     default: zvode
            is_first -- whether to be the initialization step
                     default: false
            is_restart -- whether to be the initialization step for [restart] calculation
                     default: false
        Implements Eq. (11) in Tully ; JCP 93, 1061 (1990)
        """
        self.is_hopp = False # reset hopp switch
        integrator = opts.get('integrator', 'zvode')
        is_first = opts.get('init', False)
        is_restart =  opts.get('restart', False)
        if self.cplVD_method == 'wavefunction_overlap':
            self._updateDensityMat_wavefunction_overlap(a_x0,
                                                        init=is_first,
                                                        restart=is_restart)
        elif self.cplVD_method == 'explicit_d_matrix':
            self._updateDensityMat_explicit_d_matrix   (a_x0,
                                                        init=is_first,
                                                        restart=is_restart)
        else:
            raise ValueError('Method for non-adiabatic coupling matrix calculation not defined')
    def _updateDensityMat_wavefunction_overlap(self,a_x0,**opts):
        """
        Performs _updateDensityMat procedure for wavefunction_overlap method of
        non-adiabatic coupling matrix
        """
        integrator = opts.get('integrator', 'zvode')
        is_first = opts.get('init', False)
        is_restart =  opts.get('restart', False)
        if is_first:
            # electronic density matrix
            self.rho = np.zeros((self.nstates,self.nstates),dtype=complex)
            # non-adiabatic couplings
            self.Vd = np.zeros((self.nstates,self.nstates),dtype=float)
            # parameters to start SE integration from TIME=0
            self.rho[self.initial_state,self.initial_state] = 1.0
            if self.external_VD is None:
                raise ValueError('Wavefunction overlap calculation requires interface function external_VD')
            self.a_x0_TIME = a_x0.copy() # X at TIME=0
             
            self.e_adiabatic = self.external_E(a_x0) # E at TIME=0 computed for initializing density matrix evolution
            # global energy shift to avoid fast phase rotation in the integration
            if not is_restart:
                self.energy_reference = np.mean(self.e_adiabatic)
            elif is_restart:
                self.energy_reference = self.simbox.energy_reference
            self.e_adiabatic -= self.energy_reference 
            self.energy = self.e_adiabatic.copy() # E at TIME=0. adiabatic energy for all electronic states
            # energy difference between total energy and reactant energy. identical to initial total kinetic energy while initial electronic state K is occupied. 
            # global T0 parameter for NDM method
            if not is_restart:
                ekintot = self.simbox.getKineticEnergy()
                self.T0 = ekintot
            elif is_restart:
                self.T0 = self.simbox.T0
            # the quantum step
            self.DT_q = self.simbox.DT / float(self.N_q)
            # for restarting calculation
            if is_restart:
                self.state = self.simbox.state_save
                self.rho = self.simbox.rho_densitymat_save.copy()
            
            # initializing parameters for decay-of-mixing method
            self.decoh_direction = np.zeros((len(self.simbox.X)/self.simbox.ndim,self.simbox.ndim),dtype=float)
            self.decoh_time_tau = np.zeros((self.nstates,self.nstates),dtype=float)
            # perform decay-of-mixing procedure at TIME=0
            self.DrhoDt_decoh_TIMEpDT = np.zeros((self.nstates,self.nstates),dtype=complex)
            self.decay_of_mixing()
            # for the purpose to compute the [D]ecoherence force for the total gradient 
            # determine the time derivative of the [D]ecoherence density matrix [DrhoDt_decoh] at TIMEpDT
            rho_mat = self.rho.copy() # the current density matrix in the matrix form
            self.DrhoDt_decoh_TIMEpDT = self._get_DrhoDt_decoh(rho_mat) # DrhoDt_decoh at TIMEpDT
            # initialize the ode integrator for one quantum step of electronic SE
            self.TIME_quantum = self.simbox.TIME
            if integrator == 'zvode':
                f = self._get_func()
                self.q = ode(f)
                self.q.set_integrator('zvode', method='bdf')
                # reshape density matrix [rho] for ode integration
                rho_ode = self.rho.reshape(self.nstates*self.nstates)
                self.q.set_initial_value(rho_ode, self.TIME_quantum)
            else:
                raise ValueError('ODE integrator not defined')
            return
        # An integration step of classical nuclei motion for classical interval DT is done in SimulationBox
        # x(t) -> x(t+dt)
        # v(t) -> v(t+dt)
        # Do integration of quantum electronic density matrix over DT with quantum interval DT_q
        self.TIME_quantum       = self.simbox.TIME
        self.TIME_quantum_final = self.simbox.TIME + self.simbox.DT
        # quantities to be interpolated for integration
        # e  -- potentel energy
        self.e_TIME    = self.energy.copy() # E at TIME
        self.energy    = self.e_adiabatic.copy() # E at TIME+DT updated in MDIntegrators. record for t+dt->(t+dt)+dt
        self.e_TIMEpDT = self.energy.copy() # E at TIME+DT
        # Vd -- non-adiabatic coupling V.d
        # The first classical step serves preparing VD at DT-DT/2 of quantum integration dt->(dt)+dt
        # The quantum integration starts from the second classical step
        # Assume the non-adiabatic coupling to be constant for the first classical step from 0->(0)+dt 
        if self.simbox._stepnum == 1:
            self.a_x0_TIMEpDT = a_x0.copy() # X at 0+DT
            self.Vd = self.external_VD(self.a_x0_TIME, self.a_x0_TIMEpDT, self.simbox.DT, self.simbox.phase_track, is_init=True, nstates=self.nstates, unit=self.simbox.unit) # V.D at DT-DT/2 [wavefunction_overlap]
            self.a_x0_TIME = self.a_x0_TIMEpDT.copy() # X updated at 0+DT for classical step (0+dt)->(0+dt)+dt
        elif self.simbox._stepnum > 1:
            self.Vd_TIME = self.Vd.copy() # V.D at TIME-DT/2 [wavefunction_overlap]
            self.a_x0_TIMEpDT = a_x0.copy() # X at TIME+DT
            self.Vd = self.external_VD(self.a_x0_TIME, self.a_x0_TIMEpDT, self.simbox.DT, self.simbox.phase_track, nstates=self.nstates, unit=self.simbox.unit) # V.D at (TIME+DT)-DT/2
            self.a_x0_TIME = self.a_x0_TIMEpDT.copy() # X updated at TIME+DT for classical step (t+dt)->(t+dt)+dt
        else:
            raise ValueError('Error: incorrect step number')
        
        self.Vd_TIMEpDT = self.Vd.copy() # V.D at TIME+DT
        # forced correctin of wave function phase
        # the absolute phase is unknown
        if self.simbox._stepnum > 1:
            for i in range(self.nstates):
                for j in range(i+1,self.nstates):
                    r_ph = self.Vd_TIMEpDT[i,j] / self.Vd_TIME[i,j]
                    if (r_ph < 0.0) and (0.5 < abs(r_ph) < 2.0):
                        # assume to be a matrix element with flipped sign
                        self.Vd_TIMEpDT[i,j] *= -1.0
                        self.Vd_TIMEpDT[j,i] *= -1.0
                        self.Vd[i,j] *= -1.0
                        self.Vd[j,i] *= -1.0
        
        # integrate the quantum SE within one classical step DT
        prob_FS = np.zeros((self.nstates,self.nstates),float)
        while self.TIME_quantum < (self.TIME_quantum_final - SMALL):
            # compute electronic density matrix rho via integrating SE
            self._step_quantum_rho()
            self.TIME_quantum = self.TIME_quantum + self.DT_q
            # compute FS hopping probabilities from the current decoherence electronic state K. via B matrix from electronic density matrix and integrate over quantum steps within DT
            time = self.TIME_quantum
            prob_FS = prob_FS + self.get_prob_FS(time)
        prob_FS = prob_FS * self.DT_q
        prob_FS[self.state,:] = prob_FS[self.state,:] / (self.rho[self.state,self.state].real + TINY)
        # adjus unphysical value due to num err
        for j in range(self.nstates):
            if prob_FS[self.state,j] > 1.0:
                prob_FS[self.state,j] = 1.0
 
        randnum = np.random.random()
        
        # update logfile
        norm_el = self._chk_norm_el()
        self._logstep['prob_FS'] = prob_FS
        self._logstep['randnum'] = randnum
        self._logstep['norm_el'] = norm_el
        f = self.simbox.logfile
        f.write('## Quantum electronic density matrix ')
        f.write('%8.2f a.u. to %8.2f a.u. \n' % (
                                  (self.TIME_quantum-self.simbox.DT),
                                            self.TIME_quantum_final))
        f.write('## Hopping probabilities \n')
        for i, prob_FS_s in enumerate(self._logstep['prob_FS'][self.state]):
            f.write(' %i  %8.6e\n' % (
                i, prob_FS_s))
        f.write('## rand\n')
        f.write(' %8.6f\n' % (
            self._logstep['randnum']))
        f.write('## Norm of electronic wavefunction \n')
        f.write(str(self._logstep['norm_el'])+'\n')
        # determine the switching using FS formula
        ipmax = -1
        Sigma_prob_FS = 0.0
        for i in range(self.nstates):
            if i != self.state: # avoid selfjumps
                Sigma_prob_FS = Sigma_prob_FS + prob_FS[self.state,i]
                if Sigma_prob_FS > randnum:
                    self.is_hopp = True
                    ipmax = i
                    break
        
        if self.is_hopp:
            self.state = ipmax # direct switching
        # Decay-of-mixing methods for Ehrenfest dynamics
        # prepare decoherence parameters for quantum steps within next classical step TIMEpDT -> TIMEpDT + DT
        # perform decay-of-mixing procedure at TIMEpDT. the current classical time
        self.decay_of_mixing()
        # for the purpose to compute the [D]ecoherence force for the total gradient 
        # determine the time derivative of the [D]ecoherence density matrix [DrhoDt_decoh] at TIMEpDT
        rho_mat = self.rho.copy() # the current density matrix in the matrix form
        self.DrhoDt_decoh_TIMEpDT = self._get_DrhoDt_decoh(rho_mat) # DrhoDt_decoh at TIME
    def _updateDensityMat_explicit_d_matrix(self,a_x0,**opts):
        """
        Perform _updateDensityMat procedure for explicit_d_matrix method of
        non-adiabatic coupling matrix
        """
        integrator = opts.get('integrator', 'zvode')
        is_first = opts.get('init', False)
        is_restart =  opts.get('restart', False)
        # coupling matrix element from electric field
        if self.is_switch_field:
            dipole = self.get_funcDPL(a_x0)
            efield = self.simbox._fc.getElectricField(self.simbox.TIME)
            # inner product. dipole and electric field
            # The dipole elements [numpy array] are structured as 
            # | [0.0] [0.1] ... [0.n] |
            # | ...   [k.l] ...       |
            # | ...   ...   ...       |
            # each dipole element [k.l] is n-dimensional vector for the dipole of k- and l-th electronic state.
            # The field coupling matrix element [numpy array] are identically structured.
            # each field coupling matrix element [k.l] is scalar for the coupling of k- and l-th electronic state.
            self.MuE_MatElt = np.dot(dipole,efield)
               
        if is_first:
            # electronic density matrix
            self.rho = np.zeros((self.nstates,self.nstates),dtype=complex)
            # non-adiabatic couplings
            self.Vd = np.zeros((self.nstates,self.nstates),dtype=float)
            # parameters to start SE integration from TIME=0
            self.rho[self.initial_state,self.initial_state] = 1.0
            if self.external_D is None:
                raise ValueError('Explicit d matrix calculation requires interface function external_D')
            self.Vd = self.get_funcVD(a_x0) # V.D at TIME=0
            
            self.e_adiabatic = self.external_E(a_x0) # E at TIME=0 computed for initializing density matrix evolution
            # global energy shift to avoid fast phase rotation in the integration
            if not is_restart:
                self.energy_reference = np.mean(self.e_adiabatic)
            elif is_restart:
                self.energy_reference = self.simbox.energy_reference
            self.e_adiabatic -= self.energy_reference
            self.energy = self.e_adiabatic.copy() # E at TIME=0. adiabatic energy for all electronic states
            
            # energy difference between total energy and reactant energy. identical to initial total kinetic energy while initial electronic state K is occupied. 
            # global T0 parameter for NDM method
            if not is_restart:
                ekintot = self.simbox.getKineticEnergy()
                self.T0 = ekintot
            elif is_restart:
                self.T0 = self.simbox.T0
            # the quantum step
            self.DT_q = self.simbox.DT / float(self.N_q)
            # for restarting calculation
            if is_restart:
                self.state = self.simbox.state_save
                self.rho = self.simbox.rho_densitymat_save.copy()
            
            # interpolated electric field couplings
            if self.is_switch_field:
                if self.cplFIELD_method == 'direct_coupling':
                    self.VD = self.Vd - 1.0j * self.MuE_MatElt
                else:
                    raise ValueError('Method for electric field coupling not defined')
            # initializing parameters for decay-of-mixing method
            self.decoh_direction = np.zeros((len(self.simbox.X)/self.simbox.ndim,self.simbox.ndim),dtype=float)
            self.decoh_time_tau = np.zeros((self.nstates,self.nstates),dtype=float)
            # perform decay-of-mixing procedure at TIME=0
            self.DrhoDt_decoh_TIMEpDT = np.zeros((self.nstates,self.nstates),dtype=complex)
            self.decay_of_mixing()
            # for the purpose to compute the [D]ecoherence force for the total gradient 
            # determine the time derivative of the [D]ecoherence density matrix [DrhoDt_decoh] at TIMEpDT
            rho_mat = self.rho.copy() # the current density matrix in the matrix form
            self.DrhoDt_decoh_TIMEpDT = self._get_DrhoDt_decoh(rho_mat) # DrhoDt_decoh at TIMEpDT
            # initialize the ode integrator for one quantum step of electronic SE
            self.TIME_quantum = self.simbox.TIME
            if integrator == 'zvode':
                f = self._get_func()
                self.q = ode(f)
                self.q.set_integrator('zvode', method='bdf')
                # reshape density matrix [rho] for ode integration
                rho_ode = self.rho.reshape(self.nstates*self.nstates)
                self.q.set_initial_value(rho_ode, self.TIME_quantum)
            else:
                raise ValueError('ODE integrator not defined')
            return
        # An integration step of classical nuclei motion for classical interval DT is done in SimulationBox
        # x(t) -> x(t+dt)
        # v(t) -> v(t+dt)
        # Do integration of quantum electronic amplitude over DT with quantum interval DT_q
        self.TIME_quantum       = self.simbox.TIME
        self.TIME_quantum_final = self.simbox.TIME + self.simbox.DT
        # quantities to be interpolated for integration
        # E  -- potential energy for all adiabatic electronic states
        self.e_TIME       = self.energy.copy() # E at TIME
        self.energy       = self.e_adiabatic.copy() # E at TIME+DT updated in external_G. record for t+dt->(t+dt)+dt
        self.e_TIMEpDT    = self.energy.copy() # E at TIME+DT
        # Vd -- non-adiabatic coupling V.d
        self.Vd_TIME = self.Vd.copy() # V.D at TIME [explicit_d_matrix]
        self.Vd = self.get_funcVD(a_x0) # V.D at TIME+Dt. explicitly mult velocity V
        # interpolated electric field coupling
        if self.is_switch_field:
            if self.cplFIELD_method == 'direct_coupling':
                self.Vd = self.Vd - 1.0j * self.MuE_MatElt
        self.Vd_TIMEpDT = self.Vd.copy() # V.D at TIME+DT
        # integrate within one classical step DT
        prob_FS = np.zeros((self.nstates,self.nstates),float)
        while self.TIME_quantum < (self.TIME_quantum_final - SMALL):
            # compute electronic density matrix rho via integrating SE
            self._step_quantum_rho()
            self.TIME_quantum = self.TIME_quantum + self.DT_q
            # compute FS hopping probabilities from the current decoherence electronic state K. via B matrix from electronic density matrix and integrate over quantum steps within DT
            time = self.TIME_quantum
            prob_FS = prob_FS + self.get_prob_FS(time)
        prob_FS = prob_FS * self.DT_q
        prob_FS[self.state,:] = prob_FS[self.state,:] / (self.rho[self.state,self.state].real + TINY)
        # adjus unphysical value due to num err
        for j in range(self.nstates):
            if prob_FS[self.state,j] > 1.0:
                prob_FS[self.state,j] = 1.0
        
        randnum = np.random.random() 
         
        # determine the switching using FS formula
        ipmax = -1
        Sigma_prob_FS = 0.0
        for i in range(self.nstates):
            if i != self.state: # avoid selfjumps
                Sigma_prob_FS = Sigma_prob_FS + prob_FS[self.state,i]
                if Sigma_prob_FS > randnum:
                    self.is_hopp = True
                    ipmax = i
                    break
        
        if self.is_hopp:
            self.state = ipmax # direct switching
        # Decay-of-mixing methods for Ehrenfest dynamics
        # perform decay-of-mixing procedure at TIMEpDT. the current classical time
        self.decay_of_mixing()
        # for the purpose to compute the [D]ecoherence force for the total gradient 
        # determine the time derivative of the [D]ecoherence density matrix [DrhoDt_decoh] at TIMEpDT
        rho_mat = self.rho.copy() # the current density matrix in the matrix form
        self.DrhoDt_decoh_TIMEpDT = self._get_DrhoDt_decoh(rho_mat) # DrhoDt_decoh at TIME
        # update logfile
        norm_el = self._chk_norm_el()
        self._logstep['prob_FS'] = prob_FS
        self._logstep['randnum'] = randnum
        self._logstep['norm_el'] = norm_el
        f = self.simbox.logfile
        f.write('## Quantum electronic density matrix ')
        f.write('%8.2f a.u. to %8.2f a.u. \n' % (
                                  (self.TIME_quantum-self.simbox.DT),
                                            self.TIME_quantum_final))
        f.write('## Hopping probabilities \n')
        for i, prob_FS_s in enumerate(self._logstep['prob_FS'][self.state]):
            f.write(' %i  %8.6e\n' % (
                i, prob_FS_s))
        f.write('## rand\n')
        f.write(' %8.6f\n' % (
            self._logstep['randnum']))
        f.write('## Norm of electronic wavefunction \n')
        f.write(str(self._logstep['norm_el'])+'\n')
[docs]
    def decay_of_mixing(self,**opts):
        """
        Perform the decay-of-mixing procedure for decoherence
            compute the decoherence direction                      [s]
            compute the decoherence time                           [tau]
        Returns
            decoh_direction - decohenrece direction
            decoh_time_tau  - decoherence time
        """
        natoms = len(self.simbox.X) / self.simbox.ndim
        masses       = self.simbox.M.copy()
        vel          = self.simbox.V.copy()
        masses.shape = (natoms,self.simbox.ndim)
        vel.shape    = (natoms,self.simbox.ndim)
        # compute vibrational momentum for decoherence direction vector s
        pvib = self.simbox.getVibrationalMomentum()
        pvib += TINY # regularization
        # compute the [decoherence direction] and [decoherence time] for decay-of-mixing procedures 
        if self.ehrenfest_method == 'classical':
            # classical Ehrenfest dynamics method. no decohenrece
            for i in range(self.nstates):
                if i != self.state:
                    self.decoh_time_tau[i,self.state] = HUGE 
                    self.decoh_time_tau[self.state,i] = self.decoh_time_tau[i,self.state]
        elif self.ehrenfest_method == 'ndm':
            # the normalised decoherence direction vector s
            for i in range(natoms):
                self.decoh_direction[i] = pvib[i] / np.linalg.norm(pvib[i])
            
            # compute effective vibrational energy. considering the cancellation of singularity in the decoherence force
            evib = 0.0
            for i in range(natoms):
                _v_s  = np.dot(vel[i],self.decoh_direction[i])
                evib += 0.5 * masses[i,0] * _v_s * _v_s
            evib += TINY # regularization
            
            # compute decoherence time towards the current decoherence electronic state K
            C_coh = 60.0
            for i in range(self.nstates):
                if i != self.state:
                    self.decoh_time_tau[i,self.state] = abs(2.0 * self.T0 * self.simbox.reduced_mass / (self.e_adiabatic[i] - self.e_adiabatic[self.state]) / evib)
                    self.decoh_time_tau[self.state,i] = self.decoh_time_tau[i,self.state]
        elif self.ehrenfest_method == 'scdm':
            # the normalised decoherence direction vector s
            for i in range(natoms):
                self.decoh_direction[i] = pvib[i] / np.linalg.norm(pvib[i])
            
            # compute effective vibrational energy. considering the cancellation of singularity in the decoherence force
            evib = 0.0
            for i in range(natoms):
                _v_s  = np.dot(vel[i],self.decoh_direction[i])
                evib += 0.5 * masses[i,0] * _v_s * _v_s
            evib += TINY # regularization
            # compute decoherence time towards the current decoherence electronic state K
            C_coh = 60.0 # the coherence parameter
            for i in range(self.nstates):
                if i != self.state:
                    self.decoh_time_tau[i,self.state] = (abs(1.0 / (self.e_adiabatic[i] - self.e_adiabatic[self.state])) + 0.25 / evib) * C_coh * 2.0
                    self.decoh_time_tau[self.state,i] = self.decoh_time_tau[i,self.state]
        else:
            raise ValueError('Method for decay-of-mixing procedure not defined') 
                         
    def _step_quantum_rho(self,**opts):
        """
        Updates the electronic density matrix [rho] from integrating SE
        
        The RHS functions of the ODE are computed from _get_func
        Input arguments:
            self.e_TIME     -- potential energies of each state at TIME
            self.e_TIMEpDT  -- potential energies of each state at TIME+DT
            self.Vd_TIME    -- non-adiabatic couplings of each state at TIME
            self.Vd_TIMEpDT -- non-adiabatic couplings of each state at TIME+DT
            self.rho -- electronic density matrix within the integration procedure at TIME
        
        Returns:
            self.rho -- updated electronic density matrix in the matrix form at each quantum steps within TIME to TIME+DT
        """
        while self.q.successful() and self.q.t < (self.TIME_quantum + self.DT_q):
            self.q.integrate(self.q.t + self.DT_q)
            rho_ode = self.q.y
            self.rho = rho_ode.reshape(self.nstates,self.nstates)
            if not self.q.successful():
                raise ValueError('The integration over quantum step fails')
                # dump current status
   
    def _funcf(self,rho_ode,time):
        """
        This function is a thin wrapper over RHS of SE
        Inputs
            rho_ode - the current density matrix in the vector form
            time    - the current time
        """
        # [rho_ode] is the density matrix in the vector form adapted to the ode integrator
        # [rho_mat] is the density matrix in the matrix form
        rho_mat = rho_ode.copy()
        rho_mat.shape = (self.nstates,self.nstates)
        # interpolate E and V.D
        e_intm, Vd_intm = self.linear_interpolate(time,retEVd='returnAll')
        
        Hmat  = -1.0j * np.diag(e_intm) - Vd_intm
        # determine the [E]hrenfest contribution of the density matrix for the RHS function of SE in the matrix form
        RHS_E = np.dot(Hmat,rho_mat) - np.dot(rho_mat,Hmat)      
        
        if self.ehrenfest_method == 'classical':
            RHS = RHS_E
        else:
            # determine the [D]ecoherence contribution of the density matrix for the RHS function of SE in the matrix form
            RHS_D = self._get_DrhoDt_decoh(rho_mat)            
            RHS = RHS_E + RHS_D
        # RHS function in the vector form adapted to the ode integrator
        RHS.shape = (self.nstates*self.nstates)
        return RHS
    def _get_func(self):
        """
        Return the RHS function of the SE ODE
        """
        def f(time,rho_ode):
            return self._funcf(rho_ode,time)
        return f
    def _get_DrhoDt_decoh(self,rho_mat):
        """
        Returns  the [D]ecoherence contribution of the density matrix for the RHS function of SE
        Provides the time derivate of the decoherence density matrix
            for integration of SE
            for computation of decoherence force
        Input
            rho_mat        - the current full density matrix [rho] in the matrix form
            decoh_time_tau - the decoherence time matrix
        Returns
            DrhoDt_D - the time derivative of the decoherence density matrix in the matrix form
        """
        K = self.state # the current state K. towards which the electronic wavefunction decoheres
        DrhoDt_D = np.zeros((self.nstates,self.nstates),dtype=complex)
        if self.ehrenfest_method == 'classical':
            # classical Ehrenfest method. decoherence contribution is set to zero
            pass
        else:
            # decay-of-mixing method
            for i in range(self.nstates):
                for j in range(self.nstates):
                    # diagonal     matrix elements of Drho[D]Dt
                    if i == j:
                        if i != K:
                            DrhoDt_D[i,i] = - rho_mat[i,i] / self.decoh_time_tau[i,K]
                        else: # i == K
                            DrDtMatElt_s = 0.0
                            for l in range(self.nstates):
                                if l != K:
                                    DrDtMatElt_s += rho_mat[l,l] / self.decoh_time_tau[K,l]
                            DrhoDt_D[i,i] = DrDtMatElt_s
                    # non-diagonal matrix elements of Drho[D]Dt
                    else:
                        if i != K  and j != K:
                            DrhoDt_D[i,j] = - 0.5 * (1.0 / self.decoh_time_tau[i,K] + 1.0 / self.decoh_time_tau[j,K]) * rho_mat[i,j]
                        elif i == K and j != K:
                            DrDtMatElt_s = 0.0
                            for l in range(self.nstates):
                                if l != K:
                                    DrDtMatElt_s += rho_mat[l,l] / self.decoh_time_tau[K,l]
                            DrhoDt_D[i,j] = 0.5 * (1.0 * DrDtMatElt_s / rho_mat[K,K] - 1.0 / self.decoh_time_tau[j,K]) * rho_mat[i,j]
                        elif i != K and j == K:
                            DrDtMatElt_s = 0.0
                            for l in range(self.nstates):
                                if l != K:
                                    DrDtMatElt_s += rho_mat[l,l] / self.decoh_time_tau[K,l]
                            DrhoDt_D[i,j] = 0.5 * (1.0 * DrDtMatElt_s / rho_mat[K,K] - 1.0 / self.decoh_time_tau[i,K]) * rho_mat[i,j]        
        return DrhoDt_D
[docs]
    def get_prob_FS(self,time):
        """
        Returns the FS hopp probability matrix g
        """
        Vd_intp = self.linear_interpolate(time,retEVd='returnVd')
        Gmat = np.zeros((self.nstates,self.nstates),float)
        Gmat[self.state,self.state] = 0.0 # avoid selfjumps!
        rho_mat = self.rho.copy() # current density matrix in the matrix form
        # diagonal matrix element of Drho[D]Dt for current decoherence state K
        DrDtMatElt_K = 0.0
        for l in range(self.nstates):
            if l != self.state:
                DrDtMatElt_K += rho_mat[l,l] / self.decoh_time_tau[self.state,l]
       
        for j in range(self.nstates):
            if j != self.state:
                BMatElt = - 2.0 * rho_mat[self.state,j] * Vd_intp[self.state,j]
                Gmat[self.state,j] = BMatElt.real + DrDtMatElt_K.real
                if Gmat[self.state,j] < 0.0:
                    Gmat[self.state,j] = 0.0 # adjust unphysical value
        return Gmat             
[docs]
    def linear_interpolate(self,time,**opts):
        """
        Returns interpolated E and V.D
        optional arguments:
        - retEVd: variables to be interpolated
                'returnAll' - interpolated E and V.D (default)
                'returnVd'  - interpolated V.D (alternative)
        """
        retEVd = opts.get('retEVd','returnAll')
        if self.cplVD_method == 'explicit_d_matrix':
            # non-adiabatic couplings from explicit d matrix method 
            # from self.simbox.TIME to self.simbox.TIME+self.simbox.DT
            t0 = self.simbox.TIME
            #t1 = self.simbox.TIME + self.simbox.DT
            delta_t = time - t0
            # linear interpolation
            if retEVd == 'returnAll':
                e_intm  = self.e_TIME  + delta_t * ( self.e_TIMEpDT  -  self.e_TIME ) / (self.simbox.DT)
                Vd_intm = self.Vd_TIME + delta_t * ( self.Vd_TIMEpDT - self.Vd_TIME ) / (self.simbox.DT)
                return e_intm, Vd_intm
            elif retEVd == 'returnVd':
                Vd_intm = self.Vd_TIME + delta_t * ( self.Vd_TIMEpDT - self.Vd_TIME ) / (self.simbox.DT)
                return Vd_intm
        elif self.cplVD_method == 'wavefunction_overlap':
            # non-adiabatic couplings from wavefunction overlap method 
            # E  from self.simbox.TIME to self.simbox.TIME+self.simbox.DT
            # Vd from self.simbox.TIME-self.simbox.DT/2.0 to self.simbox.TIME+self.simbox.DT/2.0
            t0_e  = self.simbox.TIME
            #t1_e  = self.simbox.TIME + self.simbox.DT
            delta_t_e  = time - t0_e
            # linear interpolation
            if retEVd == 'returnAll':
                e_intm  = self.e_TIME  + delta_t_e * ( self.e_TIMEpDT  -  self.e_TIME ) / (self.simbox.DT)
                # the interpolation of non-adiabatic coupling is invoked 
                # when and after classical integration proceeds to the second step
                if self.simbox._stepnum == 1:
                    # interpolate through [dt-dt/2, dt+dt/2]
                    Vd_intm = self.linear_interpolate_Vd(time,is_init=True)
                else:
                    # interpolate through [t, t+dt/2]
                    Vd_intm = self.linear_interpolate_Vd(time)
                return e_intm, Vd_intm
            elif retEVd == 'returnVd':
                if self.simbox._stepnum == 1:
                    Vd_intm = self.linear_interpolate_Vd(time,is_init=True)
                else:
                    Vd_intm = self.linear_interpolate_Vd(time)
                return Vd_intm 
    
[docs]
    def linear_interpolate_Vd(self,time,**opts):
        """
        Continuously interpolate the non-adiabatic coupling Vd after the initializing classical step.
        adapted to the wavefunction overlap method.
        
        Inputs:
            time:  current time of quantum integration
        uses:
            self.Vd_intp_TIME:    non-adiabatic coupling at t
            self.Vd_intp_TIMEpDT: non-adiabatic coupling at t+dt
            self.Vd_TIMEpDT:      non-adiabatic coupling at t+dt/2
        Returns:
            Vd_intm: the interpolated non-adiabatic coupling
        optional arguments:
            is_init: whether to be the initialization classical step
                default - False
        """
        is_init = opts.get('is_init',False)
        if is_init:
            # for the first step the non-adiabatic coupling is set to be constant
            Vd_intm = self.Vd.copy()
            if self.simbox._stepnum != self._stepnum_intp_ref:
                self.Vd_intp_TIMEpDT = self.Vd.copy()
                self._stepnum_intp_ref = self.simbox._stepnum
        else:
            if self.simbox._stepnum != self._stepnum_intp_ref:
                # classical step is advanced. update interpolation references
                self.Vd_intp_TIME = self.Vd_intp_TIMEpDT.copy()
                self.Vd_intp_TIMEpDT = self.Vd_TIMEpDT.copy()
                self._stepnum_intp_ref = self.simbox._stepnum
 
            delta_t_vd = time - self.simbox.TIME
            if delta_t_vd <= (self.simbox.DT / 2.0):
                Vd_intm = self.Vd_intp_TIME + delta_t_vd * ( self.Vd_TIMEpDT - self.Vd_intp_TIME ) / (self.simbox.DT / 2.0)
            else:
                Vd_intm = self.Vd_TIMEpDT.copy()
        return Vd_intm  
 
 
    def _gradients_num(self,x,**opts):
        """
        Fallback to numerical gradients
        """
        dx = self.dxgrad # step for numerical differentiation
        grads = np.zeros((self.nstates,x.size),float)
        for i in range(x.size):
            x1 = x.copy()
            x1[i] = x1[i] + 0.5*dx
            e1 = self.external_E(x1)
            x0 = x.copy()
            x0[i] = x0[i] - 0.5*dx
            e0 = self.external_E(x0)
            grads[:,i] = (e1-e0)/dx
        return grads
    
    def _chk_norm_el(self):
        """
        Returns norm of electronic wavefunction
        """
        rho_mat = self.rho.copy()
        norm_el = 0.0
        for pop in rho_mat.diagonal():
            norm_el = norm_el + pop.real
        return norm_el
    def _update_logfile(self):
        f_adp = self.simbox.logfile_adp
        norm_el = self._chk_norm_el()
        if not self.simbox.isrestart:
            adpstr = "%6i  %16.7f" % (
                (self.simbox._stepnum),self.simbox.TIME)
        elif self.simbox.isrestart:
            adpstr = "%6i  %16.7f" % (
                (self.simbox._stepnum_restart+self.simbox._stepnum),self.simbox.TIME)
        for adp_s in self.rho.diagonal():
            adpstr = adpstr + "  %16.7f" % (adp_s.real)
        adpstr = adpstr + "  %16.7f\n" % (norm_el,)
        f_adp.write(adpstr)