#!/usr/bin/env python

import sys
import csv
import argparse

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--col-matrix', help='Matrix to censor by columns', dest="col_matrix", default=None)
    parser.add_argument('-r', '--row-matrix', help='Matrix to censor by rows', dest="row_matrix", default=None)
    parser.add_argument('-b', '--bed', help='BED file to censor', dest="bed", default=None)
    parser.add_argument('-w', '--whitelist', help='White list of samples', dest="white_list", default=None)
    parser.add_argument("-d", "--delim", help="Field Delimiter (Default \t)", dest="delim", default="\t")
    parser.add_argument("-o", "--out", help="Output File", dest="output", default=None)
    
    args = parser.parse_args()
    
    if args.white_list is None:
        sys.stderr.write("Must Provide whitelist\n")
        sys.exit(0)
    
    whitelist = {}
    handle = open(args.white_list)
    for line in handle:
        key = line.rstrip().split("\t")[0]
        whitelist[key] = True
    
    if args.col_matrix is not None:
        handle = open(args.col_matrix)
        reader = csv.reader(handle, delimiter=args.delim)
        out = sys.stdout
        writer = None
        head = None
        for row in reader:
            if head is None:
                head = [0]
                orow = [row[0]]
                for i, a in enumerate(row[1:]):
                    if a in whitelist:
                        head.append(i+1)
                        orow.append(a)
                
                if len(orow) < 2:
                    break
                if args.output is not None:
                    out = open(args.output, "w")
                writer = csv.writer(out, delimiter="\t", lineterminator="\n")
                writer.writerow(orow)
            else:
                orow = []
                for i in head:
                    orow.append(row[i])
                writer.writerow(orow)
        handle.close()
        if args.output is not None:
            out.close()
    
    
    if args.row_matrix is not None:
        handle = open(args.row_matrix)
        reader = csv.reader(handle, delimiter=args.delim)
        out = sys.stdout
        writer = None
        header = None
        for row in reader:
            if header is None:
                header = row
            else:
                if row[0] in whitelist:
                    if writer is None:
                        if args.output is not None:
                            out = open(args.output, "w")
                        writer = csv.writer(out, delimiter="\t", lineterminator="\n")
                        writer.writerow(header)
                    writer.writerow(row)
        handle.close()
        if writer is not None:
            out.close()
                    
            
    if args.bed is not None:
        handle = open(args.bed)
        reader = csv.reader(handle, delimiter=args.delim)
        out = sys.stdout
        writer = None
        for row in reader:
            if row[3] in whitelist:
                if writer is None:
                    if args.output is not None:
                        out = open(args.output, "w")
                    writer = csv.writer(out, delimiter="\t", lineterminator="\n")
                writer.writerow(row)
        handle.close()
        if writer is not None:
            out.close()
    
