#!/usr/bin/env python
#*  **************************************************************************
#*
#*  CDTK, Chemical Dynamics Toolkit
#*  A modular system for chemical dynamics applications and more
#*
#*  Copyright (C) 2011, 2012, 2013, 2014, 2015, 2016
#*  Oriol Vendrell, DESY, <oriol.vendrell@desy.de>
#*
#*  Copyright (C) 2017, 2018, 2019
#*  Ralph Welsch, DESY, <ralph.welsch@desy.de>
#*
#*  Copyright (C) 2020, 2021, 2022, 2023
#*  Ludger Inhester, DESY, ludger.inhester@cfel.de
#*
#*  This file is part of CDTK.
#*
#*  CDTK is free software: you can redistribute it and/or modify
#*  it under the terms of the GNU General Public License as published by
#*  the Free Software Foundation, either version 3 of the License, or
#*  (at your option) any later version.
#*
#*  This program is distributed in the hope that it will be useful,
#*  but WITHOUT ANY WARRANTY; without even the implied warranty of
#*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#*  GNU General Public License for more details.
#*
#*  You should have received a copy of the GNU General Public License
#*  along with this program.  If not, see <http://www.gnu.org/licenses/>.
#*
#*  **************************************************************************
"""
This module has several functions to transform coordinates and velocities obtained
from Gromacs calculation into suitable input files for CDTK. Use ``gmx2cdtk -h`` 
to find out how to use this program.
Basic units in GMX, please look at the following web.
https://manual.gromacs.org/documentation/2019/reference-manual/definitions.html#table-basicunits
  
1 Picoseconds = 41341.374575751 Atomic Unit Of Time
1 nm = 18.8973 Bohr
1 nm/ps = 18.8973/41341.374575751 Bohr/a.u. = 0.00045710381413114194
1 nm/ps = 4.571038E-04 Bohr/a.u.
"""
import subprocess
from optparse import OptionParser
import numpy as np
import CDTK.Tools.Conversion as cv
nm2bohr = cv.an2au * 10  # 18.8973
nm_ps2bohr_au = cv.an2au * 10 / (cv.fs2au * 1000)  # 4.571038e-04
[docs]
def get_parser():
    parser = OptionParser()
    parser = OptionParser(usage="usage: %prog -g filename.gro [further options]")
    parser.add_option(
        "-g",
        "--GRO_file",
        dest="gmxFile",
        type="str",
        default=None,
        help="GMX .gro input file.",
    )
    parser.add_option(
        "-t",
        "--trr_file",
        dest="trrFile",
        type="str",
        default=None,
        help="GMX .trr input file (optional).",
    )
    parser.add_option(
        "-n",
        "--cen_mol",
        dest="centeredMolecule",
        type="int",
        default=13,
        help="Please enter the index of the molecule that you want to center. It is an integer number and not exceeding the number of water molecules in the box.",
    )
    parser.add_option(
        "-s",
        "--skip",
        dest="skip",
        type="float",
        default=None,
        help="Skip molecules (residues) with center of mass distance from the center large than this value. Value is given in atomic units",
    )
    parser.add_option(
        "--centerMethod",
        type="choice",
        dest="centerMethod",
        default="first",
        choices=["first", "com"],
        help="How to center: according to first atom of residue (first) or center of mass of residue (com)",
    )
    parser.add_option(
        "-o",
        type="str",
        dest="gmxOut",
        default=None,
        help="GMX .gro output file (optional)",
    )
    parser.add_option(
        "-r",
        "--output_trr",
        type="str",
        dest="trrOut",
        default=None,
        help="GMX .trr output file (optional)",
    )
    parser.add_option(
        "--noreorder",
        action='store_false',
        dest="doReorder",
        default=True,
        help="do not reorder the molecules according to distance (optional, default: do sort the molecules)",
    )
    return parser 
oxygenNames = ['O', 'OW', 'OD', 'O_', 'OS']
hydrogenNames = ['H', 'HW', 'HW1', 'HW2', 'HD', 'HC']
carbonNames = ['C', 'CA', 'CR', 'CW', 'C_', 'CS']
[docs]
def atomName2Element(atomName):
    """
    converts atomName into element name
    Parameters
    ----------
    atomName : string
    Returns
    -------
    element name: string
    """
    atomName = atomName.strip()
    atomName = ''.join(c for c in atomName if (not c.isnumeric() and c != ' '))
    if atomName in oxygenNames:
        return "O"
    elif atomName in hydrogenNames:
        return "H"
    elif atomName in carbonNames:
        return "C"
    elif atomName in cv.periodicTable.keys():
        return atomName
    else:
        raise ValueError("AtomName {:s} not recognized".format(atomName)) 
[docs]
def atomName2atomNum(atomName):
    """
    converts atomName into element number
    Parameters
    ----------
    atomName : string
    Returns
    -------
    element number: index
    """
    atomName = atomName.strip()
    atomName = ''.join(c for c in atomName if (not c.isnumeric() and c != ' '))
    if atomName in oxygenNames:
        return cv.periodicTable['O']['atomic_number']
    elif atomName in hydrogenNames:
        return cv.periodicTable['H']['atomic_number']
    elif atomName in carbonNames:
        return cv.periodicTable['C']['atomic_number']
    elif atomName in cv.periodicTable.keys():
        return cv.periodicTable[atomName]['atomic_number']
    else:
        raise ValueError("AtomName {:s} not recognized".format(atomName)) 
[docs]
def atomName2atomMass(atomName):
    """
    converts atomName into element mass
    Parameters
    ----------
    atomName : string
    Returns
    -------
    element mass in atomic units: float
    """
    atomName = atomName.strip()
    atomName = ''.join(c for c in atomName if (not c.isnumeric() and c != ' '))
    if atomName in oxygenNames:
        return cv.periodicTable['O']['atomic_mass']
    elif atomName in hydrogenNames:
        return cv.periodicTable['H']['atomic_mass']
    elif atomName in carbonNames:
        return cv.periodicTable['C']['atomic_mass']
    elif atomName in cv.periodicTable.keys():
        return cv.periodicTable[atomName]['atomic_mass']
    else:
        raise ValueError("AtomName {:s} not recognized".format(atomName)) 
[docs]
def readtrr(xtcfilename, gmxfilename):
    try:
        import MDAnalysis as mda
    except ModuleNotFoundError:
        raise ModuleNotFoundError("Please install module MDAnalysis for reading or writing gromacs trr files")
    u = mda.Universe(gmxfilename, xtcfilename)
    t = u.trajectory[-1]
    residueIndex = u.atoms.resids
    residueName = u.atoms.resnames
    atomName = u.atoms.names
    atomName = np.array([a.replace('_','') for a in atomName])
    topol = np.array([str(ri) + rn.strip() + "_" + aN.strip()
                      for ri, rn, aN in zip(residueIndex, residueName, atomName)])  # topology with cdtk format
    atomnums = np.array(list(map(atomName2atomNum, atomName)))
    atomlist = np.array(list(map(atomName2Element, atomName)))
    atommass = np.array(list(map(atomName2atomMass, atomName)))
    v_box = u.dimensions[0:3] / 10. * nm2bohr
    molec = {
        "geom": t.positions / 10. * nm2bohr,
        "vel": t.velocities / 10. * nm_ps2bohr_au,
        "topol": topol,
        "atomlist": atomlist,
        "atomnums": atomnums,
        "atommass": atommass,
        "boxsize": v_box,
        "residueIndex": residueIndex,
        "residueName": residueName,
        "atomName": atomName,
    }
    return molec 
[docs]
def writetrr(trrFilename, molec):
    """write a gromacs .trr file based on the data from molec
    uses the MDAnalysis python package
    Args:
        trrFilename (string): filename
        molec (dictionary): data
    """
    try:
        import MDAnalysis as mda
    except ModuleNotFoundError:
        raise ModuleNotFoundError("Please install module MDAnalysis for reading or writing gromacs trr files")
    n_atoms = molec["geom"].shape[0]
    resid, atResIdx = np.unique(molec['residueIndex'], return_index=True)
    resnames = np.array([ a.strip()+b for a,b in zip(molec['residueName'][atResIdx], resid.astype(str))])
    n_residues = resid.shape[0]
    atom_resindex = molec['residueIndex']
    u = mda.Universe.empty(n_atoms,
                        n_residues=n_residues,
                        n_segments=1,
                        atom_resindex=atom_resindex-1,
                        velocities=True
                        )
    u.add_TopologyAttr('name', [a.strip() for a in molec["atomName"]])
    u.add_TopologyAttr('types', [a.strip() for a in molec["atomlist"]])
    u.add_TopologyAttr('masses', molec['atommass'] )
    u.add_TopologyAttr('segid', ['SYSTEM'])
    u.add_TopologyAttr('resname', resnames)
    u.add_TopologyAttr('resid', resid)
    u.load_new(molec['geom'] * 10 / nm2bohr, order='fac')
    u.trajectory[0].velocities = molec['vel'] * 10. / nm_ps2bohr_au
    dimensions = np.zeros(6)
    dimensions[0:3] = molec["boxsize"] * 10. / nm2bohr
    dimensions[3:6] = np.array([90.,90.,90.])
    u.trajectory[0].dimensions = dimensions
    u.atoms.write(trrFilename) 
[docs]
def readgmx(filename, check_water=False):
    """
    Read gromacs file and return the topolgy, element types, coordinates,
    velocities of all atoms and the box size.
    Parameters
    ----------
    filename : string
        Name of gromacs file.
    Returns
    -------
    molec : dict of ndarray.
        Dictionary containing the following arrays:
        - geom : (n_atoms, 3) ndarray of float.
            Coordinates of all atoms.
        - vel : (n_atoms, 3) ndarray of float.
            Velocities of all atoms.
        - residueIndex: (n_atoms) ndarray of ints
        - residueName: (n_atoms) ndarray of str
        - topol : (n_atoms) ndarray of str.
            String identifiers for the topology of the system (atoms in molecules).
        - atomlist : (n_atoms) ndarray of str.
            String identifiers of all atoms types (e.g., "O", "H").
        - atomnums : (n_atoms) ndarray of int.
            Intigers with atomic numners of each element in atomloist.
        - atommass : (n_atoms) ndarray of float.
            Atomic masses of each element in atomlist.
        - boxsize : (3) ndarray of float.
            Box size (len_x, len_y, len_z).
        - residueName : (n_atoms) ndarray of str
            Name of residue for each atom 
        - residueIndex : (n_atoms) ndarray of int
            Index of residue for each atom
        - atomName : (n_atoms) ndarray of str
            Name of atoms (e.g. "OW", "HW1")
   """
    # the .gro file format has fixed column positions as delimiters ("%5d%-5s%5s%5d%8.3f%8.3f%8.3f%8.4f%8.4f%8.4f\n")
    delimiter = [5, 5, 5, 5, 8, 8, 8, 8, 8, 8]
    residueIndex = np.genfromtxt(filename, skip_header=2, skip_footer=1, usecols=(
        0), dtype=int, delimiter=delimiter)
    residueName = np.genfromtxt(filename, skip_header=2, skip_footer=1, usecols=(
        1), dtype='<U5', delimiter=delimiter)
    atomName = np.genfromtxt(filename, skip_header=2, skip_footer=1, usecols=(
        2), dtype='<U5', delimiter=delimiter)
    atomName = np.array([a.replace('_','') for a in atomName])
    topol = np.array([str(ri) + rn.strip() + "_" + aN.strip()
                      for ri, rn, aN in zip(residueIndex, residueName, atomName)])  # topology with cdtk format
    # Get position and velocities of all atoms
    posvel = np.genfromtxt(
        filename, skip_header=2, skip_footer=1, usecols=(range(4, 10)), dtype=float, delimiter=delimiter
    )
    xyz = posvel[:, :3] * nm2bohr
    vxyz = posvel[:, 3:] * nm_ps2bohr_au
    # Get box size from last line of file.
    last_line = subprocess.check_output(["tail", "-1", filename])
    v_box = np.array(last_line.split(), dtype=float) * nm2bohr
    # Get atomic numbers, atomic masses, and element names
    atomnums = np.array(list(map(atomName2atomNum, atomName)))
    atommass = np.array(list(map(atomName2atomMass, atomName)))
    atomlist = np.array(list(map(atomName2Element, atomName)))
    molec = {
        "geom": xyz,
        "vel": vxyz,
        "topol": topol,
        "atomlist": atomlist,
        "atomnums": atomnums,
        "atommass": atommass,
        "boxsize": v_box,
        "residueIndex": residueIndex,
        "residueName": residueName,
        "atomName": np.array([a.replace('_','') for a in atomName]),
    }
    return molec 
[docs]
def writegmx(filename, molec, reIndexResidues=True):
    """
    Writes gromacs .gro file from molec stricture
    Parameters
    ----------
    filename : string
        Name of gromacs file.
    molec : dict of ndarray.
        Dictionary containing the following arrays:
        - geom : (n_atoms, 3) ndarray of float.
            Coordinates of all atoms.
        - vel : (n_atoms, 3) ndarray of float.
            Velocities of all atoms.
        - topol : (n_atoms) ndarray of str.
            String identifiers for the topology of the system (atoms in molecules).
        - atomlist : (n_atoms) ndarray of str.
            String identifiers of all atoms types.
        - atomnums : (n_atoms) ndarray of int.
            Intigers with atomic numners of each element in atomloist.
        - atommass : (n_atoms) ndarray of float.
            Atomic masses of each element in atomlist.
        - boxsize : (3) ndarray of float.
            Box size (len_x, len_y, len_z).
        - residueName : (n_atoms) ndarray of str
            Name of residue for each atom 
        - residueIndex : (n_atoms) ndarray of int
            Index of residue for each atom
        - atomName : (n_atoms) ndarray of str
            Name of atoms (e.g. "OW", "HW1")
    """
    geom = molec['geom']
    vel = molec['vel']
    topol = molec['topol']
    boxsize = molec['boxsize']
    resN = 0
    prevResIdx = 0
    with open(filename, 'w') as f:
        f.write(" generated by gmx2cdtk\n")
        f.write(" {:d}\n".format(geom.shape[0]))
        for i in range(geom.shape[0]):
            residx = molec['residueIndex'][i]
            resname = molec['residueName'][i]
            atomname = molec['atomName'][i]
            if prevResIdx == 0 or prevResIdx != residx:
                resN = resN + 1
                prevResIdx = residx
            if reIndexResidues:
                printResIdx = resN
            else:
                printResIdx = residx
            if (np.any(np.isnan(vel))):
                f.write("{:5d}{:5s}{:5s}{:5d}{:8.3f}{:8.3f}{:8.3f}\n".format(
                    printResIdx, resname, atomname, i+1,
                    geom[i, 0] / nm2bohr,
                    geom[i, 1] / nm2bohr,
                    geom[i, 2] / nm2bohr)
                )
            else:
                f.write("{:5d}{:5s}{:5s}{:5d}{:8.3f}{:8.3f}{:8.3f}{:8.4f}{:8.4f}{:8.4f}\n".format(
                    printResIdx, resname, atomname, i+1,
                    geom[i, 0] / nm2bohr,
                    geom[i, 1] / nm2bohr,
                    geom[i, 2] / nm2bohr,
                    vel[i, 0] / nm_ps2bohr_au,
                    vel[i, 1] / nm_ps2bohr_au,
                    vel[i, 2] / nm_ps2bohr_au)
                )
        f.write("{:f} {:f} {:f}\n".format(
            boxsize[0]/nm2bohr, boxsize[1]/nm2bohr, boxsize[2]/nm2bohr))
        f.close() 
[docs]
def get_coord_center(n_center, molec, centerMethod='first'):
    """
    This will return the coordinate center of residue index n_center
    Parameters
    ----------
    Ncenter : int
        Index of molecule (residue index) that you want to be the center.
    molec : dict of ndarrays.
    centerMethod: either "com" or "first"
    Returns
    -------
    ndarray of float.
        X, Y, Z coordinates of center.
    """
    atomIdx = np.flatnonzero(molec['residueIndex'] == n_center)
    if centerMethod == "first":
        center = molec['geom'][atomIdx[0]]
    elif centerMethod == "com":
        g = molec['geom'][atomIdx]
        m = molec['atommass'][atomIdx]
        center = np.sum(g * m[:, np.newaxis], axis=0) / np.sum(m)
    else:
        raise ValueError("unknown centerMethod={:s} ".format(centerMethod))
    print("Centering {:s} {:d} at position {:f} {:f} {:f}\n".
          format(molec['residueName'][n_center], molec['residueIndex'][n_center], center[0], center[1], center[2]))
    return center 
[docs]
def shift_box(shift, xyz, v_box):
    """
    Shift all molecules by shift and move atoms that fall outside the box 
    according to periodic boundary conditions
    Parameters
    ----------
    shift : (3) ndarray of float.
        X, Y, Z coordinates of new center (oxygen atom).
    xyz : (n_atoms, 3) ndarray of float.
        Original coordinates of all atoms.
    Returns
    -------
    (n_atoms, 3) ndarray of float.
        New coordinates centered arund center_new in the box.
    """
    new_xyz = xyz.copy() - shift[np.newaxis, :]
    new_xyz = new_xyz % v_box
    new_xyz[new_xyz[:, 0] < -v_box[0]/2, 0] += v_box[0]
    new_xyz[new_xyz[:, 0] > v_box[0]/2, 0] -= v_box[0]
    new_xyz[new_xyz[:, 1] < -v_box[1]/2, 1] += v_box[1]
    new_xyz[new_xyz[:, 1] > v_box[1]/2, 1] -= v_box[1]
    new_xyz[new_xyz[:, 2] < -v_box[2]/2, 2] += v_box[2]
    new_xyz[new_xyz[:, 2] > v_box[2]/2, 2] -= v_box[2]
    return new_xyz 
[docs]
def makeResiduesWhole(molec):
    """
    Rearranges the atoms by multiples of boxsize such that the residues are closest together
    Parameters
    ----------
    molec : dict of ndarray.
    Returns
    -------
    xyz : (n_atoms, 3) ndarray of float.
        new coordinates.    
    """
    box = molec['boxsize']
    residues = np.unique(molec['residueIndex'])
    newGeom = molec['geom'].copy()
    for r in residues:
        while True:  # iterate until residue is whole
            residueIsWhole = True
            atomIdx = np.flatnonzero(molec['residueIndex'] == r)
            g = newGeom[atomIdx]
            m = molec['atommass'][atomIdx]
            resCenter = np.sum(g * m[:, np.newaxis], axis=0) / np.sum(m)
            for a in atomIdx:
                for i in range(3):
                    # if an atom is shifted, reiterate the whole residue, because center of mass needs to be recalculated
                    if(newGeom[a, i] - resCenter[i] > box[i]/2):
                        newGeom[a, i] -= box[i]
                        residueIsWhole = False
                        break
                    if(newGeom[a, i] - resCenter[i] < -box[i]/2):
                        newGeom[a, i] += box[i]
                        residueIsWhole = False
                        break
            if residueIsWhole:  # if nothing has been shifted, we can procceed with the next residue
                break
    return newGeom 
[docs]
def sort_molec(molec, skipDistance=False):
    """
    Sorts the molecules (residues) according to their center-of-mass positions (distance to the origin)
    Parameters
    ----------
    molec : dict of ndarray.
        Dictionary built from reading gromax file (check readgmx documentation).
    Returns
    -------
    molec : dict of ndarray
        Updated dictionary with sorted molecules.
    """
    # calculate the center of mass distance for all residues
    residues = np.unique(molec['residueIndex'])
    resCOM = np.zeros((residues.max() + 1, 3))  # residue index starts with 1
    resCOMdist = np.zeros((residues.max() + 1))
    for r in residues:
        atomIdx = np.flatnonzero(molec['residueIndex'] == r)
        g = molec['geom'][atomIdx]
        m = molec['atommass'][atomIdx]
        resCOM[r] = np.sum(g * m[:, np.newaxis], axis=0) / np.sum(m)
    resCOMdist = np.linalg.norm(resCOM, axis=1)
    # sort the residues according to distance
    resIndSorted = np.argsort(resCOMdist)
    # form the atom index that follows the order of the residues
    ind_sort = []
    i = 0
    for r in resIndSorted:
        atomIdx = np.flatnonzero(molec['residueIndex'] == r)
        # residue 0 should not appear (residue index starts with 1)
        if atomIdx.shape[0] == 0:
            continue
        if skipDistance is not None:
            if resCOMdist[r] > skipDistance:
                continue
        ind_sort = ind_sort + atomIdx[:].tolist()
        i += len(atomIdx)
    # reorder all quantities according to reordered atom index
    molec['geom'] = molec['geom'][ind_sort]
    molec['vel'] = molec['vel'][ind_sort]
    molec['topol'] = molec['topol'][ind_sort]
    molec['atomlist'] = molec['atomlist'][ind_sort]
    molec['atomnums'] = molec['atomnums'][ind_sort]
    molec['atommass'] = molec['atommass'][ind_sort]
    molec['residueIndex'] = molec['residueIndex'][ind_sort]
    molec['residueName'] = molec['residueName'][ind_sort]
    molec['atomName'] = molec['atomName'][ind_sort]
    return molec 
[docs]
def print_geom_files(molec):
    """
    Prints input files needed for `xsample` first sorting all
    atoms from closest to the center of the box to furthest.
    The files printed are .in, .xyz and .vel.
    Parameters
    ----------
    molec : dict of ndarray.
        Dictionary built from reading gromax file (check readgmx documentation).
    """
    f_pos = "atompos"
    f_vel = "atomvel"
    f_topol = "topol"
    f_elem = "atomlist"
    f_num = "atomnums"
    f_mass = "atommass"
    np.savetxt(f_pos, molec["geom"].ravel(), fmt="%-10.6f")
    np.savetxt(f_vel, molec["vel"].ravel(), fmt="%-10.6f")
    np.savetxt(f_topol, molec["topol"], fmt="%s")
    np.savetxt(f_elem, molec["atomlist"], fmt="%s")
    np.savetxt(f_num, molec["atomnums"], fmt="%d")
    np.savetxt(f_mass, molec["atommass"], fmt="%-10.6f")
    return 
[docs]
def print_xsample(filename, center_new, xyz, vxyz, elems):
    raise NotImplementedError 
[docs]
def print_xpyder(filename, center_new, xyz, vxyz, elems):
    raise NotImplementedError 
[docs]
def start():
    parser = get_parser()
    args, _ = parser.parse_args()
    if args.gmxFile is None:
        parser.error("You need to provide a .gro file!")
    if args.trrFile is not None:
        molec_data = readtrr(args.trrFile, args.gmxFile)
    else:
        molec_data = readgmx(args.gmxFile)
    # Get center of box and shift coordinates.
    boxCenter = get_coord_center(
        args.centeredMolecule, molec_data, centerMethod=args.centerMethod)
    shifted_geom = shift_box(
        boxCenter, molec_data["geom"], molec_data["boxsize"])
    molec_data['geom'] = shifted_geom
    # make residues whole
    geom2 = makeResiduesWhole(molec_data)
    molec_data['geom'] = geom2
    # sort molecules according to distance
    if args.doReorder:
        molec_data = sort_molec(molec_data, skipDistance=args.skip)
    ekin = 0.5 * np.sum(molec_data["vel"][:,:]**2 * cv.am2au * molec_data["atommass"][:,np.newaxis])
    print(f"kinetic Energy: {ekin} a.u.")
    #print(f"removing overall momentum")
    #totalMomentum = np.sum(molec_data["vel"][:,:] * cv.am2au * molec_data["atommass"][:,np.newaxis], axis=0)
    #comVel = totalMomentum / (cv.am2au * molec_data["atommass"].sum())
    #molec_data["vel"][:,:] = molec_data["vel"][:,:] - comVel[np.newaxis,:]
    #ekin = 0.5 * np.sum(molec_data["vel"][:,:]**2 * cv.am2au * molec_data["atommass"][:,np.newaxis])
    #print(f"kinetic Energy: {ekin} a.u.")
    # create cdtk files
    print_geom_files(molec_data)
    if args.gmxOut is not None:
        writegmx(args.gmxOut, molec_data)
    if args.trrOut is not None:
        writetrr(args.trrOut, molec_data) 
if __name__ == "__main__":
    start()