diff galaxy-tools/biobank/tools/check_merge_individuals.py @ 0:ba6cf6ede027 draft default tip

Uploaded
author ric
date Wed, 28 Sep 2016 06:03:30 -0400
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/galaxy-tools/biobank/tools/check_merge_individuals.py	Wed Sep 28 06:03:30 2016 -0400
@@ -0,0 +1,104 @@
+import sys, csv, argparse, os
+from collections import Counter
+
+from bl.vl.kb import KnowledgeBase as KB
+import bl.vl.utils.ome_utils as vlu
+from bl.vl.utils import LOG_LEVELS, get_logger
+
+
+def make_parser():
+    parser = argparse.ArgumentParser(description='check data that will be passed to the merge_individuals tool')
+    parser.add_argument('--logfile', type=str, help='log file (default=stderr)')
+    parser.add_argument('--loglevel', type=str, choices=LOG_LEVELS,
+                        help='logging level (default=INFO)', default='INFO')
+    parser.add_argument('-H', '--host', type=str, help='omero hostname')
+    parser.add_argument('-U', '--user', type=str, help='omero user')
+    parser.add_argument('-P', '--passwd', type=str, help='omero password')
+    parser.add_argument('--in_file', type=str, required=True,
+                        help='input file')
+    parser.add_argument('--out_file', type=str, required=True,
+                        help='output file')
+    return parser
+
+
+def get_invalid_vids(records, logger):
+    records_map = {}
+    invalid_vids = []
+
+    for rec in records:
+        for k,v in rec.iteritems():
+            records_map.setdefault(k, []).append(v)
+    # Check for duplicated sources
+    ct = Counter()
+    for x in records_map['source']:
+        ct[x] += 1
+    for k, v in ct.iteritems():
+        if v > 1:
+            logger.error('ID %s appears %d times as source, this ID has been marked as invalid' % (k, v))
+            invalid_vids.append(k)
+    # Check for VIDs that appear bots in 'source' and 'target' fields
+    sources = set(records_map['source'])
+    targets = set(records_map['target'])
+    commons = sources.intersection(targets)
+    for c in commons:
+        logger.error('ID %s appears both in \'source\' and \'target\' columns, this ID has been marked as invalid' % c)
+        invalid_vids.append(c)
+        
+    return set(invalid_vids)
+
+
+def check_row(row, individuals, logger):
+    try:
+        source = individuals[row['source']]
+        logger.debug('%s is a valid Individual ID' % source.id)
+        target = individuals[row['target']]
+        logger.debug('%s is a valid Individual ID' % target.id)
+        return True
+    except KeyError, ke:
+        logger.error('%s is not a valid Individual ID' % ke)
+        return False
+        
+
+def main(argv):
+    parser = make_parser()
+    args = parser.parse_args(argv)
+
+    logger = get_logger('check_merge_individuals', level=args.loglevel,
+                        filename=args.logfile)
+
+    try:
+        host = args.host or vlu.ome_host()
+        user = args.user or vlu.ome_user()
+        passwd = args.passwd or vlu.ome_passwd()
+    except ValueError, ve:
+        logger.critical(ve)
+        sys.exit(ve)
+
+    kb = KB(driver='omero')(host, user, passwd)
+    
+    logger.info('Preloading all individuals')
+    inds = kb.get_objects(kb.Individual)
+    logger.info('Loaded %d individuals' % len(inds))
+    inds_map = {}
+    for i in inds:
+        inds_map[i.id] = i
+
+    with open(args.in_file) as infile, open(args.out_file, 'w') as outfile:
+        reader = csv.DictReader(infile, delimiter='\t')
+        records = [row for row in reader]
+        invalid_vids = get_invalid_vids(records, logger)
+        
+        writer = csv.DictWriter(outfile, reader.fieldnames, delimiter='\t')
+        writer.writeheader()
+
+        for record in records:
+            if record['source'] in invalid_vids or record['target'] in invalid_vids:
+                logger.error('Skipping record %r because at least one ID was marked as invalid' % record)
+            else:
+                if check_row(record, inds_map, logger):
+                    writer.writerow(record)
+                    logger.debug('Record %r written in output file' % record)
+                    
+
+if __name__ == '__main__':
+    main(sys.argv[1:])