#*  **************************************************************************
#*
#*  CDTK, Chemical Dynamics Toolkit
#*  A modular system for chemical dynamics applications and more
#*
#*  Copyright (C) 2011, 2012, 2013, 2014, 2015, 2016
#*  Oriol Vendrell, DESY, <oriol.vendrell@desy.de>
#*
#*  Copyright (C) 2017, 2018, 2019
#*  Ralph Welsch, DESY, <ralph.welsch@desy.de>
#*
#*  Copyright (C) 2020, 2021, 2022, 2023
#*  Ludger Inhester, DESY, ludger.inhester@cfel.de
#*
#*  This file is part of CDTK.
#*
#*  CDTK is free software: you can redistribute it and/or modify
#*  it under the terms of the GNU General Public License as published by
#*  the Free Software Foundation, either version 3 of the License, or
#*  (at your option) any later version.
#*
#*  This program is distributed in the hope that it will be useful,
#*  but WITHOUT ANY WARRANTY; without even the implied warranty of
#*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#*  GNU General Public License for more details.
#*
#*  You should have received a copy of the GNU General Public License
#*  along with this program.  If not, see <http://www.gnu.org/licenses/>.
#*
#*  **************************************************************************
import os
import sys
import datetime
import shelve
import uuid
from time import time
import numpy as np
import CDTK.Tools.Conversion as conv
from . import MDIntegrators as mdi
from . import Tools as ntool
IDLEN = 12
SEP = '#--------------------------------------------------#\n'
[docs]
class SimulationBox(object):
    """
    Bunch of atoms given by their cartesian coordinates.
    Provides functionality to follow and analyze their classical dynamics
    """
    def __init__(self,**opts):
        self.X  = opts.get('X',None)    # 3N coordinates of the atoms
        self.V  = opts.get('V',None)    # 3N velocities of the atoms
        self.DT = opts.get('DT',10.0)   # Time-step
        self.M  = opts.get('M',None)    # 3N masses of each Cartesian coordinate
        self.atomSymbols = opts.get('atomSymbols',[])   # list of atoms
        self.func_E      = opts.get('func_E ',None)     # energy function
        self.func_EG     = opts.get('func_EG',None)     # energy,gradient function
        self.islogfile   = opts.get('logfile',True)     # produce a logfile
        self.constraints = opts.get('constraints',None)
        self.dxgrad      = opts.get('dxgrad',0.001)     # step for numerical energy differentiation
        self.TIME = 0.0    # Current time
        self._g = None     # 3N gradient vector
        self.TrajX = []    # List of positions along trajectory
        self.KinE = []     # List of Total Energy along trajectory
        self.PotE = []     # List of Total Energy along trajectory
        self.TotE = []     # List of Total Energy along trajectory
        self.TrajV = []    # List of velocities along trajectory
        self.TotE = []     # List of total energies along trajectory
        self._ID = uuid.uuid4().hex[0:IDLEN] # unique ID
        self._CTIME = datetime.datetime.today() # creation date/time
        self._BNAME = 'sb_%4i%02i%02i_' % (self._CTIME.year,self._CTIME.month,self._CTIME.day)
        if self.label: self._BNAME = self._BNAME + self.label + '_'
        self._NAME = self._BNAME + self._ID
        self._debug = False # If true provide output to screen
        self._stepnum = 0 # Classical step number
[docs]
    def integrate(self,**opts):
        """
        Integrate up to tfinal
        """
        self.DT = opts.get('dt',self.DT)
        tfinal = opts.get('tfinal',self.TIME+self.DT)
        if self.func_EG is None:        # fallback to numerical gradient
            self.func_EG = self._getNumGrad()
        if self.TIME >= tfinal: return
        if len(self.TrajX) == 0: # log status fot TIME=0
            self._e,self._g = self.func_EG(self.X)
        while self.TIME < tfinal:
            self._step(integrator='VelocityVerlet')
            self._stepnum += 1
            self.TIME += self.DT
            self._updateLog()
            self._checkpoint() 
[docs]
    def add_observable_cntrl(self,ob):
        """
        Add an observable object
        OB objects provide observable properties for the molecule
        """
        if not self._has_ob or self.isrestart:
            self._has_ob = True
            # init the observable objects
            # assume to contain several single observable objects
            self._ob = []
        ob.simbox = self # "
        self._ob.append(ob) # cross link objects
        if not self.isrestart:
            self.TrajXeh = []  # List of Observable. [position of the electronic hole]
            self.TrajTAS = []  # List of Observable. [transient absorption spectra]
            self.TrajCharge = [] # List of Observable. [positive charge on each atom] 
[docs]
    def getReducedMass(self):
        """
        Return the reduced mass of the molecule
        """
        natoms = len(self.X) / self.ndim
        masses = self.M.copy()
        masses.shape = (natoms,self.ndim)
        _rm_s = 0.0
        for i in range(natoms):
            _rm_s += 1.0 / masses[i,0]
        reduced_mass = 1.0 / _rm_s
        return reduced_mass 
[docs]
    def removeCenterMassMotion(self):
        """
        Remove center of mass translational motion
        """
        if self.internals: return # do nothing for internal coordinates
        ncoords = self.X.size
        nat = ncoords/self.ndim
        mv = self.M * self.V  # 3N momenta from velocities
        mv.shape = (nat,self.ndim)
        ptot = self.getMomentum()
        mv = mv - ptot/nat # remove total momentum
        mv.shape = (ncoords,)
        self.V = mv / self.M # velocities from rescaled momenta) 
[docs]
    def getMomentum(self,**opts):
        """
        Return total momentum vector (in [au])
        Optional arguments:
        - indices: tuple with indices of a group of atoms
        - stepnum: step number to consider
        - time: time to consider
            ** use time OR stepnum
        """
        indices = opts.get('indices',None)
        stepnum = opts.get('stepnum',None)
        time = opts.get('time',None)
        ncoords = self.X.size
        if time is not None:
            stepnum = self._indexfromtime(time)
        if stepnum is None:
            v = self.V
        else:
            v = self.TrajV[stepnum]
        mv = self.M * v
        mv.shape = (ncoords/self.ndim,self.ndim)
        if indices is None:
            ptot = mv.sum(axis=0)
        else:
            ptot = np.zeros(self.ndim,float)
            for i in indices:
                ptot = ptot + mv[i,:]
        return ptot 
[docs]
    def getKineticEnergy(self,**opts):
        """
        Total kinetic energy of the system of particles
        Optional arguments:
        - stepnum: step number to consider
        - time: time to consider
            ** use time OR stepnum
        """
        stepnum = opts.get('stepnum',None)
        time = opts.get('time',None)
        if time is not None:
            stepnum = self._indexfromtime(time)
        if stepnum is None:
            v = self.V
        else:
            v = self.TrajV[stepnum]
        return 0.5 * np.dot(self.M*v,v) 
[docs]
    def getTranslationalEnergy(self,**opts):
        """
        Translational energy of the center of mass of the system of particles
        Optional arguments:
        - stepnum: step number to consider
        - time: time to consider
            ** use time OR stepnum
        """
        stepnum = opts.get('stepnum',None)
        time = opts.get('time',None)
        if time is not None:
            stepnum = self._indexfromtime(time)
        if stepnum is None:
            v = self.V.copy()
        else:
            v = self.TrajV[stepnum].copy()
        m = self.M.copy()
        v.shape=(len(v)/self.ndim,self.ndim)
        m.shape=(len(m)/self.ndim,self.ndim)
        et = ntool.etrans(v,m[:,0],ndim=self.ndim)
        return et 
[docs]
    def getRotationalEnergy(self,**opts):
        """
        Rotational energy around the center of mass of the system of particles
        Optional arguments:
        - stepnum: step number to consider
        - time: time to consider
            ** use time OR stepnum
        """
        stepnum = opts.get('stepnum',None)
        time = opts.get('time',None)
        if time is not None:
            stepnum = self._indexfromtime(time)
        if stepnum is None:
            x = self.X.copy()
            v = self.V.copy()
        else:
            x = self.TrajX[stepnum].copy()
            v = self.TrajV[stepnum].copy()
        m = self.M.copy()
        x.shape=(len(x)/self.ndim,self.ndim)
        v.shape=(len(v)/self.ndim,self.ndim)
        m.shape=(len(m)/self.ndim,self.ndim)
        er = ntool.erot(x,v,m[:,0])
        return er 
[docs]
    def getVibrationalEnergy(self,**opts):
        """
        Internal energy of the system of particles
        Optional arguments:
        - stepnum: step number to consider
        - time: time to consider
            ** use time OR stepnum
        """
        stepnum = opts.get('stepnum',None)
        time = opts.get('time',None)
        if time is not None:
            stepnum = self._indexfromtime(time)
        if stepnum is None:
            x = self.X.copy()
            v = self.V.copy()
        else:
            x = self.TrajX[stepnum].copy()
            v = self.TrajV[stepnum].copy()
        m = self.M.copy()
        m.shape=(len(m)/self.ndim,self.ndim)
        x.shape=(len(x)/self.ndim,self.ndim)
        v.shape=(len(v)/self.ndim,self.ndim)
        ev = ntool.evib(x,v,m[:,0],ndim=self.ndim)
        return ev 
[docs]
    def getVibrationalMomentum(self,**opts):
        """
        Internal momentum of the system of particles
        Optional arguments:
        - stepnum: step number to consider
        - time: time to consider
            ** use time OR stepnum
        """
        stepnum = opts.get('stepnum',None)
        time = opts.get('time',None)
        if time is not None:
            stepnum = self._indexfromtime(time)
        if stepnum is None:
            x = self.X.copy()
            v = self.V.copy()
        else:
            x = self.TrajX[stepnum].copy()
            v = self.TrajV[stepnum].copy()
        m = self.M.copy()
        m.shape=(len(m)/self.ndim,self.ndim)
        x.shape=(len(x)/self.ndim,self.ndim)
        v.shape=(len(v)/self.ndim,self.ndim)
        pv = ntool.pvib(x,v,m[:,0],ndim=self.ndim)
        return pv 
[docs]
    def getPotEnergy(self,**opts):
        """
        Return potential energy (last time or stored) of the system (a.u.)
        Optional arguments:
        - stepnum: step number to consider
        - time: time to consider
            ** use time OR stepnum
        """
        stepnum = opts.get('stepnum',None)
        time = opts.get('time',None)
        if time is not None:
            stepnum = self._indexfromtime(time)
        if stepnum is None:
            e = self.PotE[-1]
        else:
            e = self.PotE[stepnum]
        return e 
[docs]
    def centerOfMassVelocity(self,**opts):
        """
        Compute the velocity vector of the center of mass
        Optional arguments:
        - stepnum: step number to consider
        - time: time to consider
            ** use time OR stepnum
        """
        if self.ndim != 3:
            raise ValueError('Function centerOfMassVelocity is currently for three spatial dimensions')
        stepnum = opts.get('stepnum',None)
        time = opts.get('time',None)
        if time is not None:
            stepnum = self._indexfromtime(time)
        if stepnum is None:
            vv = self.V
        else:
            vv = self.TrajV[stepnum]
        totM = self.M.sum()/float(self.ndim)
        v = vv.copy()
        m = self.M.copy()
        v.shape = (len(v)/self.ndim,self.ndim)
        m.shape = (len(m)/self.ndim,self.ndim)
        px = np.dot(v[:,0],m[:,0])
        py = np.dot(v[:,1],m[:,1])
        pz = np.dot(v[:,2],m[:,2])
        vcm = np.array([px,py,pz])/totM
        return vcm 
[docs]
    def save(self,**opts):
        """
        Save object using pickle
        optional arguments:
            - fname: filename to save to. If not given a random
                     hash will be used
        """
        # Process arguments
        fname = opts.get('fname',self._NAME)
        # Undefine functions, these cannot be saved!
        fC = self.constraints
        fE  = self.func_E
        fEG = self.func_EG
        self.constraints = None
        self.func_EG = None
        # Start save
        s = shelve.open(fname,protocol=2)
        s['sbox'] = self
        s.close()
        self.constraints = fC
        self.func_E  = fE
        self.func_EG = fEG 
[docs]
    def load(self,a_fname,**opts):
        """
        Load a simulationbox object from file
        """
        try:
            s = shelve.open(a_fname,protocol=2)
        except:
            msg = 'File '+a_fname+' could not be accessed for reading'
            raise ValueError(msg)
        b = s['sbox']
        return b 
[docs]
    def loadDirectory(self,a_dir,match=None,**opts):
        """
        Load all simulationbox objects stored in a directory
        Input arguments:
            a_dir -- directory where to load simbox objects from
            match -- (optional) load only files that match "match"
        """
        sboxes = []
        try:
            listFiles = os.listdir(a_dir)
        except:
            raise ValueError('Could not open directory '+a_dir)
        for filename in listFiles:
            pathname = os.path.join(a_dir,filename)
            if os.path.isdir(pathname):
                sboxes = sboxes + self.loadDirectory(pathname,match=match)
            else:
                if match is not None:
                    if filename.find(match) == -1:
                        continue
                try:
                    sbox = self.load(pathname)
                    sboxes.append(sbox)
                except:
                    continue
        return sboxes 
[docs]
    def traj2VTF(self,**opts):
        """
        Generate a trajectory in .vtf format
        Useful for plots with VMD
        """
        if self.internals: return
        fileName = opts.get('fileName',self._NAME+'.vtf')
        out = file(fileName,'w')
        ncoords = self.X.size
        nat = ncoords/self.ndim
        for i,s in enumerate(self.atomSymbols):
            out.write( 'atom %i   radius 1.0 name %s\n' % (i,s) )
        out.write('\n')
        for geom in self.TrajX:
            out.write('timestep\n')
            geom.shape = (nat,self.ndim)
            for atom in geom:
                for coord in atom:
                    out.write(' %12.7f' % (coord*conv.au2an))
                out.write('\n')        
            out.write('\n')
            geom.shape = (ncoords,) 
[docs]
    def getPropertyTable(self,funcX=None,filename=None,usevel=False,**opts):
        """
        Return property as a function of time
        Works on a single SimulationBox instance
        Input arguments:
            funcX -- (optional) function of the coordinates
                     may have more than one return value
            filename -- (optional) file where to store the data in table form
            usevel -- (optional) if True, velocities are
                          used instead of positions
        on return the following arrays:
            time, ekin, etrans, erot, evib, epot
        or
            time, func_val
        if funcX provided
        Options
            finalTime -- final time for property analysis over patial time domain
        """
        finalTime = opts.get('finalTime',None)
        if finalTime is not None:
            # analyse over partial time domain terminating at [finalTime]
            if sbox.TIME < finalTime:
                return None
            else:
                times = np.linspace(0.0,finalTime,num=int(finalTime/sbox.DT)+1)
        else:
            # analyse over full time domain
            times = np.linspace(0.0,self.TIME,num=len(self.TrajX))
        tlen = len(times)
        if funcX is not None:
            if not usevel:
                fval = np.array([funcX(x) for x in self.TrajX[:tlen]])
            else:
                fval = np.array([funcX(x) for x in self.TrajV[:tlen]])
            if filename is not None:
                fout = file(filename,'w')
                fout.write('# t[as]  func( X(t) )\n')
                for i,t in enumerate(times):
                    fout.write('%13.7f  ' % (t,))
                    if isinstance(fval[i],np.ndarray): # multidimensional funcX
                        for v in fval[i]: fout.write(' %13.7f' % (v,))
                    else:
                        fout.write(' %13.7f' % (fval[i],))
                    fout.write('\n')
                fout.close()
            return times,fval
        else:
            ekin = np.array([self.getKineticEnergy(time=t) for t in times])
            etra = np.array([self.getTranslationalEnergy(time=t) for t in times])
            erot = np.array([self.getRotationalEnergy(time=t) for t in times])
            evib = np.array([self.getVibrationalEnergy(time=t) for t in times])
            epot = np.array([self.getPotEnergy(time=t) for t in times])
            if filename is not None:
                fout = file(filename,'w')
                fout.write('# time[as]  E_kin  E_trans  E_rot  E_vib\n')
                for i,t in enumerate(times):
                    fout.write('%13.7f  %13.7f %13.7f %13.7f %13.7f %13.7f\n' % 
                            (t,ekin[i],etra[i],erot[i],evib[i],epot[i]))
                fout.close()
            return times,ekin,etra,erot,evib,epot 
[docs]
    def getPropertyTableDir_F(self,a_dir,a_funcX,filename=None,match=None,usevel=False,**opts):
        """
        Return an ensemble average property as a function of time
        Input arguments:
            a_dir -- directory with SimulationBox to be averaged
            a_funcX -- function of the coordinates;
                may have more than one return value
            filename -- (optional) file where to store the data in table form
            match -- load only files matching match
            usevel -- (optional) if True, velocities are used instead of positions
        On return two arrays with:
            time, func_val
 
        Options
            finalTime -- final time for property analysis over patial time domain
        """
        finalTime = opts.get('finalTime',None)
        lboxes = self.loadDirectory(a_dir,match=match)
        if finalTime is not None:
            times = np.linspace(0.0,finalTime,num=int(finalTime/lbox[0].DT)+1)
        else:
            times = np.linspace(0.0,lboxes[0].TIME,num=len(lboxes[0].TrajX))
        N = len(lboxes) # number of systems
        S = len(times) # number of time steps
        data = []
        for sb in lboxes:
            if finalTime is not None:
                tab = sb.getPropertyTable(funX=a_funcX,usevel=usevel,finalTime=finalTime)
                if tab is not None:
                    data.append(tab[1])
            else:
                data.append(sb.getPropertyTable(funcX=a_funcX,usevel=usevel)[1])
        adata = np.array(data) # indices of adata: system,timestep,property
        if len(adata.shape) == 2: # single property
            fval = np.array([adata[:,it].sum()/N for it in range(S)])
            if filename is not None:
                fout = file(filename,'w')
                fout.write('# t[as]  func( X(t) )\n')
                for i,t in enumerate(times):
                    fout.write('%13.7f  %13.7f\n' % (t,fval[i]))
                fout.close()
        if len(adata.shape) == 3: # multiple properties
            nprop = adata.shape[2]
            fval = []
            for p in range(nprop):
                fval.append( np.array([adata[:,it,p].sum()/N for it in range(S)]) )
            fval = np.array(fval)
            if filename is not None:
                fout = file(filename,'w')
                fout.write('# t[as]  func( X(t) )\n')
                for i,t in enumerate(times):
                    fout.write('%13.7f  ' % (t,))
                    for p in range(nprop):
                        fout.write(' %13.7f' % (fval[p,i]))
                    fout.write('\n')
                fout.close()
        return times,fval 
[docs]
    def getPropertyTableDir_E(self,a_dir,filename=None,match=None,**opts):
        """
        Return ensemble averages of various energies
        Input arguments:
            a_dir -- directory with SimulationBox to be averaged
            filename -- (optional) file where to store the data in table form
            match -- load only files matching a_match
        On return two arrays with:
            time, func_val
        Options
            finalTime -- final time for property analysis over patial time domain
        """
        finalTime = opts.get('finalTime',None)
        lboxes = self.loadDirectory(a_dir,match=match)
        if finalTime is not None:
            times = np.linspace(0.0,finalTime,num=int(finalTime/lbox[0].DT)+1)
        else:
            times = np.linspace(0.0,lboxes[0].TIME,num=len(lboxes[0].TrajX))
        data = []
        for sb in lboxes:
            data.append(sb.getPropertyTable(finalTime=finalTime))
        adata = np.array(data)
        N = len(adata[:,0,0]) # number of systems
        S = len(adata[0,0,:]) # number of time steps
        ekin = np.array([adata[:,1,it].sum()/N for it in range(S)])
        etra = np.array([adata[:,2,it].sum()/N for it in range(S)])
        erot = np.array([adata[:,3,it].sum()/N for it in range(S)])
        evib = np.array([adata[:,4,it].sum()/N for it in range(S)])
        epot = np.array([adata[:,5,it].sum()/N for it in range(S)])
        if filename is not None:
            fout = file(filename,'w')
            fout.write('# time[as]  E_kin  E_trans  E_rot  E_vib\n')
            for i,t in enumerate(times):
                fout.write('%13.7f  %13.7f %13.7f %13.7f %13.7f %13.7f\n' % 
                        (t,ekin[i],etra[i],erot[i],evib[i],epot[i]))
            fout.close()
        return times,ekin,etra,erot,evib,epot 
[docs]
    def getProperty2D_F(self,a_dir,a_funcX,a_filename,match=None):
        """
        Return a 2D function of the system in block form, one block per time step
        Input arguments:
            a_dir -- directory with SimulationBox to be averaged
            a_filename -- file where to store the data in table form
            match -- load only files matching a_match
        On return two arrays with:
            time, func_val
        """
        lboxes = self.loadDirectory(a_dir,match=match)
        times = np.linspace(0.0,lboxes[0].TIME,num=len(lboxes[0].TrajX))
        N = len(lboxes) # number of systems
        S = len(times) # number of time steps
        data = []
        for sb in lboxes:
            data.append(sb.getPropertyTable(funcX=a_funcX)[1])
        adata = np.array(data) # indices of adata: system,timestep,property
        fout = file(a_filename,'w')
        for i,t in enumerate(times):
            for n in range(N):
                fout.write('%13.7f  %13.7f  %i  %13.7f\n' % (adata[n,i,0],adata[n,i,1],1,t))
            fout.write('\n\n')
        fout.close() 
#   def _updateLog(self,**opts):
#       is_first = opts.get('init',False)
#       # Initialize logfile
#       if is_first and self.islogfile and not self.isrestart:
#           self.logfile = file(self._NAME+'.log','w')
#           self.logfile.write(SEP)
#           self.logfile.write('# '+self._NAME+'\n')
#           self.logfile.write(SEP)
#           self.logfile.write('# step time KinE PotE TotE <state>\n')
#       # Appending logfile for [restart] calculation
#       if is_first and self.islogfile and self.isrestart:
#           self.logfile = file(self._NAME+'.log','a+w')
#           self.logfile.write(SEP)
#           self.logfile.write('# '+self._NAME+'\n')
#           self.logfile.write(SEP)
#           self.logfile.write('# step time KinE PotE TotE <state>\n')
#       kinE = self.getKineticEnergy()
#       potE = self.external_E(self.X)
#       totE = potE + kinE
    def _step(self,**opts):
        """
        Do an integration step
        """
        integrator = opts.get('integrator','VelocityVerlet')
        if integrator == 'VelocityVerlet':
           (self.X,self.V,self._g) = mdi.velocityVerlet(self.X,
                                                        self.V,
                                                        self._g,
                                                        self.M,
                                                        self.DT,
                                                        self.func_EG,
                                                        constraints=self.constraints)
    def _checkpoint(self):
        """
        Use the normal saving mechanism to checkpoint
        """
        self.save(fname=self._NAME+'.rst')
        self.traj2VTF()
##    def _getTimeDeriv(self,**opts):
##chage this part
    def _getNumGrad(self,**opts):
        dx = self.dxgrad
        def g(x):
            e = self.func_E(x)
            grad = np.zeros(x.size,float)
            for i in range(x.size):
                x1 = x.copy()
                x1[i] = x1[i] + 0.5*dx
                f1 = self.func_E(x1)
                x0 = x.copy()
                x0[i] = x0[i] - 0.5*dx
                f0 = self.func_E(x0)
                grad[i] = (f1-f0)/dx
            return e,grad
        return g
    def _indexfromtime(self,a_t):
        dt = self.DT
        tf = float(a_t)
        i,r = divmod(tf,dt)
        if r <= 0.5*dt:
            i = int(i)
        else:
            i = int(i) + 1
        return i 
[docs]
def load(filename):
    """
    Load a SimulationBox from file and return it
    """
    s = SimulationBox().load(filename)
    return s