import argparse
import csv
import os
import sys

from fasta_extract_utils import Fasta


def reverse_complement(bases):
    complements = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
    return ''.join(complements[b.upper()] for b in reversed(bases))


def get_output_path(hid, subtract_from_start, add_to_end, extend_existing, consider_strand, orphan=False):
    attrs = 'u%dd%d' % (subtract_from_start, add_to_end)
    if extend_existing:
        attrs += 'x'
    if consider_strand:
        attrs += '_s'
    if orphan:
        attrs += '_orphan'
        format = 'gff'
        output_dir = 'output_orphan_dir'
    else:
        format = 'fasta'
        output_dir = 'output_dir'
    return os.path.join(output_dir, '%s_on_data_%s.%s' % (attrs, hid, format))


def stop_err(msg):
    sys.stderr.write(msg)
    sys.exit(1)


parser = argparse.ArgumentParser()
parser.add_argument('--input', dest='inputs', action='append', nargs=2, help="Input datasets")
parser.add_argument('--genome_file', dest='genome_file', help='Reference genome fasta index file.')
parser.add_argument('--subtract_from_start', dest='subtract_from_start', type=int, help='Distance to subtract from start.')
parser.add_argument('--add_to_end', dest='add_to_end', type=int, help='Distance to add to end.')
parser.add_argument('--extend_existing', dest='extend_existing', help='Extend existing start/end rather or from computed midpoint.')
parser.add_argument('--strand', dest='strand', help='Consider strandedness: reverse complement extracted sequence on reverse strand.')
args = parser.parse_args()

fasta = Fasta(args.genome_file)

for (input_filename, hid) in args.inputs:
    extend_existing = args.extend_existing == 'existing'
    consider_strand = args.strand == 'yes'

    reader = csv.reader(open(input_filename, 'rU'), delimiter='\t')
    fasta_output_path = get_output_path(hid,
                                        args.subtract_from_start,
                                        args.add_to_end,
                                        extend_existing,
                                        consider_strand)
    output = open(fasta_output_path, 'wb')
    gff_output_path = get_output_path(hid,
                                      args.subtract_from_start,
                                      args.add_to_end,
                                      extend_existing,
                                      consider_strand,
                                      orphan=True)
    orphan_writer = csv.writer(open(gff_output_path, 'wb'), delimiter='\t')

    for row in reader:
        if len(row) != 9 or row[0].startswith('#'):
            continue
        cname = row[0]
        start = int(row[3])
        end = int(row[4])
        strand = row[6]
        if extend_existing:
            start -= args.subtract_from_start
            end += args.add_to_end
        else:
            midpoint = (start + end) // 2
            start = midpoint - args.subtract_from_start
            end = midpoint + args.add_to_end
        if 1 <= start and end <= len(fasta[cname]):
            output.write('>%s:%s-%s_%s\n' % (cname, start, end, strand))
            bases = fasta[cname][start-1:end]
            if consider_strand and strand == '-':
                bases = reverse_complement(bases)
            output.write('%s\n' % bases)
        else:
            orphan_writer.writerow(row)
    output.close()
