#!/usr/bin/env python
#* **************************************************************************
#*
#* CDTK, Chemical Dynamics Toolkit
#* A modular system for chemical dynamics applications and more
#*
#* Copyright (C) 2011, 2012, 2013, 2014, 2015, 2016
#* Oriol Vendrell, DESY, <oriol.vendrell@desy.de>
#*
#* Copyright (C) 2017, 2018, 2019
#* Ralph Welsch, DESY, <ralph.welsch@desy.de>
#*
#* Copyright (C) 2020, 2021, 2022, 2023
#* Ludger Inhester, DESY, ludger.inhester@cfel.de
#*
#* This file is part of CDTK.
#*
#* CDTK is free software: you can redistribute it and/or modify
#* it under the terms of the GNU General Public License as published by
#* the Free Software Foundation, either version 3 of the License, or
#* (at your option) any later version.
#*
#* This program is distributed in the hope that it will be useful,
#* but WITHOUT ANY WARRANTY; without even the implied warranty of
#* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#* GNU General Public License for more details.
#*
#* You should have received a copy of the GNU General Public License
#* along with this program. If not, see <http://www.gnu.org/licenses/>.
#*
#* **************************************************************************
from builtins import str
from builtins import map
import sys
import os
import shutil
from optparse import OptionParser
import pickle as cpickle
import numpy as np
import CDTK.Dynamics.Trajectory as trj
import CDTK.Dynamics.Trajectory_SH as trjsh
import CDTK.Dynamics.MCElectronDynamics as mced
import CDTK.Tools.Inputfile as inf
import CDTK.Tools.EField as efi
import CDTK.Tools.MolecularMechanics as mm
import CDTK.Interfaces.WrapperInterface as win
import CDTK.Tools.Conversion as cv
import CDTK.Tools.Utils as uti
[docs]
def start():
def pickle_trajSH(traj):
"""
This function pickles a trajectory_SH object. It saves and removes stuff
that can't be pickled (see pickle documentation.)
Input:
traj --- the trajectory_SH to be pickled
"""
# TODO: move this into a method of the trajectory class
# save and remove stuff that can't be pickled
tmp_EGrad = traj.f_EGrad
traj.f_EGrad = None
tmp_W = traj.f_W
traj.f_W = None
tmp_D = traj.f_D
traj.f_D = None
tmp_DD = traj.f_DD
traj.f_DD = None
tmp_NAC = traj.f_NAC
traj.f_NAC = None
tmp_overlap = traj.f_overlap
traj.f_overlap = None
tmp_q = traj.q
traj.q = None
tmp_getPartial = traj.getPartial
traj.getPartial = None
hasES = traj.ES is not None
#hasConst = traj.const is not None
if hasES:
tmp_QCE = traj.ES.QCE
traj.ES.QCE = None
tmp_field = traj.ES.field
traj.ES.field = None
#if hasConst:
# tmp_const = traj.const
# traj.const = None
# pickle
picklefile = open('pickle.dat', "wb" )
cpickle.dump(traj, picklefile, -1)
picklefile.close()
# restore stuff that can't be pickled
traj.f_EGrad = tmp_EGrad
traj.f_W = tmp_W
traj.f_D = tmp_D
traj.f_DD = tmp_DD
traj.f_NAC = tmp_NAC
traj.f_overlap = tmp_overlap
traj.q = tmp_q
traj.getPartial = tmp_getPartial
if hasES == True:
traj.ES.QCE = tmp_QCE
traj.ES.field = tmp_field
#if hasConst:
# traj.const = tmp_const
return
def pickle_traj(traj):
"""
This function pickles a trajectory object. It saves and removes stuff
that can't be pickled (see pickle documentation.)
Input:
traj --- the trajectory to be pickled
"""
# save and remove stuff that can't be pickled
tmp_EGrad = traj.f_EGrad
traj.f_EGrad = None
hasES = traj.ES is not None
tmp_getPartial = traj.getPartial
traj.getPartial = None
#hasConst = traj.const is not None
if hasES:
tmp_QCE = traj.ES.QCE
traj.ES.QCE = None
tmp_field = traj.ES.field
traj.ES.field = None
#if hasConst:
# tmp_const = traj.const
# traj.const = None
# pickle
picklefile = open('pickle.dat', "wb" )
cpickle.dump(traj, picklefile, -1)
picklefile.close()
# restore stuff that can't be pickled
traj.f_EGrad = tmp_EGrad
if hasES:
traj.ES.QCE = tmp_QCE
traj.ES.field = tmp_field
traj.getPartial = tmp_getPartial
#if hasConst:
# traj.const = tmp_const
return
IDLEN = 8
TINY = 1.0e-8
PREFIX = 'trj'
# --------------------------------------------------------------------------
# Parse command line options
# --------------------------------------------------------------------------
parser=OptionParser()
parser.add_option('-i','--input-file',
dest='input_file',
type='str',
default=None,
help='Path to input file')
parser.add_option('-d','--project-dir',
dest='project_dir',
type='str',
default=None,
help='Project name used to for output directory')
parser.add_option('-l','--local',
dest='is_local',
action='store_true',
default=False,
help='Trajectory is run in current (PWD) directory')
parser.add_option('-r','--restart',
dest='restart',
action='store_true',
default=False,
help='Whether to restart the calculation.')
parser.add_option('-p', '--pickle',
dest='doPickle',
action='store_true',
default=False,
help='Save data in binary format using pickle')
opts, args = parser.parse_args(sys.argv[1:])
restart=opts.restart
# --------------------------------------------------------------------------
# Parse input file
# --------------------------------------------------------------------------
basename = ''
if opts.input_file:
input_file = opts.input_file
basename = input_file.split('.')[0]
elif os.path.exists('input_trj'):
input_file = 'input_trj'
else:
raise ValueError('No input file specified')
I = inf.Inputfile(input_file).sections
# --------------------------------------------------------------------------
# Set projectname
# --------------------------------------------------------------------------
if opts.is_local:
projdir = '.'
elif opts.project_dir:
projdir = opts.project_dir
else:
_id = uti.getUniqueID(IDLEN)
_date = uti.getDateString()
projdir = PREFIX + '_' + basename + '_' + _date + '_' + _id
# --------------------------------------------------------------------------
# Create name directory, copy input file and change to it
# --------------------------------------------------------------------------
if not os.path.exists(projdir):
os.mkdir(projdir)
if not restart:
shutil.copy(input_file,projdir+'/input_trj_tmp')
shutil.move(projdir+'/input_trj_tmp',projdir+'/input_trj')
currentDir = os.getcwd()
os.chdir(projdir)
open('stop','w').close() # empty "stop" file
# --------------------------------------------------------------------------
# Read atom list and atomic coordinates/velocities
# --------------------------------------------------------------------------
# - Atom list
if os.path.exists('atomlist'):
atomlist = uti.readColumnFile('atomlist')
elif 'cartpos' in I:
atomlist = I['cartpos']['atomlist']
else:
raise ValueError('atom list (./atomlist) not found')
# - Atomic positions
if os.path.exists('atompos'):
xcart = np.array( uti.readColumnFile('atompos'),float )
elif 'cartpos' in I:
xcart = np.array( I['cartpos']['coordinates'], float )
if I['system']['xunit'][0] == 'an':
xcart = xcart * cv.an2au
else:
raise ValueError('atomic positions (./atompos) not found')
# - Atomic velocities
if os.path.exists('atomvel'):
vcart = np.array( uti.readColumnFile('atomvel'),float )
elif 'cartvel' in I:
vcart = np.array( I['cartvel']['coordinates'], float )
else:
vcart = np.zeros(len(xcart),float)
nat = len(atomlist)
atomnums = []
for atom in atomlist:
atomnums.append(cv.periodicTable[atom]['atomic_number'])
atommass = []
atommass3 = []
for atom in atomlist:
atommass.append(cv.periodicTable[atom]['atomic_mass'])
atommass3.append(cv.periodicTable[atom]['atomic_mass'])
atommass3.append(cv.periodicTable[atom]['atomic_mass'])
atommass3.append(cv.periodicTable[atom]['atomic_mass'])
uti.listToFile(atomlist,'atomlist')
uti.listToFile(atomnums,'atomnums')
uti.listToFile(atommass,'atommass')
# --------------------------------------------------------------------------
# (MM) Read atom list and atomic coordinates/velocities (Optional)
# --------------------------------------------------------------------------
if 'mmregion' in I:
if 'dt' in I['mmregion']:
mm_dt = cv.numval( I['mmregion']['dt'] , 'time' )
else:
mm_dt = cv.numval( I['trajectory']['dt'] , 'time' )
# - MM mode
if 'mode' in I['mmregion']:
mm_mode = I['mmregion']['mode'][0]
else:
mm_mode = 'electron' # default value
# - MM Atom list
if 'cartpos_mm' in I:
mm_atomlist = I['cartpos_mm']['atomlist']
else:
mm_atomlist = []
# - MM Atomic positions
if 'cartpos_mm' in I:
mm_xcart = np.array( I['cartpos_mm']['coordinates'], float )
if I['system']['xunit'][0] == 'an':
mm_xcart = mm_xcart * cv.an2au
else:
mm_xcart = []
# - MM atomic velocities
if 'cartvel_mm' in I:
mm_vcart = np.array( I['cartvel_mm']['coordinates'], float )
else:
mm_vcart = np.zeros(len(mm_xcart),float)
# - MM atom masses
mm_atommass = []
mm_atommass3 = []
for atom in mm_atomlist:
mm_atommass.append(cv.periodicTable[atom]['atomic_mass'])
mm_atommass3.append(cv.periodicTable[atom]['atomic_mass'])
mm_atommass3.append(cv.periodicTable[atom]['atomic_mass'])
mm_atommass3.append(cv.periodicTable[atom]['atomic_mass'])
# Molecular mechanics instance
MM = mm.MolecularMechanics(dt=mm_dt,mode=mm_mode)
MM.R = mm_xcart
MM.V = mm_vcart
MM.atomlist = mm_atomlist
MM.atommasses = mm_atommass
MM.atommasses3 = mm_atommass3
# QM region properties
MM.natoms_qm = len(atomlist)
MM.qm_R = xcart
MM.qm_V = vcart
MM.qm_atomlist = atomlist
MM.qm_M = atommass3
# Point charge mass and atom number constructed in QCE
MM.mm_to_point_charge(qce=I['system']['qchemistry'][0], mode=mm_mode)
uti.listToFile(MM.atomlist,'MM_atomlist')
uti.listToFile(MM.atommasses,'MM_atommass')
uti.listToFile(MM.x_C, 'MM_X_charge')
uti.listToFile(MM.x_M, 'MM_X_atommass')
else:
MM = None
# --------------------------------------------------------------------------
# Create interface to quantum chemistry engine (QCE)
# --------------------------------------------------------------------------
is_keepinp = False
is_keeplog = False
if 'keepinp' in I['system']: is_keepinp = True
if 'keeplog' in I['system']: is_keeplog = True
if I['system']['qchemistry'] == ['gamess']:
import CDTK.Interfaces.GamessUSInterface as gin
QCE = gin.gamess()
QCE.is_storeinp = is_keepinp # Keep inp file of every calculation performed
QCE.is_storelog = is_keeplog # Keep log file of every calculation performed
QCE.atomicSymbols = uti.changeIsotopeSymbols(atomlist)
QCE.atomicNumbers = atomnums
IG = I.get('gamess',None) # IG -> $gamess input section, if any
if IG:
QCE.init_input_options( IG )
if I['system']['qchemistry'] == ['molcas']:
import CDTK.Interfaces.MolcasInterface as mol
QCE = mol.molcas()
QCE.is_storeinp = is_keepinp # Keep inp file of every calculation performed
QCE.is_storelog = is_keeplog # Keep log file of every calculation performed
QCE.atomicSymbols = uti.changeIsotopeSymbols(atomlist)
QCE.atomicNumbers = atomnums
IM = I.get('molcas',None) # IM -> $molcas input section, if any
if IM:
QCE.init_input_options( IM )
# for XMolecule as a quantum chemistry tool
if I['system']['qchemistry'] == ['xmolecule']:
import CDTK.Interfaces.XMoleculeInterface as xmim
QCE = xmim.xmolecule()
QCE.is_storeinp = is_keepinp # Keep inp file of every calculation performed
QCE.is_storelog = is_keeplog # Keep log file of every calculation performed
QCE.atomicSymbols = uti.changeIsotopeSymbols(atomlist)
QCE.atomicNumbers = atomnums
QCE.atomicMasses = np.array(atommass3,float)*cv.am2au
IX = I.get('xmolecule', {}) # IX -> $xmolecule input section, if any
# also give position to prepare the initialization
if MM is not None:
# Pass point charges as additional arguments
xcart = np.append( xcart, MM.x_R )
vcart = np.append( vcart, MM.x_V )
atomlist = np.append( np.asarray(atomlist), MM.x_L )
atommass = np.append( atommass, MM.x_M )
atommass3 = np.append( atommass3, MM.x_M3 )
QCE.init_input_options( IX, atomlist[:nat], xcart[:3*nat], x_pos=MM.x_R, x_charges=MM.x_C)
else:
if restart:
with open('R.log','r') as posfile:
for line in posfile:
vals = line.split()
vals = np.array(list(map(float,vals[1:])))
xcart = vals
QCE.init_input_options( IX, atomlist, xcart)
if I['system']['qchemistry'] == ['interpolation']: #
import CDTK.Interfaces.Interpolation_Interface as inter #
QCE = inter.interpolation() #
QCE.atomicSymbols = uti.changeIsotopeSymbols(atomlist) #
QCE.atomicNumbers = atomnums #
II = I.get('interpolation',None) # II -> $interpolation input section, if any #
if II: #
QCE.init_input_options( II) #
if I['system']['qchemistry'] == ['embedding']:
import CDTK.Interfaces.EmbeddingInterface as emb
IM = I.get('embedding',None)
IM['quantum_options'] = (I.get(IM['quantum'][0]))
QCE = emb.embedding(atomlist, atomnums, atommass3, xcart, **IM)
# --------------------------------------------------------------------------
# Add electric field, if any
# --------------------------------------------------------------------------
field_opts = I.get('efield',None)
if field_opts is not None:
EF = efi.EField()
EF.init_from_input_options(field_opts)
QCE.efields = [EF]
# --------------------------------------------------------------------------
# Attach quantum chemistry engine to general wrapper interface
# --------------------------------------------------------------------------
W = win.wrapperInterface()
W.QCE = QCE
forces = I['quantum'].get('forces',None)
if forces:
W.forces = forces[0]
else:
W.forces = 'analytical'
# --------------------------------------------------------------------------
# Initialize trajectory object
# --------------------------------------------------------------------------
logn = max(int( I['trajectory'].get('logn', [1])[0]),1)
dt = cv.numval( I['trajectory']['dt'] , 'time' )
tf = cv.numval( I['trajectory']['tf'] , 'time' )
dt_max = dt
field_explicit = I['trajectory'].get('field_explicit', None)
# threshold for energy conservation
if 'de_thresh' in I['trajectory']:
dE_thresh = float( I['trajectory']['de_thresh'][0] )
else:
dE_thresh = None
p_thresh = I['trajectory'].get('p_thresh', [1.0])[0]
noreturn = I['trajectory'].get('noreturn', [None])[0]
try:
rseed = int( I['system']['rseed'][0] )
except:
rseed = np.random.randint(np.iinfo(np.int32).max)
print("# Random seed initialized with rseed=" + str(rseed))
try:
rcom = I['trajectory']['rcom'][0].lower() == 'yes'.lower()
except:
rcom = False
try:
rrot = I['trajectory']['rrot'][0].lower() == 'yes'.lower()
except:
rrot = False
# check for electron dynamics input
mc_electron_dynamics = 'mced' in I
ES = None
if mc_electron_dynamics:
if I['mced'].get('event', ['none'])[0].lower() == 'read':
ES = mced.ElectronicState(state='doNotRun', dte=0.0, ratio=0.0, pe=0.0, stop=False)
with open('event.log','r') as eventfile:
for line in eventfile:
if len(line.split()) > 1:
ES.log['event'].append(tuple(line.split()))
else:
dte = cv.numval( I['mced']['dte'] , 'time' )
state = QCE.get_occ()
ratio = I['mced'].get('ratio', [0.1])[0]
stop = I['mced'].get('stop', [False])[0]
# pe = QCE.pe
if I['mced'].get('pdelay',False):
pdelay = cv.numval( I['mced']['pdelay'] , 'time' )
else:
pdelay=None
ES = mced.ElectronicState(state=state, dte=dte, ratio=ratio, stop=stop, pdelay=pdelay)
ES.QCE = QCE
ES.field = QCE.efields
ES.initLogs()
if I['quantum']['type'][0] == 'gs':
trajectory = trj.Trajectory(3*nat)
trajectory.rseed = rseed
trajectory.DIR = os.getcwd()
trajectory.R = xcart.flatten()
trajectory.V = vcart.flatten()
trajectory.M = np.array(atommass3,float)*cv.am2au
trajectory.dt = dt
trajectory.f_EGrad = W.f_EGrad_GS
trajectory.rcom = rcom
trajectory.rrot = rrot
trajectory.dE_thresh = dE_thresh
trajectory.MM = MM
trajectory.atomlist = atomlist
trajectory.logn = logn
if 'log_moe' in I['trajectory']:
trajectory.f_MOE = W.f_MOE
trajectory.log_MOE = bool( I['trajectory']['log_moe'][0] )
if mc_electron_dynamics:
ES.DIR = trajectory.DIR
trajectory.ES = ES
if restart:
trajectory.restart()
elif I['quantum']['type'][0] == 'bo':
trajectory = trj.Trajectory(3*nat)
trajectory.DIR = os.getcwd()
trajectory.rseed = rseed
trajectory.R = xcart.flatten()
trajectory.V = vcart.flatten()
trajectory.M = np.array(atommass3,float)*cv.am2au
trajectory.dt = dt
trajectory.f_EGrad = W.f_EGrad_ES
trajectory.S = int( I['quantum']['istate'][0] )
trajectory.rcom = rcom
trajectory.rrot = rrot
trajectory.dE_thresh = dE_thresh
trajectory.logn = logn
if mc_electron_dynamics:
ES.DIR = trajectory.DIR
trajectory.ES = ES
if restart:
trajectory.restart()
elif I['quantum']['type'][0] == 'fssh':
nstates = int( I['quantum']['nstates'][0] )
istate = int( I['quantum']['istate'][0] )
QCE.nstates = nstates
trajectory = trjsh.Trajectory_SH(3*nat,nstates)
trajectory.rseed = rseed
trajectory.S = istate
trajectory.logn = logn
if 'icoeffs' in I['quantum']:
icoeffs = list(map(float, I['quantum']['icoeffs'][0].split(',')))
if len( icoeffs ) != nstates:
raise ValueError('icoeffs, if specified, must be given for all states')
if len(np.nonzero(icoeffs)[0]) == 1 and np.nonzero(icoeffs)[0] != istate:
raise ValueError('Pure state specified by icoeffs does not match istate')
icoeffs = [ i/ np.sqrt(np.sum(np.abs(icoeffs)**2)) for i in icoeffs]
trajectory.C[:] = icoeffs[:]
else:
trajectory.C[:] = 0
trajectory.C[istate] = 1
if 'no_hops' in I['quantum']:
trajectory.no_hops = I['quantum']['no_hops'][0]
trajectory.DIR = os.getcwd()
trajectory.R = xcart.flatten()
trajectory.V = vcart.flatten()
trajectory.M = np.array(atommass3,float)*cv.am2au
trajectory.dt = dt
trajectory.dt_max = dt_max
trajectory.p_thresh = p_thresh
trajectory.noreturn = noreturn
trajectory.field_explicit = field_explicit
trajectory.dE_thresh = dE_thresh
trajectory.f_EGrad = W.f_EGrad_NA
trajectory.f_W = W.f_W
trajectory.f_D = W.f_D
trajectory.f_overlap = None
if 'trivialcrossing' in I['quantum']:
trajectory.trivialCrossing = I['quantum']['trivialcrossing'][0]
if I['quantum']['trivialcrossing'][0] == 'overlap':
trajectory.f_overlap = W.f_overlap
W.QCE.calcOverlaps = True
else:
raise ValueError('unknown value for trivialCrossing {:s}'.format(I['quantum']['trivialcrossing'][0]))
trajectory.rcom = rcom
trajectory.rrot = rrot
trajectory.dE_thresh = dE_thresh
trajectory.logn = logn
if mc_electron_dynamics:
ES.DIR = trajectory.DIR
trajectory.ES = ES
if restart:
trajectory.restart()
rescaling = I['quantum'].get('rescaling',None)
if rescaling:
if rescaling[0] == 'nac':
trajectory.rescaling = 'nac'
trajectory.f_NAC = W.f_NAC
if rescaling[0] == 'grd':
trajectory.rescaling = 'grd'
else:
if QCE.engine == 'GamessUS':
trajectory.rescaling = 'nac'
trajectory.f_NAC = W.f_NAC
elif QCE.engine == 'molcas':
trajectory.rescaling = 'grd'
elif QCE.engine == 'interpolation':
trajectory.rescaling = 'grd'
else:
trajectory.rescaling = 'nac'
trajectory.f_NAC = W.f_NAC
efield = I['quantum'].get('efield',None)
e_dot_coupling = I['quantum'].get('e_dot_coupling',None)
e_coupling = I['quantum'].get('e_coupling',None)
if efield:
if e_dot_coupling:
trajectory.f_DD = W.f_Dmu
elif e_coupling:
trajectory.f_DD = W.f_Dmu_E
if len( I['quantum']['type'] ) > 1:
if I['quantum']['type'][1] == 'nohop':
trajectory.is_sh = False
first_step = I['quantum'].get('first_step',None)
if first_step:
trajectory.first_step = float(first_step[0])
else:
sys.exit('$quantum|type needs to be specified')
if 'xmolecule' in I:
trajectory.getPartial = QCE.getPartialCharges
# log electric field / pulse
if field_opts:
EF.writeLog(np.arange(0., tf + dt, dt), filename="field.log")
qintegrator = I['quantum'].get('quantum_integrator',None)
if qintegrator:
trajectory.quantum_step_integrator = qintegrator[0]
else:
trajectory.quantum_step_integrator = 'ipicture'
if I['system']['qchemistry'] == ['embedding'] and QCE.hasConstraints():
trajectory.set_constraints(QCE.getConstraintsCorrection)
#print("Enforcing constraints ... ")
#x0 = QCE.enforceConstraints(trajectory.R)
#trajectory.R = x0
if not restart:
trajectory.t = cv.numval( I['trajectory'].get('ti', ['0.0', 'fs']) , 'time' )
#exit(-1)
# --------------------------------------------------------------------------
# Run trajectory
# --------------------------------------------------------------------------
id = trajectory._ID
while trajectory.t < tf:
sys.stdout.flush()
# print trajectory.t
if os.path.isfile('stop'):
trajectory.integrate() # dt step
if trajectory._stepnum % trajectory.logn == 0 or \
trajectory.t + trajectory.dt >= tf:
trajectory.log_to_file()
# if pickeling enabled
if opts.doPickle == True:
# if neccessary save XMolecule
# if I['system']['qchemistry'] == ['xmolecule']:
# QCE.write_xm_chkpt()
if I['quantum']['type'][0] == 'fssh':
pickle_trajSH(trajectory)
else:
pickle_traj(trajectory)
else:
break
# go back to original directory
os.chdir(currentDir)
if __name__ == "__main__":
start()