#!/usr/bin/env python

import sys
import csv
import argparse


class IDDag:    
    def __init__(self):
        self.graph = None
    
    def load(self, path):
        handle = open(path)
        self.parents = {}
        for line in handle:
            tmp = line.rstrip().split("\t") 
            if tmp[0] not in self.parents:
                self.parents[tmp[0]] = [ tmp[1] ]
            else:
                self.parents[tmp[0]].append( tmp[1] )
                
        handle.close()
    
    def get_key_list(self):
        return self.parents.keys()
    
    def get_by(self, key):
        return self.parents[key]
    
    def _build_graph(self):
        self.graph = {}    
        self.rev_graph = {}    
        for pid in self.get_key_list():
            if pid not in self.graph:
                self.graph[pid] = {}
            p = self.get_by(pid)
            for cid in p:
                self.graph[pid][cid.child] = True
                if cid.child not in self.rev_graph:
                    self.rev_graph[cid.child] = {}
                self.rev_graph[cid.child][cid.id] = True
                
    def is_descendant(self, parent, child):
        if self.graph is None:
            self._build_graph()
        
        cid = child
        while cid in self.rev_graph:
            cid = self.rev_graph[cid].keys()[0]
            if cid == parent:
                return True
        return False

    def _desc_crawl(self, parent):
        out = {}
        if parent in self.graph:
            for node in self.graph[parent]:
                if node is not None and len(node) and node != parent:
                    out[node] = True
                    for c in self._desc_crawl(node):
                        out[c] = True
        return out.keys()

    
    def get_descendants(self, parent):
        if self.graph is None:
            self._build_graph()
        return self._desc_crawl(parent)
    
    def get_children(self, node):
        if self.graph is None:
            self._build_graph()
        if node in self.graph:
            return self.graph[node]
        return []
    
    def get_parents(self, node):
        if self.graph is None:
            self._build_graph()
        if node in self.rev_graph:
            return self.rev_graph[node]
        return []
        
    def in_graph(self, name):
        if self.graph is None:
            self._build_graph()
        
        if name in self.graph or name in self.rev_graph:
            return True
        return False


class IDReducer(object):
    """
    The IDReducer class uses an IDDag to 'reduce' id's and objects to
    common parent objects.
    
    Assume Matrix 1 has aliquot ids like
        - sample1-aliquot1 
        - sample2-aliquot1 
        - sample2-aliquot1 
    And that Matrix 1 has aliquot ids like
        - sample1-aliquot2 
        - sample2-aliquot2 
        - sample2-aliquot2 
    
    Both files deal with the same samples, but different aliquots were 
    ran on different machines, producing matrices of different datatypes.
    But for data integration perposes, we need to refer to aliquots by their
    parent sample name.
    
    The idDag file for this data would be::
        
        sample1 sample1-aliquot1 
        sample1 sample1-aliquot2 
        sample2 sample2-aliquot1 
        sample2 sample2-aliquot2 
        sample3 sample2-aliquot1 
        sample3 sample2-aliquot2 

    If this file was loaded into an idDag class, and used to initialize an IDReducer
    the following transformatins would be valid::
        
        > idReducer.reduce_id( 'sample1-aliquot1' )
        'sample1'
        > idReducer.reduce_id( 'sample1-aliquot2' )
        'sample1'
    
    """
    def __init__(self, idDag):
        self.revGraph = {}
        for pid in idDag.get_key_list():
            p = idDag.get_by(pid)
            for cid in p:
                if cid.child not in self.revGraph:
                    self.revGraph[cid.child] = {}
                self.revGraph[cid.child][cid.id] = cid.edgeType

    def reduce_id(self, id, edgeStop=None):
        outID = id
        while outID in self.revGraph:
            pn = None
            for p in self.revGraph[outID]:
                if edgeStop is None or edgeStop != self.revGraph[outID][p]:
                    pn = p
            if pn is None:
                return outID
            outID = pn
        return outID
    
    def reduce_matrix(self, matrix, edgeStop=None):
        ncols = {}
        rmap = {}
        for col in matrix.get_col_list():
            rval = self.reduce_id(col, edgeStop)
            if rval not in ncols:
                ncols[rval] = []
            ncols[rval].append(col)
            rmap[col] = rval
        out = CGData.GenomicMatrix.GenomicMatrix()
        out.init_blank( cols=ncols.keys(), rows=matrix.get_row_list() )
        for row in matrix.get_row_list():
            for col in ncols:
                tmp = []
                for nc in ncols[col]:
                    tmp.append( matrix.get_val( col_name=nc, row_name=row ) )
                v = sum(tmp) / float(len(tmp))
                out.set_val(row_name=row, col_name=col, value=v)
        return out

class IDExpander(object):

    def __init__(self, idDag):
        self.expGraph = {}
        for pid in idDag.get_key_list():
            p = idDag.get_by(pid)
            if pid not in self.expGraph:
                self.expGraph[pid] = []
            for cid in p:
                self.expGraph[pid].append(cid)
    
    def expand_id(self, id, leaf_only=False):
        out = {}
        if id not in self.expGraph or len(self.expGraph[id]) == 0:
            return [id]
        
        for c in self.expGraph[id]:
            if not leaf_only:
                out[c] = True
            for gc in self.expand_id(c, leaf_only):
                out[gc] = True
        return out.keys()
    
    def expand_matrix(self, matrix, leaf_only=False):
        nrows = {}
        for row in matrix.get_row_list():
            #if row in self.expGraph:
            for e_val in self.expand_id(row, leaf_only):
                if e_val not in nrows:
                    nrows[e_val] = []
                nrows[e_val].append(row)
        out = ClinicalMatrix()
        out.init_blank( rows=sorted(nrows.keys()), cols=matrix.get_col_list() )
        
        for row in nrows.keys():
            for pid in nrows[row]:
                for col in matrix.get_col_list():
                    out.set_val( row_name=row, col_name=col, value=matrix.get_val(row_name=pid, col_name=col))
        
        #print nrows
        return out
    

class ClinicalMatrix:
    corner_name = "sample"
    def load(self, path):
        self.col_map = {}
        self.row_map = {}    
        pos_hash = None

        handle = open(path)
        
        self.matrix = []
        for row in csv.reader(handle, delimiter="\t"):
            if pos_hash is None:
                pos_hash = {}
                pos = 0
                for name in row[1:]:
                    i = 1
                    orig_name = name
                    while name in pos_hash:
                        name = orig_name + "#" + str(i)
                        i += 1
                    pos_hash[name] = pos
                    pos += 1
            else:
                newRow = []
                newRow = [""] * (len(pos_hash))
                for col in pos_hash:
                    i = pos_hash[col] + 1
                    newRow[i - 1] = row[i]
                self.row_map[row[0]] = len(self.matrix)
                self.matrix.append(newRow)

        self.col_map = {}
        for col in pos_hash:
            self.col_map[col] = pos_hash[col]
    
    def get_row_list(self):
        """
        Returns names of rows
        """
        out = self.row_map.keys()
        out.sort( lambda x,y: self.row_map[x]-self.row_map[y])
        return out 

    def get_col_list(self):
        """
        Returns names of columns
        """
        out = self.col_map.keys()
        out.sort( lambda x,y: self.col_map[x]-self.col_map[y])
        return out 

    def get_row(self, row_name):
        return self.matrix[ self.row_map[row_name] ]
        
    def set_val(self, col_name, row_name, value):
        """
        Set cell value based on row and column names
        """
        self.matrix[self.row_map[row_name]][self.col_map[col_name]] = value
    
    def get_val(self, col_name, row_name):
        """
        Get cell value based on row and column names
        """
        return self.matrix[self.row_map[row_name]][self.col_map[col_name]]
        
        
    def init_blank(self, cols, rows):
        """
        Initlize matrix with NA (or nan) values using row/column names
        provided by user. User can also force usage of native python objects
        (which is useful for string based matrices, and numpy matrices fix cel string length)
        """       
        self.matrix = []
        self.col_map = {}
        self.row_map = {}    
        for i in range(len(rows)):
            self.matrix.append([""]*len(cols))
        for i, c in enumerate(cols):
            self.col_map[c] = i
        for i, r in enumerate(rows):
            self.row_map[r] = i
        self.loaded = True

    def write(self, handle, missing=''):
        write = csv.writer(handle, delimiter="\t", lineterminator='\n')
        col_list = self.get_col_list()
        
        write.writerow([self.corner_name] + col_list)
        for rowName in self.row_map:
            out = [rowName]
            row = self.get_row(rowName)
            for col in col_list:
                val = row[self.col_map[col]]
                out.append(val)
            write.writerow(out)


if __name__ == "__main__" : 
    
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--col-matrix', help='Matrix to saturate by columns', dest="col_matrix", default=None)
    #parser.add_argument('-r', '--row-matrix', help='Matrix to censor by rows', dest="row_matrix", default=None)
    parser.add_argument('-d', '--iddag', help='IDDag to use for saturation', dest="iddag", default=None)
    parser.add_argument("-o", "--out", help="Output File", dest="output", default=None)
    parser.add_argument("-l", "--leaf-only", help="Lead Only", dest="leaf_only", action="store_true", default=False)
    
    args = parser.parse_args()
    

    matrix = ClinicalMatrix()
    matrix.load(args.col_matrix)
    iddag = IDDag()
    iddag.load(args.iddag)
    
    expander = IDExpander(iddag)
    out = expander.expand_matrix(matrix, args.leaf_only)
    if args.output is None:
        out.write(sys.stdout)
    else:
        handle = open(args.output, "w")
        out.write(handle)
        handle.close()