Mercurial > repos > nick > duplex
comparison correct.py @ 18:e4d75f9efb90 draft
planemo upload commit b'4303231da9e48b2719b4429a29b72421d24310f4\n'-dirty
author | nick |
---|---|
date | Thu, 02 Feb 2017 18:44:31 -0500 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
17:836fa4fe9494 | 18:e4d75f9efb90 |
---|---|
1 #!/usr/bin/env python | |
2 from __future__ import division | |
3 from __future__ import print_function | |
4 import os | |
5 import sys | |
6 import gzip | |
7 import logging | |
8 import argparse | |
9 import resource | |
10 import subprocess | |
11 import networkx | |
12 import swalign | |
13 | |
14 VERBOSE = (logging.DEBUG+logging.INFO)//2 | |
15 ARG_DEFAULTS = {'sam':sys.stdin, 'qual':20, 'pos':2, 'dist':1, 'choose_by':'reads', 'output':True, | |
16 'visualize':0, 'viz_format':'png', 'log':sys.stderr, 'volume':logging.WARNING} | |
17 USAGE = "%(prog)s [options]" | |
18 DESCRIPTION = """Correct barcodes using an alignment of all barcodes to themselves. Reads the | |
19 alignment in SAM format and corrects the barcodes in an input "families" file (the output of | |
20 make-barcodes.awk). It will print the "families" file to stdout with barcodes (and orders) | |
21 corrected.""" | |
22 | |
23 | |
24 def main(argv): | |
25 | |
26 parser = argparse.ArgumentParser(description=DESCRIPTION) | |
27 parser.set_defaults(**ARG_DEFAULTS) | |
28 | |
29 parser.add_argument('families', type=open_as_text_or_gzip, | |
30 help='The sorted output of make-barcodes.awk. The important part is that it\'s a tab-delimited ' | |
31 'file with at least 2 columns: the barcode sequence and order, and it must be sorted in ' | |
32 'the same order as the "reads" in the SAM file.') | |
33 parser.add_argument('reads', type=open_as_text_or_gzip, | |
34 help='The fasta/q file given to the aligner. Used to get barcode sequences from read names.') | |
35 parser.add_argument('sam', type=argparse.FileType('r'), nargs='?', | |
36 help='Barcode alignment, in SAM format. Omit to read from stdin. The read names must be ' | |
37 'integers, representing the (1-based) order they appear in the families file.') | |
38 parser.add_argument('-P', '--prepend', action='store_true', | |
39 help='Prepend the corrected barcodes and orders to the original columns.') | |
40 parser.add_argument('-d', '--dist', type=int, | |
41 help='NM edit distance threshold. Default: %(default)s') | |
42 parser.add_argument('-m', '--mapq', type=int, | |
43 help='MAPQ threshold. Default: %(default)s') | |
44 parser.add_argument('-p', '--pos', type=int, | |
45 help='POS tolerance. Alignments will be ignored if abs(POS - 1) is greater than this value. ' | |
46 'Set to greater than the barcode length for no threshold. Default: %(default)s') | |
47 parser.add_argument('-t', '--tag-len', type=int, | |
48 help='Length of each half of the barcode. If not given, it will be determined from the first ' | |
49 'barcode in the families file.') | |
50 parser.add_argument('-c', '--choose-by', choices=('reads', 'connectivity')) | |
51 parser.add_argument('--limit', type=int, | |
52 help='Limit the number of lines that will be read from each input file, for testing purposes.') | |
53 parser.add_argument('-S', '--structures', action='store_true', | |
54 help='Print a list of the unique isoforms') | |
55 parser.add_argument('--struct-human', action='store_true') | |
56 parser.add_argument('-V', '--visualize', nargs='?', | |
57 help='Produce a visualization of the unique structures write the image to this file. ' | |
58 'If you omit a filename, it will be displayed in a window.') | |
59 parser.add_argument('-F', '--viz-format', choices=('dot', 'graphviz', 'png')) | |
60 parser.add_argument('-n', '--no-output', dest='output', action='store_false') | |
61 parser.add_argument('-l', '--log', type=argparse.FileType('w'), | |
62 help='Print log messages to this file instead of to stderr. Warning: Will overwrite the file.') | |
63 parser.add_argument('-q', '--quiet', dest='volume', action='store_const', const=logging.CRITICAL) | |
64 parser.add_argument('-i', '--info', dest='volume', action='store_const', const=logging.INFO) | |
65 parser.add_argument('-v', '--verbose', dest='volume', action='store_const', const=VERBOSE) | |
66 parser.add_argument('-D', '--debug', dest='volume', action='store_const', const=logging.DEBUG, | |
67 help='Print debug messages (very verbose).') | |
68 | |
69 args = parser.parse_args(argv[1:]) | |
70 | |
71 logging.basicConfig(stream=args.log, level=args.volume, format='%(message)s') | |
72 tone_down_logger() | |
73 | |
74 logging.info('Reading the fasta/q to map read names to barcodes..') | |
75 names_to_barcodes = map_names_to_barcodes(args.reads, args.limit) | |
76 | |
77 logging.info('Reading the SAM to build the graph of barcode relationships..') | |
78 graph, reversed_barcodes = read_alignments(args.sam, names_to_barcodes, args.pos, args.mapq, | |
79 args.dist, args.limit) | |
80 logging.info('{} reversed barcodes'.format(len(reversed_barcodes))) | |
81 | |
82 logging.info('Reading the families.tsv to get the counts of each family..') | |
83 family_counts = get_family_counts(args.families, args.limit) | |
84 | |
85 if args.structures: | |
86 logging.info('Counting the unique barcode networks..') | |
87 structures = count_structures(graph, family_counts) | |
88 print_structures(structures, args.struct_human) | |
89 if args.visualize != 0: | |
90 logging.info('Generating a visualization of barcode networks..') | |
91 visualize([s['graph'] for s in structures], args.visualize, args.viz_format) | |
92 | |
93 logging.info('Building the correction table from the graph..') | |
94 corrections = make_correction_table(graph, family_counts, args.choose_by) | |
95 | |
96 logging.info('Reading the families.tsv again to print corrected output..') | |
97 families = open_as_text_or_gzip(args.families.name) | |
98 print_corrected_output(families, corrections, reversed_barcodes, args.prepend, args.limit, | |
99 args.output) | |
100 | |
101 max_mem = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024 | |
102 logging.info('Max memory usage: {:0.2f}MB'.format(max_mem)) | |
103 | |
104 | |
105 def detect_format(reads_file, max_lines=7): | |
106 """Detect whether a file is a fastq or a fasta, based on its content.""" | |
107 fasta_votes = 0 | |
108 fastq_votes = 0 | |
109 line_num = 0 | |
110 for line in reads_file: | |
111 line_num += 1 | |
112 if line_num % 4 == 1: | |
113 if line.startswith('@'): | |
114 fastq_votes += 1 | |
115 elif line.startswith('>'): | |
116 fasta_votes += 1 | |
117 elif line_num % 4 == 3: | |
118 if line.startswith('+'): | |
119 fastq_votes += 1 | |
120 elif line.startswith('>'): | |
121 fasta_votes += 1 | |
122 if line_num >= max_lines: | |
123 break | |
124 reads_file.seek(0) | |
125 if fasta_votes > fastq_votes: | |
126 return 'fasta' | |
127 elif fastq_votes > fasta_votes: | |
128 return 'fastq' | |
129 else: | |
130 return None | |
131 | |
132 | |
133 def read_fastaq(reads_file): | |
134 filename = reads_file.name | |
135 if filename.endswith('.fa') or filename.endswith('.fasta'): | |
136 format = 'fasta' | |
137 elif filename.endswith('.fq') or filename.endswith('.fastq'): | |
138 format = 'fastq' | |
139 else: | |
140 format = detect_format(reads_file) | |
141 if format == 'fasta': | |
142 return read_fasta(reads_file) | |
143 elif format == 'fastq': | |
144 return read_fastq(reads_file) | |
145 | |
146 | |
147 def read_fasta(reads_file): | |
148 """Read a FASTA file, yielding read names and sequences. | |
149 NOTE: This assumes sequences are only one line!""" | |
150 line_num = 0 | |
151 for line_raw in reads_file: | |
152 line = line_raw.rstrip('\r\n') | |
153 line_num += 1 | |
154 if line_num % 2 == 1: | |
155 assert line.startswith('>'), line | |
156 read_name = line[1:] | |
157 elif line_num % 2 == 0: | |
158 read_seq = line | |
159 yield read_name, read_seq | |
160 | |
161 | |
162 def read_fastq(reads_file): | |
163 """Read a FASTQ file, yielding read names and sequences. | |
164 NOTE: This assumes sequences are only one line!""" | |
165 line_num = 0 | |
166 for line in reads_file: | |
167 line_num += 1 | |
168 if line_num % 4 == 1: | |
169 assert line.startswith('@'), line | |
170 read_name = line[1:].rstrip('\r\n') | |
171 elif line_num % 4 == 2: | |
172 read_seq = line.rstrip('\r\n') | |
173 yield read_name, read_seq | |
174 | |
175 | |
176 def map_names_to_barcodes(reads_file, limit=None): | |
177 """Map barcode names to their sequences.""" | |
178 names_to_barcodes = {} | |
179 read_num = 0 | |
180 for read_name, read_seq in read_fastaq(reads_file): | |
181 read_num += 1 | |
182 if limit is not None and read_num > limit: | |
183 break | |
184 try: | |
185 name = int(read_name) | |
186 except ValueError: | |
187 logging.critical('non-int read name "{}"'.format(name)) | |
188 raise | |
189 names_to_barcodes[name] = read_seq | |
190 reads_file.close() | |
191 return names_to_barcodes | |
192 | |
193 | |
194 def parse_alignment(sam_file, pos_thres, mapq_thres, dist_thres): | |
195 """Parse the SAM file and yield reads that pass the filters. | |
196 Returns (qname, rname, reversed).""" | |
197 line_num = 0 | |
198 for line in sam_file: | |
199 line_num += 1 | |
200 if line.startswith('@'): | |
201 logging.debug('Header line ({})'.format(line_num)) | |
202 continue | |
203 fields = line.split('\t') | |
204 logging.debug('read {} -> ref {} (read seq {}):'.format(fields[2], fields[0], fields[9])) | |
205 qname_str = fields[0] | |
206 rname_str = fields[2] | |
207 rname_fields = rname_str.split(':') | |
208 if len(rname_fields) == 2 and rname_fields[1] == 'rev': | |
209 reversed = True | |
210 rname_str = rname_fields[0] | |
211 else: | |
212 reversed = False | |
213 try: | |
214 qname = int(qname_str) | |
215 rname = int(rname_str) | |
216 except ValueError: | |
217 if fields[2] == '*': | |
218 logging.debug('\tRead unmapped (reference == "*")') | |
219 continue | |
220 else: | |
221 logging.error('Non-integer read name(s) on line {}: "{}", "{}".' | |
222 .format(line_num, qname, rname)) | |
223 raise | |
224 # Apply alignment quality filters. | |
225 try: | |
226 flags = int(fields[1]) | |
227 pos = int(fields[3]) | |
228 mapq = int(fields[4]) | |
229 except ValueError: | |
230 logging.warn('\tNon-integer flag ({}), pos ({}), or mapq ({})' | |
231 .format(fields[1], fields[3], fields[4])) | |
232 continue | |
233 if flags & 4: | |
234 logging.debug('\tRead unmapped (flag & 4 == True)') | |
235 continue | |
236 if abs(pos - 1) > pos_thres: | |
237 logging.debug('\tAlignment failed pos filter: abs({} - 1) > {}'.format(pos, pos_thres)) | |
238 continue | |
239 if mapq < mapq_thres: | |
240 logging.debug('\tAlignment failed mapq filter: {} > {}'.format(mapq, mapq_thres)) | |
241 continue | |
242 nm = None | |
243 for tag in fields[11:]: | |
244 if tag.startswith('NM:i:'): | |
245 try: | |
246 nm = int(tag[5:]) | |
247 except ValueError: | |
248 logging.error('Invalid NM tag "{}" on line {}.'.format(tag, line_num)) | |
249 raise | |
250 break | |
251 assert nm is not None, line_num | |
252 if nm > dist_thres: | |
253 logging.debug('\tAlignment failed NM distance filter: {} > {}'.format(nm, dist_thres)) | |
254 continue | |
255 yield qname, rname, reversed | |
256 sam_file.close() | |
257 | |
258 | |
259 def read_alignments(sam_file, names_to_barcodes, pos_thres, mapq_thres, dist_thres, limit=None): | |
260 """Read the alignments from the SAM file. | |
261 Returns a dict mapping each reference sequence (RNAME) to sets of sequences (QNAMEs) that align to | |
262 it.""" | |
263 graph = networkx.Graph() | |
264 # This is the set of all barcodes which are involved in an alignment where the target is reversed. | |
265 # Whether it's a query or reference sequence in the alignment, it's marked here. | |
266 reversed_barcodes = set() | |
267 # Maps correct barcode numbers to sets of original barcodes (includes correct ones). | |
268 line_num = 0 | |
269 for qname, rname, reversed in parse_alignment(sam_file, pos_thres, mapq_thres, dist_thres): | |
270 line_num += 1 | |
271 if limit is not None and line_num > limit: | |
272 break | |
273 # Skip self-alignments. | |
274 if rname == qname: | |
275 continue | |
276 rseq = names_to_barcodes[rname] | |
277 qseq = names_to_barcodes[qname] | |
278 # Is this an alignment to a reversed barcode? | |
279 if reversed: | |
280 reversed_barcodes.add(rseq) | |
281 reversed_barcodes.add(qseq) | |
282 graph.add_node(rseq) | |
283 graph.add_node(qseq) | |
284 graph.add_edge(rseq, qseq) | |
285 return graph, reversed_barcodes | |
286 | |
287 | |
288 def get_family_counts(families_file, limit=None): | |
289 """For each family (barcode), count how many read pairs exist for each strand (order).""" | |
290 family_counts = {} | |
291 last_barcode = None | |
292 this_family_counts = None | |
293 line_num = 0 | |
294 for line in families_file: | |
295 line_num += 1 | |
296 if limit is not None and line_num > limit: | |
297 break | |
298 fields = line.rstrip('\r\n').split('\t') | |
299 barcode = fields[0] | |
300 order = fields[1] | |
301 if barcode != last_barcode: | |
302 if this_family_counts: | |
303 this_family_counts['all'] = this_family_counts['ab'] + this_family_counts['ba'] | |
304 family_counts[last_barcode] = this_family_counts | |
305 this_family_counts = {'ab':0, 'ba':0} | |
306 last_barcode = barcode | |
307 this_family_counts[order] += 1 | |
308 this_family_counts['all'] = this_family_counts['ab'] + this_family_counts['ba'] | |
309 family_counts[last_barcode] = this_family_counts | |
310 families_file.close() | |
311 return family_counts | |
312 | |
313 | |
314 def make_correction_table(meta_graph, family_counts, choose_by='reads'): | |
315 """Make a table mapping original barcode sequences to correct barcodes. | |
316 Assumes the most connected node in the graph as the correct barcode.""" | |
317 corrections = {} | |
318 for graph in networkx.connected_component_subgraphs(meta_graph): | |
319 if choose_by == 'reads': | |
320 def key(bar): | |
321 return family_counts[bar]['all'] | |
322 elif choose_by == 'connectivity': | |
323 degrees = graph.degree() | |
324 def key(bar): | |
325 return degrees[bar] | |
326 barcodes = sorted(graph.nodes(), key=key, reverse=True) | |
327 correct = barcodes[0] | |
328 for barcode in barcodes: | |
329 if barcode != correct: | |
330 logging.debug('Correcting {} ->\n {}\n'.format(barcode, correct)) | |
331 corrections[barcode] = correct | |
332 return corrections | |
333 | |
334 | |
335 def print_corrected_output(families_file, corrections, reversed_barcodes, prepend=False, limit=None, | |
336 output=True): | |
337 line_num = 0 | |
338 barcode_num = 0 | |
339 barcode_last = None | |
340 corrected = {'reads':0, 'barcodes':0, 'reversed':0} | |
341 reads = [0, 0] | |
342 corrections_in_this_family = 0 | |
343 for line in families_file: | |
344 line_num += 1 | |
345 if limit is not None and line_num > limit: | |
346 break | |
347 fields = line.rstrip('\r\n').split('\t') | |
348 raw_barcode = fields[0] | |
349 order = fields[1] | |
350 if raw_barcode != barcode_last: | |
351 # We just started a new family. | |
352 barcode_num += 1 | |
353 family_info = '{}\t{}\t{}'.format(barcode_last, reads[0], reads[1]) | |
354 if corrections_in_this_family: | |
355 corrected['reads'] += corrections_in_this_family | |
356 corrected['barcodes'] += 1 | |
357 family_info += '\tCORRECTED!' | |
358 else: | |
359 family_info += '\tuncorrected' | |
360 logging.log(VERBOSE, family_info) | |
361 reads = [0, 0] | |
362 corrections_in_this_family = 0 | |
363 barcode_last = raw_barcode | |
364 if order == 'ab': | |
365 reads[0] += 1 | |
366 elif order == 'ba': | |
367 reads[1] += 1 | |
368 if raw_barcode in corrections: | |
369 correct_barcode = corrections[raw_barcode] | |
370 corrections_in_this_family += 1 | |
371 # Check if the order of the barcode reverses in the correct version. | |
372 # First, we check in reversed_barcodes whether either barcode was involved in a reversed | |
373 # alignment, to save time (is_alignment_reversed() does a full smith-waterman alignment). | |
374 if ((raw_barcode in reversed_barcodes or correct_barcode in reversed_barcodes) and | |
375 is_alignment_reversed(raw_barcode, correct_barcode)): | |
376 # If so, then switch the order field. | |
377 corrected['reversed'] += 1 | |
378 if order == 'ab': | |
379 fields[1] = 'ba' | |
380 else: | |
381 fields[1] = 'ab' | |
382 else: | |
383 correct_barcode = raw_barcode | |
384 if prepend: | |
385 fields.insert(0, correct_barcode) | |
386 else: | |
387 fields[0] = correct_barcode | |
388 if output: | |
389 print(*fields, sep='\t') | |
390 families_file.close() | |
391 if corrections_in_this_family: | |
392 corrected['reads'] += corrections_in_this_family | |
393 corrected['barcodes'] += 1 | |
394 logging.info('Corrected {barcodes} barcodes on {reads} read pairs, with {reversed} reversed.' | |
395 .format(**corrected)) | |
396 | |
397 | |
398 def is_alignment_reversed(barcode1, barcode2): | |
399 """Return True if the barcodes are reversed with respect to each other, False otherwise. | |
400 "reversed" in this case meaning the alpha + beta halves are swapped. | |
401 Determine by aligning the two to each other, once in their original forms, and once with the | |
402 second barcode reversed. If the smith-waterman score is higher in the reversed form, return True. | |
403 """ | |
404 half = len(barcode2)//2 | |
405 barcode2_rev = barcode2[half:] + barcode2[:half] | |
406 fwd_align = swalign.smith_waterman(barcode1, barcode2) | |
407 rev_align = swalign.smith_waterman(barcode1, barcode2_rev) | |
408 if rev_align.score > fwd_align.score: | |
409 return True | |
410 else: | |
411 return False | |
412 | |
413 | |
414 def count_structures(meta_graph, family_counts): | |
415 """Count the number of unique (isomorphic) subgraphs in the main graph.""" | |
416 structures = [] | |
417 for graph in networkx.connected_component_subgraphs(meta_graph): | |
418 match = False | |
419 for structure in structures: | |
420 archetype = structure['graph'] | |
421 if networkx.is_isomorphic(graph, archetype): | |
422 match = True | |
423 structure['count'] += 1 | |
424 structure['central'] += int(is_centralized(graph, family_counts)) | |
425 break | |
426 if not match: | |
427 size = len(graph) | |
428 central = is_centralized(graph, family_counts) | |
429 structures.append({'graph':graph, 'size':size, 'count':1, 'central':int(central)}) | |
430 return structures | |
431 | |
432 | |
433 def is_centralized(graph, family_counts): | |
434 """Checks if the graph is centralized in terms of where the reads are located. | |
435 In a centralized graph, the node with the highest degree is the only one which (may) have more | |
436 than one read pair associated with that barcode. | |
437 This returns True if that's the case, False otherwise.""" | |
438 if len(graph) == 2: | |
439 # Special-case graphs with 2 nodes, since the other algorithm doesn't work for them. | |
440 # - When both nodes have a degree of 1, sorting by degree doesn't work and can result in the | |
441 # barcode with more read pairs coming second. | |
442 barcode1, barcode2 = graph.nodes() | |
443 counts1 = family_counts[barcode1] | |
444 counts2 = family_counts[barcode2] | |
445 total1 = counts1['all'] | |
446 total2 = counts2['all'] | |
447 logging.debug('{}: {:3d} ({}/{})\n{}: {:3d} ({}/{})\n' | |
448 .format(barcode1, total1, counts1['ab'], counts1['ba'], | |
449 barcode2, total2, counts2['ab'], counts2['ba'])) | |
450 if (total1 >= 1 and total2 == 1) or (total1 == 1 and total2 >= 1): | |
451 return True | |
452 else: | |
453 return False | |
454 else: | |
455 degrees = graph.degree() | |
456 first = True | |
457 for barcode in sorted(graph.nodes(), key=lambda barcode: degrees[barcode], reverse=True): | |
458 if not first: | |
459 counts = family_counts[barcode] | |
460 # How many read pairs are associated with this barcode (how many times did we see this barcode)? | |
461 try: | |
462 if counts['all'] > 1: | |
463 return False | |
464 except TypeError: | |
465 logging.critical('barcode: {}, counts: {}'.format(barcode, counts)) | |
466 raise | |
467 first = False | |
468 return True | |
469 | |
470 | |
471 def print_structures(structures, human=True): | |
472 # Define a cmp function to sort the list of structures in ascending order of size, but then | |
473 # descending order of count. | |
474 def cmp_fxn(structure1, structure2): | |
475 if structure1['size'] == structure2['size']: | |
476 return structure2['count'] - structure1['count'] | |
477 else: | |
478 return structure1['size'] - structure2['size'] | |
479 width = None | |
480 last_size = None | |
481 for structure in sorted(structures, cmp=cmp_fxn): | |
482 size = structure['size'] | |
483 graph = structure['graph'] | |
484 if size == last_size: | |
485 i += 1 | |
486 else: | |
487 i = 0 | |
488 if width is None: | |
489 width = str(len(str(structure['count']))) | |
490 letters = num_to_letters(i) | |
491 degrees = sorted(graph.degree().values(), reverse=True) | |
492 if human: | |
493 degrees_str = ' '.join(map(str, degrees)) | |
494 else: | |
495 degrees_str = ','.join(map(str, degrees)) | |
496 if human: | |
497 format_str = '{:2d}{:<3s} {count:<'+width+'d} {central:<'+width+'d} {}' | |
498 print(format_str.format(size, letters+':', degrees_str, **structure)) | |
499 else: | |
500 print(size, letters, structure['count'], structure['central'], degrees_str, sep='\t') | |
501 last_size = size | |
502 | |
503 | |
504 def num_to_letters(i): | |
505 """Translate numbers to letters, e.g. 1 -> A, 10 -> J, 100 -> CV""" | |
506 letters = '' | |
507 while i > 0: | |
508 n = (i-1) % 26 | |
509 i = i // 26 | |
510 if n == 25: | |
511 i -= 1 | |
512 letters = chr(65+n) + letters | |
513 return letters | |
514 | |
515 | |
516 def visualize(graphs, viz_path, args_viz_format): | |
517 import matplotlib | |
518 from networkx.drawing.nx_agraph import graphviz_layout | |
519 meta_graph = networkx.Graph() | |
520 for graph in graphs: | |
521 add_graph(meta_graph, graph) | |
522 pos = graphviz_layout(meta_graph) | |
523 networkx.draw(meta_graph, pos) | |
524 if viz_path: | |
525 ext = os.path.splitext(viz_path)[1] | |
526 if ext == '.dot': | |
527 viz_format = 'graphviz' | |
528 elif ext == '.png': | |
529 viz_format = 'png' | |
530 else: | |
531 viz_format = args_viz_format | |
532 if viz_format == 'graphviz': | |
533 from networkx.drawing.nx_pydot import write_dot | |
534 assert viz_path is not None, 'Must provide a filename to --visualize if using --viz-format "graphviz".' | |
535 base_path = os.path.splitext(viz_path) | |
536 write_dot(meta_graph, base_path+'.dot') | |
537 run_command('dot', '-T', 'png', '-o', base_path+'.png', base_path+'.dot') | |
538 logging.info('Wrote image of graph to '+base_path+'.dot') | |
539 elif viz_format == 'png': | |
540 if viz_path is None: | |
541 matplotlib.pyplot.show() | |
542 else: | |
543 matplotlib.pyplot.savefig(viz_path) | |
544 | |
545 | |
546 def add_graph(graph, subgraph): | |
547 # I'm sure there's a function in the library for this, but just cause I need it quick.. | |
548 for node in subgraph.nodes(): | |
549 graph.add_node(node) | |
550 for edge in subgraph.edges(): | |
551 graph.add_edge(*edge) | |
552 return graph | |
553 | |
554 | |
555 def open_as_text_or_gzip(path): | |
556 """Return an open file-like object reading the path as a text file or a gzip file, depending on | |
557 which it looks like.""" | |
558 if detect_gzip(path): | |
559 return gzip.open(path, 'r') | |
560 else: | |
561 return open(path, 'rU') | |
562 | |
563 | |
564 def detect_gzip(path): | |
565 """Return True if the file looks like a gzip file: ends with .gz or contains non-ASCII bytes.""" | |
566 ext = os.path.splitext(path)[1] | |
567 if ext == '.gz': | |
568 return True | |
569 elif ext in ('.txt', '.tsv', '.csv'): | |
570 return False | |
571 with open(path) as fh: | |
572 is_not_ascii = detect_non_ascii(fh.read(100)) | |
573 if is_not_ascii: | |
574 return True | |
575 | |
576 | |
577 def detect_non_ascii(bytes, max_test=100): | |
578 """Return True if any of the first "max_test" bytes are non-ASCII (the high bit set to 1). | |
579 Return False otherwise.""" | |
580 for i, char in enumerate(bytes): | |
581 # Is the high bit a 1? | |
582 if ord(char) & 128: | |
583 return True | |
584 if i >= max_test: | |
585 return False | |
586 return False | |
587 | |
588 | |
589 def run_command(*command): | |
590 try: | |
591 exit_status = subprocess.call(command) | |
592 except subprocess.CalledProcessError as cpe: | |
593 exit_status = cpe.returncode | |
594 except OSError: | |
595 exit_status = None | |
596 return exit_status | |
597 | |
598 | |
599 def tone_down_logger(): | |
600 """Change the logging level names from all-caps to capitalized lowercase. | |
601 E.g. "WARNING" -> "Warning" (turn down the volume a bit in your log files)""" | |
602 for level in (logging.CRITICAL, logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG): | |
603 level_name = logging.getLevelName(level) | |
604 logging.addLevelName(level, level_name.capitalize()) | |
605 | |
606 | |
607 if __name__ == '__main__': | |
608 sys.exit(main(sys.argv)) |