#!/usr/bin/env python

from __future__ import print_function

import argparse
import csv
import errno
import json
import os
import re
import shutil
import sys

from pprint import pprint

MOB_TYPER_FIELDNAMES = [
        "file_id",
        "num_contigs",
        "total_length",
        "gc",
        "rep_type(s)",
        "rep_type_accession(s)",
        "relaxase_type(s)",
        "relaxase_type_accession(s)",
        "mpf_type",
        "mpf_type_accession(s)",
        "orit_type(s)",
        "orit_accession(s)",
        "PredictedMobility",
        "mash_nearest_neighbor",
        "mash_neighbor_distance",
        "mash_neighbor_cluster",
        "NCBI-HR-rank",
        "NCBI-HR-Name",
        "LitRepHRPlasmClass",
        "LitPredDBHRRank",
        "LitPredDBHRRankSciName",
        "LitRepHRRankInPubs",
        "LitRepHRNameInPubs",
        "LitMeanTransferRate",
        "LitClosestRefAcc",
        "LitClosestRefDonorStrain",
        "LitClosestRefRecipientStrain",
        "LitClosestRefTransferRate",
        "LitClosestConjugTemp",
        "LitPMIDs",
        "LitPMIDsNumber",
]

def parse_mob_typer_report(mob_typer_report_path):
    mob_typer_report = []

    with open(mob_typer_report_path) as f:
        reader = csv.DictReader(f, delimiter="\t", quotechar='"', fieldnames=MOB_TYPER_FIELDNAMES)
        for row in reader:
            mob_typer_report.append(row)
    return mob_typer_report

def parse_genbank_accession(genbank_file_path):
    with open(genbank_file_path, 'r') as f:
        while True:
            line = f.readline()
            # break while statement if it is not a comment line
            # i.e. does not startwith #
            if line.startswith('ACCESSION'):
                return line.strip().split()[1]


def count_contigs(plasmid_fasta_path):
    contigs = 0
    with open(plasmid_fasta_path, 'r') as f:
        for line in f:
            if line.startswith('>'):
                contigs += 1
    return contigs

def count_bases(plasmid_fasta_path):
    bases = 0
    with open(plasmid_fasta_path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line.startswith('>'):
                bases += len(line)
    return bases

def main(args):
    # create output directory
    try:
        os.mkdir(args.outdir)
    except OSError as exc:
        if exc.errno == errno.EEXIST and os.path.isdir(args.outdir):
            pass
        else:
            raise

    # parse mob_typer report
    mob_typer_report = parse_mob_typer_report(args.mob_typer_report)
    num_plasmid_contigs = count_contigs(args.plasmid)
    num_plasmid_bases = count_bases(args.plasmid)

    with open(os.path.join(args.outdir, 'mob_typer_record.tsv'), 'w') as f:
        mob_typer_record_writer = csv.DictWriter(f, delimiter="\t", quotechar='"', fieldnames=MOB_TYPER_FIELDNAMES)
        mob_typer_record_writer.writeheader()
        for record in mob_typer_report:
            if num_plasmid_contigs == int(record['num_contigs']) and num_plasmid_bases == int(record['total_length']):
                for reference_plasmid in args.reference_plasmids:
                    if parse_genbank_accession(reference_plasmid) == record['mash_nearest_neighbor']:
                        shutil.copy2(reference_plasmid, os.path.join(args.outdir, "reference_plasmid.gbk"))
                        mob_typer_record_writer.writerow(record)

    shutil.copy2(args.plasmid, os.path.join(args.outdir, "plasmid.fasta"))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--plasmid", help="plasmid assembly (fasta)")
    parser.add_argument("--reference_plasmids", nargs='+', help="reference plasmids (genbank)")    
    parser.add_argument("--mob_typer_report", help="mob_typer reports (tsv)")
    parser.add_argument("--outdir", dest="outdir", default=".", help="Output directory")
    args = parser.parse_args()
    main(args)
