#ifndef SAM_RATIO_H
#define SAM_RATIO_H

#include <iostream>
#include <fstream>
#include <algorithm>

#include "FastaRecord.h"
#include "Alignments.h"
#include "args.h"

typedef std::vector<std::pair<int,char>> cigar_str;

/**
 *
 */
struct header {
	std::string level =             "Level";
        std::string iteration =         "Iteration";
        std::string gene_id =           "Gene id";
        std::string gene_fraction =     "Gene Fraction";
        std::string hits =              "Hits";
};

/**
 * Reports the total number of bases that were touched for each
 * gene. This is largely calculated using the positional and seq
 * information found in fields four and ten of each alignment
 */
void analyze_coverage(FastaRecord &record, Alignments &alignment) {
	record.update_gene_hits();

	cigar_str cigar = alignment.cigar();

	int len;
	char op;

	int occurrence;
	int pos_in_gene = alignment.pos();

	int start, stop;
	int base_hits = record._base_hits.size(); // func this
	int read_length = alignment.seq().length(); //func this

	if(pos_in_gene == 1) {
		occurrence = 0;
		for(int i = 0; i < cigar.size(); i++) {
			len = cigar[i].first;
			op = cigar[i].second;

			switch(op) {
				case 'M':
					occurrence += len;
					break;
				case 'I':
					occurrence += len;
					break;
				default:
					break;
			}
		}

		start = read_length - occurrence;
		stop = start + read_length;

		for(int i = start; i < base_hits; i++) {
			if(i == stop) break;
			record._base_hits[i] = 1;
		}
	}
	else {
		start = pos_in_gene - 1;
		stop = start + read_length;

		for(int i = start; i < base_hits; i++) {
			if(i == stop) break;
			record._base_hits[i] = 1;
		}
	}
}

/**
 * Returns gene fraction of fasta record
 * Returns -1 if gene fraction is not greater than threshold
 */
double coverage(const FastaRecord &record, const int &threshold) {
        double gene_coverage;

        int base_hits, gene_size;

        base_hits = record.get_base_hits();
        gene_size = record.gene().length();

        gene_coverage = (static_cast<double>(base_hits)/static_cast<double>(gene_size))*100;

        if(gene_coverage > threshold)
                return gene_coverage;
        return -1;
}

/**
 * Writes header to output file when
 * reading from stdin
 */
void bam_stream_header() {
	header h;
	char sep = ',';

	std::cout << h.level << sep << h.iteration << sep 
                  << h.gene_id << sep << h.gene_fraction << sep 
                  << h.hits << '\n';
}

/**
 * Writes header to output file when
 * reading from sam file
 */
void file_header(const std::string &out_fp, const std::string &sam_fp) {
	header h;
	std::ofstream out(out_fp.c_str(), std::ofstream::app );
	char sep = ',';

	out << "@" << sam_fp << '\n';
	out << h.level << sep << h.iteration << sep 
            << h.gene_id << sep << h.gene_fraction << sep 
            << h.hits << '\n';
	out.close();
}

/**
 *
 */
void create_header(cmd_args &args) {
	if(args.bam_stream) {
		bam_stream_header();
	}
	else {
		file_header(args.out_fp, args.sam_fp);
	}
}

/**
 * Writes results to output file when reading from
 * stdin
 */
void bam_stream_results(std::vector<FastaRecord> &records,
                        const int &level, const int &iteration,
                        cmd_args &args) {

	double gene_fraction;
	int hits_seen;
	std::string id;
	char sep = ',';

	for(auto &rec : records) {
		gene_fraction = coverage(rec, args.threshold);
		hits_seen = rec.gene_hits();
		id = rec.gene_id();

		if(gene_fraction > 0) {
			std::cout << level << sep << iteration << sep
			          << id << sep << gene_fraction << sep
			          << hits_seen << '\n';			
		}
	}
}

/**
 * Write results when reading sam file from
 * path
 */
void file_results(std::vector<FastaRecord> &records,
                  const int level, const int &iteration,
                  cmd_args &args) {

	std::ofstream out(args.out_fp.c_str(), std::ofstream::app);
	
	double gene_fraction;
	int hits_seen;
	std::string id;
	char sep = ',';

	for(auto &rec : records) {
		gene_fraction = coverage(rec, args.threshold);
		hits_seen = rec.gene_hits();
		id = rec.gene_id();

		if(gene_fraction > 0) {
			out << level << sep << iteration << sep
			    << id << sep << gene_fraction << sep
                            << hits_seen << '\n';
		}
	}
	out.close();
}

/**
 *
 */
void report_results(std::vector<FastaRecord> &records,
		    const int level, const int &iteration,
		    cmd_args &args) {

	if(args.bam_stream) {
		bam_stream_results(records,level,iteration,args);
	}
	else {
		file_results(records,level,iteration,args);
	}
}

/**
 * Generates a sequence of samples from sam file specified
 * by the starting level, max level, and skip pattern
 */
void generate_samples(std::vector<FastaRecord> &records,
                      std::vector<Alignments> &alignments,
		      cmd_args &args) {

	int read_count = alignments.size();
	int sample_size;

	srand(unsigned(time(0)));

	std::vector<int> sequence(read_count);
	iota(sequence.begin(), sequence.end(), 0);

	create_header(args);	

	for(int level = args.min; level <= args.max; level += args.skip) {
		for(int sample = 0; sample < args.s_per_lev; sample++) {
			random_shuffle(sequence.begin(), sequence.end(), randomize);
			sample_size = round(((static_cast<double>(level)/100)*read_count));
			std::vector<int> chosen(sequence.begin(), sequence.begin()+sample_size);

			for(int a_idx = 0; a_idx < chosen.size(); a_idx++) {
				std::string rname = alignments[chosen[a_idx]].rname();
				int gene_idx = FastaRecord::find_gene(records, rname);	
				analyze_coverage(records[gene_idx], alignments[chosen[a_idx]]);
			}
			report_results(records,level,sample+1,args);
			FastaRecord::reset_base_hits(records);
			FastaRecord::reset_gene_hits(records);
		}
	}
}

#endif /* SAM_RATIO_H */
