"""
Map PlateWell labels written as PLATE_BARCODE:WELL_LABEL to labels written as
PLATE_LABEL:WELL_LABEL which is the PlateWell label format required by the map_vid
application.
The inputs are a TSV file and the label of the column of this file containing the
PlateWell labels that are going to be mapped.
"""

import csv, argparse, sys, copy

from bl.vl.kb import KnowledgeBase as KB
from bl.vl.utils import LOG_LEVELS, get_logger
import bl.vl.utils.ome_utils as vlu


def get_wells_map(kb, plate_barcodes, logger):
    wells_map = {}
    logger.info('Start building PlateWells map')
    res = kb.get_by_field(kb.TiterPlate, 'barcode', plate_barcodes)
    logger.debug('Plates %r --- Results: %r', plate_barcodes, res)
    for _, pl in res.iteritems():
        if pl.OME_TABLE == 'TiterPlate':
            if pl.barcode:
                for w in kb.get_wells_by_plate(pl):
                    logger.debug('Mapping well %s of plate %s', w.label, w.container.label)
                    wells_map['%s:%s' % (w.container.barcode, w.label)] = '%s:%s' % (w.container.label,
                                                                                     w.label)
            else:
                logger.debug('TiterPlate %s has no barcode', pl.label)
        else:
            logger.debug('Object is a %r, skipping it', pl.OME_TABLE)
    logger.info('Mapped %d PlateWells', len(wells_map))
    return wells_map


def get_plates_list(records, plates_column, logger):
    plates = set()
    logger.info('Retrieving TiterPlate barcodes from %d records', len(records))
    for r in records:
        plates.add(r[plates_column].split(':')[0])
    logger.info('Found %d TiterPlate objects', len(plates))
    return list(plates)


def make_parser():
    parser = argparse.ArgumentParser('Map barcodes in PlateWell labels to TiterPlate labels')
    parser.add_argument('--logfile', type=str, help='log file (default=stderr)')
    parser.add_argument('--loglevel', type=str, choices=LOG_LEVELS,
                        help='logging level', default='INFO')
    parser.add_argument('-H', '--host', type=str, help='OMERO host')
    parser.add_argument('-U', '--user', type=str, help='OMERO user')
    parser.add_argument('-P', '--passwd', type=str, help='OMERO password')
    parser.add_argument('--in-file', type=str, required=True,
                        help='input TSV file')
    parser.add_argument('--column-label', type=str, required=True,
                        help='the label of the columun containing the values that will be mapped')
    parser.add_argument('--out-file', type=str, required=True,
                        help='output TSV file')
    parser.add_argument('--strict-mapping', action='store_true',
                        help='if output records are less than the input ones, raise an error')
    return parser


def main(argv):
    parser = make_parser()
    args = parser.parse_args(argv)

    logger = get_logger('wells_barcode_to_label', level=args.loglevel,
                        filename=args.logfile)
    try:
        host = args.host or vlu.ome_host()
        user = args.user or vlu.ome_user()
        passwd = args.passwd or vlu.ome_passwd()
    except ValueError, ve:
        logger.critical(ve)
        sys.exit(ve)

    logger.info('Starting job')

    kb = KB(driver='omero')(host, user, passwd)
    # wells_map = get_wells_map(kb, logger)

    with open(args.in_file) as in_file, open(args.out_file, 'w') as out_file:
        reader = csv.DictReader(in_file, delimiter='\t')
        if args.column_label not in reader.fieldnames:
            msg = 'No column %s in file %s' % (args.column_label, args.in_file)
            logger.critical(msg)
            raise RuntimeError(msg)
        records = [row for row in reader]
        plates = get_plates_list(records, args.column_label, logger)
        wells_map = get_wells_map(kb, plates, logger)
        logger.info('Mapping %d records', len(records))
        writer = csv.DictWriter(out_file, reader.fieldnames, delimiter='\t')
        writer.writeheader()
        mapped_records = []
        for rec in records:
            mapped = copy.deepcopy(rec)
            logger.debug('Mapping value %s', mapped[args.column_label])
            if mapped[args.column_label] in wells_map:
                mapped[args.column_label] = wells_map[mapped[args.column_label]]
                mapped_records.append(mapped)
        if args.strict_mapping and len(mapped_records) < len(records):
            msg = 'Mapped %d record of %d' % (len(mapped_records), len(records))
            logger.critical(msg)
            sys.exit(msg)
        logger.info('%d records mapped', len(mapped_records))
        writer.writerows(mapped_records)
    logger.info('Job completed')


if __name__ == '__main__':
    main(sys.argv[1:])