from bl.core.io.illumina import GenomeStudioSampleSheetReader as gsr
from bl.vl.utils import LOG_LEVELS, get_logger
import csv, argparse, sys, re


def make_parser():
    parser = argparse.ArgumentParser('Split GenomeStudio samplesheet in TSV files to import data within OMERO')
    parser.add_argument('--logfile', type=str, help='log file (default=stderr)')
    parser.add_argument('--loglevel', type=str, choices=LOG_LEVELS,
                        help='logging level', default='INFO')
    parser.add_argument('--input-file', type=str, required=True,
                        help='GenomeStudio samplesheet')
    parser.add_argument('--arrays-out-file', type=str,
                        help='output file containing IlluminaArrayOfArrays definitions',
                        default='./array_of_arrays.tsv')
    parser.add_argument('--bead-chip-out-file', type=str,
                        help='output file containing IlluminaBeadChipArray definitions',
                        default='./bead_chip.tsv')
    parser.add_argument('--array-measure-out-file', type=str,
                        help='output file containing IlluminaBeadChipMeasure definitions',
                        default='./array_measure.tsv')
    parser.add_argument('--array-measures-out-file', type=str,
                        help='output file containing IlluminaBeadChipMeasures definitions',
                        default='./array_measures.tsv')
    parser.add_argument('--study', type=str, required=True,
                        help='Study label that will be used in the import procedure')
    return parser


def get_assay_type_enum(manifest_file):
    return manifest_file.strip().replace('.bpm', '').replace('-', '_').replace(' ', '_').upper()


def prepare_array_of_arrays_input(barcode, study, elements):
    ICHIPCORDS_PATTERN = re.compile(r'^r(\d{2})c(\d{2})$', re.IGNORECASE)
    rows = []
    cols = []
    for x in elements:
        m = re.match(ICHIPCORDS_PATTERN, x['array_label'])
        rows.append(int(m.groups()[0]))
        cols.append(int(m.groups()[1]))
    return {
        'barcode': barcode,
        'rows': max(rows),
        'columns': max(cols),
        'label': barcode,
        'study': study,
    }


def barcodes_to_labels(elements, wells_map, strict_mapping, logger):
    from copy import deepcopy

    mapped_elements = []
    for e in elements:
        if e['source'] in wells_map:
            new_el = deepcopy(e)
            new_el['source'] = wells_map[e['source']]
            mapped_elements.append(new_el)
        else:
            logger.warning('Unable to map well %s' % e['source'])

    if strict_mapping and len(mapped_elements) < len(elements):
        msg = 'Mapped %d records of %d' %(len(elements), len(mapped_elements))
        logger.critical(msg)
        sys.exit(msg)
    return mapped_elements


def prepare_bead_chip_array_input(array_barcode, assay_type, study, elements):
    return [{
        'illumina_array': array_barcode,
        'label': x['array_label'],
        'source': x['source'],
        'bead_chip_assay_type': assay_type,
        'study': study,
    } for x in elements]


def prepare_bead_chip_measure_input(array_barcode, study, elements,
                                    device='generic_illumina_scanner',
                                    status='USABLE'):
    records = []
    for channel in ['Grn', 'Red']:
        records.extend(
            [
                {
                    'label': '%s_%s_%s' % (array_barcode, x['array_label'], channel),
                    'source': '%s:%s' % (array_barcode, x['array_label']),
                    'scanner': device,
                    'status': status,
                    'study': study,
                } for x in elements
            ]
        )
    return records


def prepare_bead_chip_array_measures_input(array_barcode, study, elements):
    return [{
        'study': study,
        'label': '%s_%s' % (array_barcode, x['array_label']),
        'red_channel': '%s_%s_Red' % (array_barcode, x['array_label']),
        'green_channel': '%s_%s_Grn' %(array_barcode, x['array_label']),
        'source': '%s:%s' % (array_barcode, x['array_label']),
    } for x in elements]


def main(argv):
    parser = make_parser()
    args = parser.parse_args(argv)

    logger = get_logger('prepare_illumina_import_inputs', level=args.loglevel,
                        filename=args.logfile)

    logger.info('Processing file %s', args.input_file)
    with open(args.input_file) as in_file:
        reader = gsr(in_file)
        assay_type = get_assay_type_enum(reader.header['A'])
        arrays_map = {}
        for r in reader:
            arrays_map.setdefault(r['SentrixBarcode_A'], []).append({'source': r['Sample_ID'],
                                                                     'array_label': r['SentrixPosition_A']})
        with open(args.arrays_out_file, 'w') as array_file,\
            open(args.bead_chip_out_file, 'w') as chip_file,\
            open(args.array_measures_out_file, 'w') as measures_file,\
            open(args.array_measure_out_file, 'w') as measure_file:
            arrays_writer = csv.DictWriter(array_file,
                                           ['study', 'label', 'barcode', 'rows', 'columns'],
                                           delimiter='\t')
            arrays_writer.writeheader()
            chip_writer = csv.DictWriter(chip_file,
                                         ['study', 'illumina_array', 'label', 'source',
                                          'bead_chip_assay_type'],
                                         delimiter='\t')
            chip_writer.writeheader()
            measure_writer = csv.DictWriter(measures_file,
                                            ['study', 'label', 'source', 'scanner', 'status'],
                                            delimiter='\t')
            measure_writer.writeheader()
            measures_writer = csv.DictWriter(measure_file,
                                             ['study', 'label', 'red_channel', 'green_channel',
                                              'source'],
                                             delimiter='\t')
            measures_writer.writeheader()
            for k, v in arrays_map.iteritems():
                arrays_writer.writerow(prepare_array_of_arrays_input(k, args.study, v))
                chip_writer.writerows(prepare_bead_chip_array_input(k, assay_type, args.study, v))
                measure_writer.writerows(prepare_bead_chip_measure_input(k, args.study, v))
                measures_writer.writerows(prepare_bead_chip_array_measures_input(k, args.study, v))
    logger.info('Job completed')


if __name__ == '__main__':
    main(sys.argv[1:])