#!/usr/bin/env python3

# coding: utf-8
"""
The script allows to estimate the number of sub-trajectories to obtain.
It also allows to split the trajectory into a number of sub-trajectories.
# USAGE : estimate_nb_sub_trajectories.py -c : file obtaining with gmx check
  -log : name of log file (optional) -d : output directory (optional)
  -o : name of output file (optional)
  -f : desired number of frames per sub-trajectory (optional)
  -start : start time of the trajectory (optional)
  -end : end time of the trajectory (optional)
"""

__all__ = []
__author__ = "Agnès-Elisabeth Petit"
__date__ = "01/06/2022"
__version__ = "1.0"
__copyright__ = "(c) 2022 CC-BY-NC-SA"

# Library import
import argparse
import logging
import os
import sys

import numpy as np


def test_setup():
    global args
    args = parse_arguments()
    args.verbose = True


def parse_arguments():
    parser = argparse.ArgumentParser(
        description="The script allows to estimate the number of "
                    "sub-trajectories to obtain. It also allows to split"
                    " the trajectory into a number of sub-trajectories.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        prefix_chars="-",
        add_help=True,
    )
    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        help="""Information messages to stderr""",
    )
    parser.add_argument(
        "-c",
        "--input_check",
        type=str,
        nargs=1,
        help=""".txt file obtained with gmx check -f.
        It contains information about the trajectory""",
    )
    parser.add_argument(
        "-log", "--log_output", type=str,
        default="log/estimate_nb_sub_trajectories.log",
        help="""Output for log file. Default :
        log/estimate_nb_sub_trajectories.log"""
    )
    parser.add_argument(
        "-d", "--output_directory", type=str, nargs=1,
        default="./",
        help="""It's output Directory. Default : ./"""
    )
    parser.add_argument(
        "-o",
        "--output_file",
        type=str,
        default="estimated_number_of_sub_trajectories.tsv",
        help="""Output file. Default :
         estimated_number_of_sub_trajectories.tsv""",
    )
    parser.add_argument(
        "-f",
        "--nb_frames",
        type=int,
        default=30,
        help="""Number of frames per sub-trajectory. Default : 30""",
    )
    parser.add_argument(
        "-start", "--start_traj", type=str,
        help="""Start of the trajectory to be cut"""
    )
    parser.add_argument(
        "-end", "--end_traj", type=str,
        help="""End of the trajectory to be cut"""
    )
    return parser.parse_args()


def search_nbr_steps_time_step(txt_file):
    """
    Description : Keeping the number of frames of the complete trajectory and
     the time between each frame.
    param txt_file: file obtained with gmx check.
    return: list that contains the number of frames of the complete trajectory
     and time between each frame.
    """
    if args.verbose:
        logging.info("\nFunction search_nbr_steps_time_step")
        logging.info("The input file is " + txt_file)
    with open(txt_file, "r") as f:
        len_traj = ""
        time_step = ""
        for li in f:
            li = li.rstrip()
            if li.startswith("Step"):
                li2 = li.split()
                len_traj = int(li2[1])
                time_step = int(li2[2])
    if args.verbose:
        logging.info("The length of the trajectory is " + str(len_traj))
        logging.info(
            "The elapsed time between two steps is : " + str(time_step) + " ps"
        )
        logging.info("End search_nbr_steps_time_step functions")
    return len_traj, time_step


def estimate_nbr_sub_trajectories(nbr_step_time_step, nbr_frames_traj,
                                  out_file):
    """
    Description: Creation of a tsv file that contains the number of frames
      of the complete trajectory, the number of sub-trajectories to create,
      the duration between each frame and the number of frames
      per sub-trajectory.
    param nbr_step_time_step: list which contains the number of frames of
      the complete trajectory and the time between each frame
    param nbr_frames_traj: number of frames per sub-trajectory
    param out_file: output file name
    return: list that contains the number of frames of the complete trajectory,
      the number of sub-trajectories to create, the duration between
      each frame and the number of frames per sub-trajectory.
    output: tsv file
    """
    if args.verbose:
        logging.info("\nFunction estimate_nbr_means")
        logging.info("The length of the trajectory is " +
                     str(nbr_step_time_step[0]))
        logging.info(
            "The elapsed time between two steps is : "
            + str(nbr_step_time_step[1])
            + " ps"
        )
        logging.info("The output file is  " + str(out_file))
    name_columns = [
        "Length_trajectory (frames)",
        "Number_sub_trajectories",
        "Time_steps (ps)",
        "Number_frames_per_sub_trajectory",
        "Start_trajectory (frames)",
        "End_trajectory (frames)",
    ]
    time_step = nbr_step_time_step[1]
    if nbr_step_time_step[2] is not None:
        start_traj = nbr_step_time_step[2]
    else:
        start_traj = 0
    if nbr_step_time_step[3] is not None:
        end_traj = nbr_step_time_step[3]
    else:
        end_traj = nbr_step_time_step[0] - 1
    if args.verbose:
        logging.info(
            "The first frame of the trajectory is the number "
            + str(start_traj)
        )
        logging.info("The first frame of the trajectory is the number "
                     + str(end_traj))
    len_traj = end_traj - start_traj + 1
    n_sub_traj = len_traj // nbr_frames_traj
    if args.verbose:
        logging.info("The estimated number of sub-trajectories is : "
                     + str(n_sub_traj))
    list_values = [
        str(len_traj),
        str(n_sub_traj),
        str(time_step),
        str(nbr_frames_traj),
        str(start_traj),
        str(end_traj),
    ]
    tab_values = np.asarray([name_columns, list_values])
    np.savetxt(out_file, tab_values, delimiter="\t", fmt="%s")
    if args.verbose:
        logging.info("Save table in the file : " + str(out_file))
        logging.info("End estimate_nbr_sub_trajectories function")


if __name__ == "__main__":
    args = parse_arguments()
    if args.output_directory[0].endswith("/"):
        out_directory = args.output_directory[0]
    else:
        out_directory = args.output_directory[0] + "/"
    if args.verbose:
        if "/" in args.log_output:
            log_dir = args.log_output.rsplit("/", 1)[0]
        else:
            log_dir = "log/"
        log_file = args.log_output
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        if os.path.isfile(log_file):
            os.remove(log_file)
        if args.log_output:
            logging.basicConfig(
                filename=log_file,
                format="%(levelname)s - %(message)s",
                level=logging.INFO,
            )
        else:
            logging.basicConfig(
                filename=log_file,
                format="%(levelname)s - %(message)s",
                level=logging.INFO,
            )
        logging.info("verbose mode on")
    nb_frames_traj = args.nb_frames
    if args.start_traj and args.start_traj != "":
        start_trajectory = int(args.start_traj)
    else:
        start_trajectory = None
    if args.end_traj and args.end_traj != "":
        end_trajectory = int(args.end_traj)
    else:
        end_trajectory = None
    if args.verbose:
        logging.info("Start estimate number sub-trajectories")
    output_file = args.output_file
    if not args.input_check:
        print("Please enter the file created with 'gmx_check -f'")
        if args.verbose:
            logging.error("Please enter the file created with 'gmx_check -f'")
        sys.exit(1)
    nb_steps_time_step = list(search_nbr_steps_time_step(
        args.input_check[0]))
    nb_steps_time_step.append(start_trajectory)
    nb_steps_time_step.append(end_trajectory)
    output_file = out_directory + output_file
    estimate_nbr_sub_trajectories(nb_steps_time_step, nb_frames_traj,
                                  output_file)
