#!/usr/bin/env python3

# coding: utf-8
"""
The script allows to estimate the number of sub-trajectories to obtain.
It also allows to split the trajectory into a number of sub-trajectories.
# USAGE to estimate number of sub-trajectories :
# USAGE : cut_trajectory.py
  -s : number of desired sub-trajectories (number or file)
  -gro : .gro file -xtc : .xtc file -log : name of log file (optional)
  -d : output directory (optional) -o : name of output file (optional)
  -g : group for output (optional)
  -start : start time of the trajectory (optional)
  -end : end time of the trajectory (optional)
"""

__all__ = []
__author__ = "Agnès-Elisabeth Petit"
__date__ = "01/06/2022"
__version__ = "1.0"
__copyright__ = "(c) 2022 CC-BY-NC-SA"

# Library import
import argparse
import logging
import os
import subprocess
import sys

from joblib import Parallel, delayed


def test_setup():
    global args
    args = parse_arguments()
    args.verbose = True


def parse_arguments():
    list_choices = []
    [list_choices.append(str(x)) for x in range(17)]
    parser = argparse.ArgumentParser(
        description="The script allows to estimate the number of "
                    "sub-trajectories to obtain. It also allows to split"
                    " the trajectory into a number of sub-trajectories.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        prefix_chars="-",
        add_help=True,
    )
    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        help="""Information messages to stderr""",
    )
    parser.add_argument(
        "-c",
        "--input_check",
        type=str,
        help=""".txt file obtained with gmx check -f.
        It contains information about the trajectory""",
    )
    parser.add_argument(
        "-s",
        "--nbr_sub_traj",
        type=str,
        help="""numbers of sub_trajectories""",
    )
    parser.add_argument(
        "-gro", "--gro_file", type=str, help="""My input .gro file"""
    )
    parser.add_argument(
        "-xtc", "--xtc_file", type=str, help="""My input .xtc file"""
    )
    parser.add_argument(
        "-log", "--log_output", type=str,
        default="log/cut_trajectory.log",
        help="""Output for log file. Default :
        log/cut_trajectory.log"""
    )
    parser.add_argument(
        "-d", "--output_directory", type=str,
        default="out/",
        help="""It's output Directory. Default : out/"""
    )
    parser.add_argument(
        "-g",
        "--group_output",
        type=str,
        default="0",
        choices=list_choices,
        help="""Select group for output. 0 : system, 1: protein, 2: protein-H,
         3: C-alpha, 4: Backbone, 5: MainChain, 6: MainChain+Cb,
         7: MainChain+H, 8: SideChain, 9: SideChain-H, 10: Prot-Masses,
         11: non-Protein, 12: Other, 13: POPC, 14: POT, 15: CLA,
         16: TIP3""",
    )
    parser.add_argument(
        "-start", "--start_traj", type=str,
        help="""Start of the trajectory to be cut"""
    )
    parser.add_argument(
        "-end", "--end_traj", type=str,
        help="""End of the trajectory to be cut"""
    )
    parser.add_argument(
        "-cpus", "--number_cpus", type=int,
        default=1,
        help="""Number of cpus. Default : 1"""
    )
    return parser.parse_args()


def search_nbr_steps_time_step(txt_file):
    """
    Description : Keeping the number of frames of the complete trajectory and
     the time between each frame.
    param txt_file: file obtained with gmx check.
    return: list that contains the number of frames of the complete trajectory
     and time between each frame.
    """
    if args.verbose:
        logging.info("\nFunction search_nbr_steps_time_step")
        logging.info("The input file is " + txt_file)
    with open(txt_file, "r") as f:
        len_traj = ""
        time_step = ""
        for li in f:
            li = li.rstrip()
            if li.startswith("Step"):
                li2 = li.split()
                len_traj = int(li2[1])
                time_step = int(li2[2])
    if args.verbose:
        logging.info("The length of the trajectory is " + str(len_traj))
        logging.info(
            "The elapsed time between two steps is : " + str(time_step) + " ps"
        )
        logging.info("End search_nbr_steps_time_step functions")
    return len_traj, time_step


def search_nbr_sub_traj(tsv_nb_sub_traj_file):
    """
    Description: Obtaining the number of frames of the complete trajectory,
      the number of sub-trajectories to create, the time between each
      frame and the number of frames per sub-trajectory.
    param tsv_nb_sub_traj_file: tsv file obtained with the estimation of
      the number of sub-trajectories
    return: list that contains the number of images of the complete trajectory,
      the number of sub-trajectories to create, the time between each image and
      the number of images per sub-trajectory.
    """
    if args.verbose:
        logging.info("\nFunction search_nbr_sub_traj")
        logging.info("The input file is " + tsv_nb_sub_traj_file)
    list_number_sub_traj = []
    with open(tsv_nb_sub_traj_file, "r") as f:
        for li in f:
            li = li.rstrip()
            if not li.startswith("Length"):
                list_number_sub_traj = li.split()
    if args.verbose:
        logging.info("The length of complete trajectory is "
                     + list_number_sub_traj[0])
        logging.info(
            "The number of sub-trajectories required is  "
            + list_number_sub_traj[1]
        )
        logging.info(
            "The elapsed time between two steps is : "
            + list_number_sub_traj[2] + "ps"
        )
        logging.info(
            "The number of frames per sub-trajectory is "
            + list_number_sub_traj[3]
        )
        logging.info("End search_nbr_sub_traj function")
    list_number_sub_traj = [int(v) for v in list_number_sub_traj]
    return list_number_sub_traj


def launch_cut_traj(list_file, gro_file, xtc_file, out_dir,
                    logging_file, n_group, n_cpus):
    """
    Description: function that allows to split the trajectory
      into a number of sub-trajectories
    param list_file: list that contains the number of images of the complete
      trajectory, the number of sub-trajectories to create, the time between
      each image and the number of images per sub-trajectory.
    param gro_file: .gro file
    param xtc_file: .xtc file
    param out_dir: output directory
    param logging_file: name of log file
    return: None
    output: create a number of sub-trajectories
    """
    if args.verbose:
        logging.info("\nFunction cut_traj")
        logging.info("The length of the trajectory is " + str(list_file[0])
                     + " frames")
        logging.info(
            "The elapsed time between two steps is : " + str(list_file[2])
            + " ps"
        )
        logging.info("The number of sub-trajectories required is  "
                     + str(list_file[1]))
        logging.info(".gro file is " + gro_file)
        logging.info(".xtc file is " + xtc_file)
        logging.info("Output directory is " + str(out_dir))
        logging.info("The name of .log file is " + logging_file)
        logging.info("The number of cpus used is " + str(n_cpus))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    n_sub_traj = list_file[1]
    times_step = list_file[2]
    nb_frame_sub_traj = list_file[3]
    prefix_name_file = os.path.basename(gro_file).rsplit(".", 1)[0]
    start = list_file[4]
    end = list_file[4] - 1
    dict_sub_traj = {}
    for nb_traj in range(n_sub_traj):
        end += nb_frame_sub_traj
        value = str(start * times_step) + "," + str(end * times_step)
        dict_sub_traj[nb_traj + 1] = value
        start += nb_frame_sub_traj
    if args.verbose:
        for k, v in dict_sub_traj.items():
            logging.info("Sub_trajectory " + str(k) + " starts at "
                         + v.split(",")[0] + " ps and ends at "
                         + v.split(",")[1] + " ps")
        logging.info("Launch gmx trjconv\n")
    if args.verbose:
        list_log_files = Parallel(n_jobs=n_cpus)(
            delayed(cut_traj)(out_dir, prefix_name_file, k, v, n_group,
                              xtc_file, gro_file, logging_file, args)
            for k, v in dict_sub_traj.items())
        log_file_complete = open(logging_file, "a")
        for f in list_log_files:
            open_f = open(f, "r")
            [log_file_complete.write(li) for li in open_f.readlines()]
            open_f.close()
            os.remove(f)
        log_file_complete.close()
    else:
        logging_file = ""
        Parallel(n_jobs=n_cpus)(delayed(cut_traj)(out_dir, prefix_name_file, k,
                                                  v, n_group, xtc_file,
                                                  gro_file, logging_file, args)
                                for k, v in dict_sub_traj.items())
    if args.verbose:
        logging.info("End cut_traj function")


def cut_traj(out_dir, prefix_name_file, k, v, n_group, xtc_file, gro_file,
             logging_file, arguments):
    out_traj = (str(out_dir) + prefix_name_file + "_traj_" + str(k) + ".xtc")
    bash_command = ("echo " + n_group + " | gmx trjconv -f " + xtc_file
                    + " -s " + gro_file + " -b " + str(v.split(",")[0])
                    + " -e " + str(v.split(",")[1]) + " -o " + out_traj)
    if arguments.verbose:
        log_directory = "log/tmp/"
        logging_file = (log_directory
                        + logging_file.rsplit("/", 1)[1].split(".")[0]
                        + "_" + str(k) + ".log")
        try:
            os.makedirs(log_directory, exist_ok=True)
        except OSError:
            print("Directory '%s' can not be created")
        f_log = open(logging_file, "w")
        subprocess.run(bash_command, shell=True, stdout=f_log, stderr=f_log)
        f_log.close()
        return logging_file
    else:
        subprocess.run(bash_command, shell=True)


if __name__ == "__main__":
    args = parse_arguments()
    log_file = ""
    if args.output_directory.endswith("/"):
        out_directory = args.output_directory
    else:
        out_directory = args.output_directory + "/"
    if args.verbose:
        if "/" in args.log_output:
            log_dir = args.log_output.rsplit("/", 1)[0]
        else:
            log_dir = "log/"
        log_file = args.log_output
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        if os.path.isfile(log_file):
            os.remove(log_file)
        if args.log_output:
            logging.basicConfig(
                filename=log_file,
                format="%(levelname)s - %(message)s",
                level=logging.INFO,
            )
        else:
            logging.basicConfig(
                filename=log_file,
                format="%(levelname)s - %(message)s",
                level=logging.INFO,
            )
        logging.info("verbose mode on")
    gro = ""
    xtc = ""
    list_nbr_sub_traj = []
    if args.verbose:
        logging.info("Start cut trajectory")
    if args.gro_file and ".gro" in args.gro_file:
        gro = os.path.basename(args.gro_file)
    else:
        if args.verbose:
            logging.error("The entry is not a file or is not a .gro file")
        sys.exit(1)
    if args.xtc_file and ".xtc" in args.xtc_file:
        xtc = os.path.basename(args.xtc_file)
    else:
        if args.verbose:
            logging.error("The entry is not a file or is not a .xtc file")
        sys.exit(1)
    if ".tsv" in args.nbr_sub_traj:
        list_nbr_sub_traj = search_nbr_sub_traj(args.nbr_sub_traj)
        nb_sub_traj = list_nbr_sub_traj[1]
    else:
        nb_sub_traj = int(args.nbr_sub_traj)
        nb_steps_time_step = search_nbr_steps_time_step(
            args.input_check)
        if args.start_traj and args.start_traj != "":
            start_trajectory = int(args.start_traj)
        else:
            start_trajectory = 0
        if args.end_traj and args.end_traj != "":
            end_trajectory = int(args.end_traj)
        else:
            end_trajectory = nb_steps_time_step[0] - 1
        len_trajectory = end_trajectory - start_trajectory + 1
        list_nbr_sub_traj = [
            len_trajectory,
            nb_sub_traj,
            nb_steps_time_step[1],
            len_trajectory // nb_sub_traj,
            start_trajectory,
            end_trajectory,
        ]
    if nb_sub_traj >= (int(list_nbr_sub_traj[0]) // 2):
        if args.verbose:
            logging.error(
                "The number of averages requested is greater than "
                "half the trajectory size"
            )
        print(
            "Number of requested sub-trajectories too large compared "
            "to the size of the trajectory"
        )
        sys.exit(2)
    nb_cpus = args.number_cpus
    launch_cut_traj(
        list_nbr_sub_traj, gro, xtc, out_directory, log_file,
        args.group_output, nb_cpus
    )
