comparison correct.py @ 18:e4d75f9efb90 draft

planemo upload commit b'4303231da9e48b2719b4429a29b72421d24310f4\n'-dirty
author nick
date Thu, 02 Feb 2017 18:44:31 -0500
parents
children
comparison
equal deleted inserted replaced
17:836fa4fe9494 18:e4d75f9efb90
1 #!/usr/bin/env python
2 from __future__ import division
3 from __future__ import print_function
4 import os
5 import sys
6 import gzip
7 import logging
8 import argparse
9 import resource
10 import subprocess
11 import networkx
12 import swalign
13
14 VERBOSE = (logging.DEBUG+logging.INFO)//2
15 ARG_DEFAULTS = {'sam':sys.stdin, 'qual':20, 'pos':2, 'dist':1, 'choose_by':'reads', 'output':True,
16 'visualize':0, 'viz_format':'png', 'log':sys.stderr, 'volume':logging.WARNING}
17 USAGE = "%(prog)s [options]"
18 DESCRIPTION = """Correct barcodes using an alignment of all barcodes to themselves. Reads the
19 alignment in SAM format and corrects the barcodes in an input "families" file (the output of
20 make-barcodes.awk). It will print the "families" file to stdout with barcodes (and orders)
21 corrected."""
22
23
24 def main(argv):
25
26 parser = argparse.ArgumentParser(description=DESCRIPTION)
27 parser.set_defaults(**ARG_DEFAULTS)
28
29 parser.add_argument('families', type=open_as_text_or_gzip,
30 help='The sorted output of make-barcodes.awk. The important part is that it\'s a tab-delimited '
31 'file with at least 2 columns: the barcode sequence and order, and it must be sorted in '
32 'the same order as the "reads" in the SAM file.')
33 parser.add_argument('reads', type=open_as_text_or_gzip,
34 help='The fasta/q file given to the aligner. Used to get barcode sequences from read names.')
35 parser.add_argument('sam', type=argparse.FileType('r'), nargs='?',
36 help='Barcode alignment, in SAM format. Omit to read from stdin. The read names must be '
37 'integers, representing the (1-based) order they appear in the families file.')
38 parser.add_argument('-P', '--prepend', action='store_true',
39 help='Prepend the corrected barcodes and orders to the original columns.')
40 parser.add_argument('-d', '--dist', type=int,
41 help='NM edit distance threshold. Default: %(default)s')
42 parser.add_argument('-m', '--mapq', type=int,
43 help='MAPQ threshold. Default: %(default)s')
44 parser.add_argument('-p', '--pos', type=int,
45 help='POS tolerance. Alignments will be ignored if abs(POS - 1) is greater than this value. '
46 'Set to greater than the barcode length for no threshold. Default: %(default)s')
47 parser.add_argument('-t', '--tag-len', type=int,
48 help='Length of each half of the barcode. If not given, it will be determined from the first '
49 'barcode in the families file.')
50 parser.add_argument('-c', '--choose-by', choices=('reads', 'connectivity'))
51 parser.add_argument('--limit', type=int,
52 help='Limit the number of lines that will be read from each input file, for testing purposes.')
53 parser.add_argument('-S', '--structures', action='store_true',
54 help='Print a list of the unique isoforms')
55 parser.add_argument('--struct-human', action='store_true')
56 parser.add_argument('-V', '--visualize', nargs='?',
57 help='Produce a visualization of the unique structures write the image to this file. '
58 'If you omit a filename, it will be displayed in a window.')
59 parser.add_argument('-F', '--viz-format', choices=('dot', 'graphviz', 'png'))
60 parser.add_argument('-n', '--no-output', dest='output', action='store_false')
61 parser.add_argument('-l', '--log', type=argparse.FileType('w'),
62 help='Print log messages to this file instead of to stderr. Warning: Will overwrite the file.')
63 parser.add_argument('-q', '--quiet', dest='volume', action='store_const', const=logging.CRITICAL)
64 parser.add_argument('-i', '--info', dest='volume', action='store_const', const=logging.INFO)
65 parser.add_argument('-v', '--verbose', dest='volume', action='store_const', const=VERBOSE)
66 parser.add_argument('-D', '--debug', dest='volume', action='store_const', const=logging.DEBUG,
67 help='Print debug messages (very verbose).')
68
69 args = parser.parse_args(argv[1:])
70
71 logging.basicConfig(stream=args.log, level=args.volume, format='%(message)s')
72 tone_down_logger()
73
74 logging.info('Reading the fasta/q to map read names to barcodes..')
75 names_to_barcodes = map_names_to_barcodes(args.reads, args.limit)
76
77 logging.info('Reading the SAM to build the graph of barcode relationships..')
78 graph, reversed_barcodes = read_alignments(args.sam, names_to_barcodes, args.pos, args.mapq,
79 args.dist, args.limit)
80 logging.info('{} reversed barcodes'.format(len(reversed_barcodes)))
81
82 logging.info('Reading the families.tsv to get the counts of each family..')
83 family_counts = get_family_counts(args.families, args.limit)
84
85 if args.structures:
86 logging.info('Counting the unique barcode networks..')
87 structures = count_structures(graph, family_counts)
88 print_structures(structures, args.struct_human)
89 if args.visualize != 0:
90 logging.info('Generating a visualization of barcode networks..')
91 visualize([s['graph'] for s in structures], args.visualize, args.viz_format)
92
93 logging.info('Building the correction table from the graph..')
94 corrections = make_correction_table(graph, family_counts, args.choose_by)
95
96 logging.info('Reading the families.tsv again to print corrected output..')
97 families = open_as_text_or_gzip(args.families.name)
98 print_corrected_output(families, corrections, reversed_barcodes, args.prepend, args.limit,
99 args.output)
100
101 max_mem = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024
102 logging.info('Max memory usage: {:0.2f}MB'.format(max_mem))
103
104
105 def detect_format(reads_file, max_lines=7):
106 """Detect whether a file is a fastq or a fasta, based on its content."""
107 fasta_votes = 0
108 fastq_votes = 0
109 line_num = 0
110 for line in reads_file:
111 line_num += 1
112 if line_num % 4 == 1:
113 if line.startswith('@'):
114 fastq_votes += 1
115 elif line.startswith('>'):
116 fasta_votes += 1
117 elif line_num % 4 == 3:
118 if line.startswith('+'):
119 fastq_votes += 1
120 elif line.startswith('>'):
121 fasta_votes += 1
122 if line_num >= max_lines:
123 break
124 reads_file.seek(0)
125 if fasta_votes > fastq_votes:
126 return 'fasta'
127 elif fastq_votes > fasta_votes:
128 return 'fastq'
129 else:
130 return None
131
132
133 def read_fastaq(reads_file):
134 filename = reads_file.name
135 if filename.endswith('.fa') or filename.endswith('.fasta'):
136 format = 'fasta'
137 elif filename.endswith('.fq') or filename.endswith('.fastq'):
138 format = 'fastq'
139 else:
140 format = detect_format(reads_file)
141 if format == 'fasta':
142 return read_fasta(reads_file)
143 elif format == 'fastq':
144 return read_fastq(reads_file)
145
146
147 def read_fasta(reads_file):
148 """Read a FASTA file, yielding read names and sequences.
149 NOTE: This assumes sequences are only one line!"""
150 line_num = 0
151 for line_raw in reads_file:
152 line = line_raw.rstrip('\r\n')
153 line_num += 1
154 if line_num % 2 == 1:
155 assert line.startswith('>'), line
156 read_name = line[1:]
157 elif line_num % 2 == 0:
158 read_seq = line
159 yield read_name, read_seq
160
161
162 def read_fastq(reads_file):
163 """Read a FASTQ file, yielding read names and sequences.
164 NOTE: This assumes sequences are only one line!"""
165 line_num = 0
166 for line in reads_file:
167 line_num += 1
168 if line_num % 4 == 1:
169 assert line.startswith('@'), line
170 read_name = line[1:].rstrip('\r\n')
171 elif line_num % 4 == 2:
172 read_seq = line.rstrip('\r\n')
173 yield read_name, read_seq
174
175
176 def map_names_to_barcodes(reads_file, limit=None):
177 """Map barcode names to their sequences."""
178 names_to_barcodes = {}
179 read_num = 0
180 for read_name, read_seq in read_fastaq(reads_file):
181 read_num += 1
182 if limit is not None and read_num > limit:
183 break
184 try:
185 name = int(read_name)
186 except ValueError:
187 logging.critical('non-int read name "{}"'.format(name))
188 raise
189 names_to_barcodes[name] = read_seq
190 reads_file.close()
191 return names_to_barcodes
192
193
194 def parse_alignment(sam_file, pos_thres, mapq_thres, dist_thres):
195 """Parse the SAM file and yield reads that pass the filters.
196 Returns (qname, rname, reversed)."""
197 line_num = 0
198 for line in sam_file:
199 line_num += 1
200 if line.startswith('@'):
201 logging.debug('Header line ({})'.format(line_num))
202 continue
203 fields = line.split('\t')
204 logging.debug('read {} -> ref {} (read seq {}):'.format(fields[2], fields[0], fields[9]))
205 qname_str = fields[0]
206 rname_str = fields[2]
207 rname_fields = rname_str.split(':')
208 if len(rname_fields) == 2 and rname_fields[1] == 'rev':
209 reversed = True
210 rname_str = rname_fields[0]
211 else:
212 reversed = False
213 try:
214 qname = int(qname_str)
215 rname = int(rname_str)
216 except ValueError:
217 if fields[2] == '*':
218 logging.debug('\tRead unmapped (reference == "*")')
219 continue
220 else:
221 logging.error('Non-integer read name(s) on line {}: "{}", "{}".'
222 .format(line_num, qname, rname))
223 raise
224 # Apply alignment quality filters.
225 try:
226 flags = int(fields[1])
227 pos = int(fields[3])
228 mapq = int(fields[4])
229 except ValueError:
230 logging.warn('\tNon-integer flag ({}), pos ({}), or mapq ({})'
231 .format(fields[1], fields[3], fields[4]))
232 continue
233 if flags & 4:
234 logging.debug('\tRead unmapped (flag & 4 == True)')
235 continue
236 if abs(pos - 1) > pos_thres:
237 logging.debug('\tAlignment failed pos filter: abs({} - 1) > {}'.format(pos, pos_thres))
238 continue
239 if mapq < mapq_thres:
240 logging.debug('\tAlignment failed mapq filter: {} > {}'.format(mapq, mapq_thres))
241 continue
242 nm = None
243 for tag in fields[11:]:
244 if tag.startswith('NM:i:'):
245 try:
246 nm = int(tag[5:])
247 except ValueError:
248 logging.error('Invalid NM tag "{}" on line {}.'.format(tag, line_num))
249 raise
250 break
251 assert nm is not None, line_num
252 if nm > dist_thres:
253 logging.debug('\tAlignment failed NM distance filter: {} > {}'.format(nm, dist_thres))
254 continue
255 yield qname, rname, reversed
256 sam_file.close()
257
258
259 def read_alignments(sam_file, names_to_barcodes, pos_thres, mapq_thres, dist_thres, limit=None):
260 """Read the alignments from the SAM file.
261 Returns a dict mapping each reference sequence (RNAME) to sets of sequences (QNAMEs) that align to
262 it."""
263 graph = networkx.Graph()
264 # This is the set of all barcodes which are involved in an alignment where the target is reversed.
265 # Whether it's a query or reference sequence in the alignment, it's marked here.
266 reversed_barcodes = set()
267 # Maps correct barcode numbers to sets of original barcodes (includes correct ones).
268 line_num = 0
269 for qname, rname, reversed in parse_alignment(sam_file, pos_thres, mapq_thres, dist_thres):
270 line_num += 1
271 if limit is not None and line_num > limit:
272 break
273 # Skip self-alignments.
274 if rname == qname:
275 continue
276 rseq = names_to_barcodes[rname]
277 qseq = names_to_barcodes[qname]
278 # Is this an alignment to a reversed barcode?
279 if reversed:
280 reversed_barcodes.add(rseq)
281 reversed_barcodes.add(qseq)
282 graph.add_node(rseq)
283 graph.add_node(qseq)
284 graph.add_edge(rseq, qseq)
285 return graph, reversed_barcodes
286
287
288 def get_family_counts(families_file, limit=None):
289 """For each family (barcode), count how many read pairs exist for each strand (order)."""
290 family_counts = {}
291 last_barcode = None
292 this_family_counts = None
293 line_num = 0
294 for line in families_file:
295 line_num += 1
296 if limit is not None and line_num > limit:
297 break
298 fields = line.rstrip('\r\n').split('\t')
299 barcode = fields[0]
300 order = fields[1]
301 if barcode != last_barcode:
302 if this_family_counts:
303 this_family_counts['all'] = this_family_counts['ab'] + this_family_counts['ba']
304 family_counts[last_barcode] = this_family_counts
305 this_family_counts = {'ab':0, 'ba':0}
306 last_barcode = barcode
307 this_family_counts[order] += 1
308 this_family_counts['all'] = this_family_counts['ab'] + this_family_counts['ba']
309 family_counts[last_barcode] = this_family_counts
310 families_file.close()
311 return family_counts
312
313
314 def make_correction_table(meta_graph, family_counts, choose_by='reads'):
315 """Make a table mapping original barcode sequences to correct barcodes.
316 Assumes the most connected node in the graph as the correct barcode."""
317 corrections = {}
318 for graph in networkx.connected_component_subgraphs(meta_graph):
319 if choose_by == 'reads':
320 def key(bar):
321 return family_counts[bar]['all']
322 elif choose_by == 'connectivity':
323 degrees = graph.degree()
324 def key(bar):
325 return degrees[bar]
326 barcodes = sorted(graph.nodes(), key=key, reverse=True)
327 correct = barcodes[0]
328 for barcode in barcodes:
329 if barcode != correct:
330 logging.debug('Correcting {} ->\n {}\n'.format(barcode, correct))
331 corrections[barcode] = correct
332 return corrections
333
334
335 def print_corrected_output(families_file, corrections, reversed_barcodes, prepend=False, limit=None,
336 output=True):
337 line_num = 0
338 barcode_num = 0
339 barcode_last = None
340 corrected = {'reads':0, 'barcodes':0, 'reversed':0}
341 reads = [0, 0]
342 corrections_in_this_family = 0
343 for line in families_file:
344 line_num += 1
345 if limit is not None and line_num > limit:
346 break
347 fields = line.rstrip('\r\n').split('\t')
348 raw_barcode = fields[0]
349 order = fields[1]
350 if raw_barcode != barcode_last:
351 # We just started a new family.
352 barcode_num += 1
353 family_info = '{}\t{}\t{}'.format(barcode_last, reads[0], reads[1])
354 if corrections_in_this_family:
355 corrected['reads'] += corrections_in_this_family
356 corrected['barcodes'] += 1
357 family_info += '\tCORRECTED!'
358 else:
359 family_info += '\tuncorrected'
360 logging.log(VERBOSE, family_info)
361 reads = [0, 0]
362 corrections_in_this_family = 0
363 barcode_last = raw_barcode
364 if order == 'ab':
365 reads[0] += 1
366 elif order == 'ba':
367 reads[1] += 1
368 if raw_barcode in corrections:
369 correct_barcode = corrections[raw_barcode]
370 corrections_in_this_family += 1
371 # Check if the order of the barcode reverses in the correct version.
372 # First, we check in reversed_barcodes whether either barcode was involved in a reversed
373 # alignment, to save time (is_alignment_reversed() does a full smith-waterman alignment).
374 if ((raw_barcode in reversed_barcodes or correct_barcode in reversed_barcodes) and
375 is_alignment_reversed(raw_barcode, correct_barcode)):
376 # If so, then switch the order field.
377 corrected['reversed'] += 1
378 if order == 'ab':
379 fields[1] = 'ba'
380 else:
381 fields[1] = 'ab'
382 else:
383 correct_barcode = raw_barcode
384 if prepend:
385 fields.insert(0, correct_barcode)
386 else:
387 fields[0] = correct_barcode
388 if output:
389 print(*fields, sep='\t')
390 families_file.close()
391 if corrections_in_this_family:
392 corrected['reads'] += corrections_in_this_family
393 corrected['barcodes'] += 1
394 logging.info('Corrected {barcodes} barcodes on {reads} read pairs, with {reversed} reversed.'
395 .format(**corrected))
396
397
398 def is_alignment_reversed(barcode1, barcode2):
399 """Return True if the barcodes are reversed with respect to each other, False otherwise.
400 "reversed" in this case meaning the alpha + beta halves are swapped.
401 Determine by aligning the two to each other, once in their original forms, and once with the
402 second barcode reversed. If the smith-waterman score is higher in the reversed form, return True.
403 """
404 half = len(barcode2)//2
405 barcode2_rev = barcode2[half:] + barcode2[:half]
406 fwd_align = swalign.smith_waterman(barcode1, barcode2)
407 rev_align = swalign.smith_waterman(barcode1, barcode2_rev)
408 if rev_align.score > fwd_align.score:
409 return True
410 else:
411 return False
412
413
414 def count_structures(meta_graph, family_counts):
415 """Count the number of unique (isomorphic) subgraphs in the main graph."""
416 structures = []
417 for graph in networkx.connected_component_subgraphs(meta_graph):
418 match = False
419 for structure in structures:
420 archetype = structure['graph']
421 if networkx.is_isomorphic(graph, archetype):
422 match = True
423 structure['count'] += 1
424 structure['central'] += int(is_centralized(graph, family_counts))
425 break
426 if not match:
427 size = len(graph)
428 central = is_centralized(graph, family_counts)
429 structures.append({'graph':graph, 'size':size, 'count':1, 'central':int(central)})
430 return structures
431
432
433 def is_centralized(graph, family_counts):
434 """Checks if the graph is centralized in terms of where the reads are located.
435 In a centralized graph, the node with the highest degree is the only one which (may) have more
436 than one read pair associated with that barcode.
437 This returns True if that's the case, False otherwise."""
438 if len(graph) == 2:
439 # Special-case graphs with 2 nodes, since the other algorithm doesn't work for them.
440 # - When both nodes have a degree of 1, sorting by degree doesn't work and can result in the
441 # barcode with more read pairs coming second.
442 barcode1, barcode2 = graph.nodes()
443 counts1 = family_counts[barcode1]
444 counts2 = family_counts[barcode2]
445 total1 = counts1['all']
446 total2 = counts2['all']
447 logging.debug('{}: {:3d} ({}/{})\n{}: {:3d} ({}/{})\n'
448 .format(barcode1, total1, counts1['ab'], counts1['ba'],
449 barcode2, total2, counts2['ab'], counts2['ba']))
450 if (total1 >= 1 and total2 == 1) or (total1 == 1 and total2 >= 1):
451 return True
452 else:
453 return False
454 else:
455 degrees = graph.degree()
456 first = True
457 for barcode in sorted(graph.nodes(), key=lambda barcode: degrees[barcode], reverse=True):
458 if not first:
459 counts = family_counts[barcode]
460 # How many read pairs are associated with this barcode (how many times did we see this barcode)?
461 try:
462 if counts['all'] > 1:
463 return False
464 except TypeError:
465 logging.critical('barcode: {}, counts: {}'.format(barcode, counts))
466 raise
467 first = False
468 return True
469
470
471 def print_structures(structures, human=True):
472 # Define a cmp function to sort the list of structures in ascending order of size, but then
473 # descending order of count.
474 def cmp_fxn(structure1, structure2):
475 if structure1['size'] == structure2['size']:
476 return structure2['count'] - structure1['count']
477 else:
478 return structure1['size'] - structure2['size']
479 width = None
480 last_size = None
481 for structure in sorted(structures, cmp=cmp_fxn):
482 size = structure['size']
483 graph = structure['graph']
484 if size == last_size:
485 i += 1
486 else:
487 i = 0
488 if width is None:
489 width = str(len(str(structure['count'])))
490 letters = num_to_letters(i)
491 degrees = sorted(graph.degree().values(), reverse=True)
492 if human:
493 degrees_str = ' '.join(map(str, degrees))
494 else:
495 degrees_str = ','.join(map(str, degrees))
496 if human:
497 format_str = '{:2d}{:<3s} {count:<'+width+'d} {central:<'+width+'d} {}'
498 print(format_str.format(size, letters+':', degrees_str, **structure))
499 else:
500 print(size, letters, structure['count'], structure['central'], degrees_str, sep='\t')
501 last_size = size
502
503
504 def num_to_letters(i):
505 """Translate numbers to letters, e.g. 1 -> A, 10 -> J, 100 -> CV"""
506 letters = ''
507 while i > 0:
508 n = (i-1) % 26
509 i = i // 26
510 if n == 25:
511 i -= 1
512 letters = chr(65+n) + letters
513 return letters
514
515
516 def visualize(graphs, viz_path, args_viz_format):
517 import matplotlib
518 from networkx.drawing.nx_agraph import graphviz_layout
519 meta_graph = networkx.Graph()
520 for graph in graphs:
521 add_graph(meta_graph, graph)
522 pos = graphviz_layout(meta_graph)
523 networkx.draw(meta_graph, pos)
524 if viz_path:
525 ext = os.path.splitext(viz_path)[1]
526 if ext == '.dot':
527 viz_format = 'graphviz'
528 elif ext == '.png':
529 viz_format = 'png'
530 else:
531 viz_format = args_viz_format
532 if viz_format == 'graphviz':
533 from networkx.drawing.nx_pydot import write_dot
534 assert viz_path is not None, 'Must provide a filename to --visualize if using --viz-format "graphviz".'
535 base_path = os.path.splitext(viz_path)
536 write_dot(meta_graph, base_path+'.dot')
537 run_command('dot', '-T', 'png', '-o', base_path+'.png', base_path+'.dot')
538 logging.info('Wrote image of graph to '+base_path+'.dot')
539 elif viz_format == 'png':
540 if viz_path is None:
541 matplotlib.pyplot.show()
542 else:
543 matplotlib.pyplot.savefig(viz_path)
544
545
546 def add_graph(graph, subgraph):
547 # I'm sure there's a function in the library for this, but just cause I need it quick..
548 for node in subgraph.nodes():
549 graph.add_node(node)
550 for edge in subgraph.edges():
551 graph.add_edge(*edge)
552 return graph
553
554
555 def open_as_text_or_gzip(path):
556 """Return an open file-like object reading the path as a text file or a gzip file, depending on
557 which it looks like."""
558 if detect_gzip(path):
559 return gzip.open(path, 'r')
560 else:
561 return open(path, 'rU')
562
563
564 def detect_gzip(path):
565 """Return True if the file looks like a gzip file: ends with .gz or contains non-ASCII bytes."""
566 ext = os.path.splitext(path)[1]
567 if ext == '.gz':
568 return True
569 elif ext in ('.txt', '.tsv', '.csv'):
570 return False
571 with open(path) as fh:
572 is_not_ascii = detect_non_ascii(fh.read(100))
573 if is_not_ascii:
574 return True
575
576
577 def detect_non_ascii(bytes, max_test=100):
578 """Return True if any of the first "max_test" bytes are non-ASCII (the high bit set to 1).
579 Return False otherwise."""
580 for i, char in enumerate(bytes):
581 # Is the high bit a 1?
582 if ord(char) & 128:
583 return True
584 if i >= max_test:
585 return False
586 return False
587
588
589 def run_command(*command):
590 try:
591 exit_status = subprocess.call(command)
592 except subprocess.CalledProcessError as cpe:
593 exit_status = cpe.returncode
594 except OSError:
595 exit_status = None
596 return exit_status
597
598
599 def tone_down_logger():
600 """Change the logging level names from all-caps to capitalized lowercase.
601 E.g. "WARNING" -> "Warning" (turn down the volume a bit in your log files)"""
602 for level in (logging.CRITICAL, logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG):
603 level_name = logging.getLevelName(level)
604 logging.addLevelName(level, level_name.capitalize())
605
606
607 if __name__ == '__main__':
608 sys.exit(main(sys.argv))