#!/usr/bin/env python

import sys
import csv
import re
from argparse import ArgumentParser


class FloatSplitter:
	def __init__(self, name, column, splitval):
		self.name = name
		self.column = column
		self.splitval = splitval

	def split(self, val):
		return float(val) > self.splitval

	def __repr__(self):
		return self.name

	def __repr__(self):
		return self.name

class SetMember:
	def __init__(self, name, column, label):
		self.name = name
		self.column = column
		self.label = label

	def split(self, val):
		return val == self.label

	def __str__(self):
		return self.name

	def __repr__(self):
		return self.name


class SplitGenerator:
	def __init__(self, args, name, values):
		self.args = args
		self.name = name
		self.values = values

	def float_values(self):
		fvals = []
		for v in self.values[i]:
			if v not in NA_SET:
				t = float(v)
				fvals.append( float(v) )
		return fvals

	def value_set(self):
		val_set = {}		
		for v in self.values:
			val_set[v] = True
		return val_set


	def is_valid(self):
		try:
			isFloat = True
			t = self.float_values()
		except ValueError:
			isFloat = False
		return isFloat

class EnumerationGen(SplitGenerator):

	def is_valid(self):
		v = self.value_set()
		if len(v) < self.args.max and len(v) > 1:
			return True
		return False

	def __iter__(self):
		"""
		if len(vals) == 2:
			v = vals.keys()[0]
			t_new.append(SetMember( "%s:label=%s" % (header[i], v), i, v ))
		else:
		"""
		for v in self.value_set():
			yield SetMember( "%s:label=%s" % (self.name, v), i, v )

class MeanGen(SplitGenerator):

	def __iter__(self):
		fvals = self.float_values()
		yield FloatSplitter( "%s:mean" % (self.name), i, sum(fvals) / float(len(fvals)) ) 

class MedianGen(SplitGenerator):

	def __iter__(self):
		fvals = self.float_values()
		fvals.sort()
		median = (fvals[len(fvals)/2]+fvals[(len(fvals)/2)-1])/2.0
		yield FloatSplitter("%s:median" % (self.name), i, median) 

class_map = {
	"enumerate" : EnumerationGen,
	"mean" : MeanGen,
	"median" : MedianGen
}


NA_SET = [ 'nan', 'na', 'n/a', '' ]

if __name__ == "__main__":
	parser = ArgumentParser()
	parser.add_argument("src", help="Source Phenotype Matrix", default=None)
	parser.add_argument("-p", "--pos", help="Positive Set", default="1")
	parser.add_argument("-n", "--neg", help="Negative Set", default="0")
	parser.add_argument("-m", "--max", help="Max group dichotomizations", type=int, default=3)
	parser.add_argument("-g", "--group-min", dest="min", help="Min members of group", type=int, default=1)
	parser.add_argument("-s", "--script-file", dest="script_file", help="Split Script", default=None)

	args = parser.parse_args()


	handle = open(args.src)
	reader = csv.reader(handle, delimiter="\t")
	header = None
	colMap = None
	colVals = None
	rowCount = 0
	for row in reader:
		if header is None:
			header = {}
			isFloat = {}
			colVals = {}
			colMap = {}
			for i, a in enumerate(row):
				header[i] = a
				colMap[a] = i
				isFloat[i] = True
				colVals[i] = []
		else:
			rowCount += 1
			for col in header:
				colVals[col].append(row[col])
	handle.close()

	start_splitter = []
	if args.script_file is None:
		for i in range(1,len(header)):

			mean = MeanGen(args, header[i], colVals[i])

			if mean.is_valid():
				for split in mean:
					start_splitter.append(split)
			else:
				enum = EnumerationGen(args, header[i], colVals[i])
				if enum.is_valid():
					for split in enum:
						start_splitter.append(split)

	else:
		handle = open(args.script_file)
		for line in handle:
			row = line.rstrip().split(",")
			if len(row) == 2:
				cls = class_map[row[0]]
				if row[1] in colMap:
					index = colMap[row[1]]
				else:
					res = re.search(r'^c(\d+)', row[1])
					if res:
						index = int(res.group(1))
				obj = cls(args, header[index], colVals[index])
				for split in obj:
					start_splitter.append(split)

	theories = []
	for t in start_splitter:
		set_count = { True : 0, False : 0}
		for i in range(rowCount):
			if colVals[t.column][i] not in NA_SET:
				v = t.split(colVals[ t.column ][ i ])
				set_count[v] = set_count[v] + 1
		if min(set_count.values()) >= args.min:
			theories.append(t)


	writer = csv.writer( sys.stdout, delimiter="\t", lineterminator="\n" )
	out = ["sample"]
	for t in theories:
		out.append( str(t) )
	writer.writerow(out)

	for i in range(rowCount):
		out = [colVals[0][i]]
		for t in theories:
			if colVals[t.column][i] not in NA_SET:
				if t.split(colVals[ t.column ][ i ]):
					out.append( args.pos )
				else:
					out.append( args.neg )
			else:
				out.append("")
		writer.writerow(out)




