diff correct.py @ 18:e4d75f9efb90 draft

planemo upload commit b'4303231da9e48b2719b4429a29b72421d24310f4\n'-dirty
author nick
date Thu, 02 Feb 2017 18:44:31 -0500
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/correct.py	Thu Feb 02 18:44:31 2017 -0500
@@ -0,0 +1,608 @@
+#!/usr/bin/env python
+from __future__ import division
+from __future__ import print_function
+import os
+import sys
+import gzip
+import logging
+import argparse
+import resource
+import subprocess
+import networkx
+import swalign
+
+VERBOSE = (logging.DEBUG+logging.INFO)//2
+ARG_DEFAULTS = {'sam':sys.stdin, 'qual':20, 'pos':2, 'dist':1, 'choose_by':'reads', 'output':True,
+                'visualize':0, 'viz_format':'png', 'log':sys.stderr, 'volume':logging.WARNING}
+USAGE = "%(prog)s [options]"
+DESCRIPTION = """Correct barcodes using an alignment of all barcodes to themselves. Reads the
+alignment in SAM format and corrects the barcodes in an input "families" file (the output of
+make-barcodes.awk). It will print the "families" file to stdout with barcodes (and orders)
+corrected."""
+
+
+def main(argv):
+
+  parser = argparse.ArgumentParser(description=DESCRIPTION)
+  parser.set_defaults(**ARG_DEFAULTS)
+
+  parser.add_argument('families', type=open_as_text_or_gzip,
+    help='The sorted output of make-barcodes.awk. The important part is that it\'s a tab-delimited '
+         'file with at least 2 columns: the barcode sequence and order, and it must be sorted in '
+         'the same order as the "reads" in the SAM file.')
+  parser.add_argument('reads', type=open_as_text_or_gzip,
+    help='The fasta/q file given to the aligner. Used to get barcode sequences from read names.')
+  parser.add_argument('sam', type=argparse.FileType('r'), nargs='?',
+    help='Barcode alignment, in SAM format. Omit to read from stdin. The read names must be '
+         'integers, representing the (1-based) order they appear in the families file.')
+  parser.add_argument('-P', '--prepend', action='store_true',
+    help='Prepend the corrected barcodes and orders to the original columns.')
+  parser.add_argument('-d', '--dist', type=int,
+    help='NM edit distance threshold. Default: %(default)s')
+  parser.add_argument('-m', '--mapq', type=int,
+    help='MAPQ threshold. Default: %(default)s')
+  parser.add_argument('-p', '--pos', type=int,
+    help='POS tolerance. Alignments will be ignored if abs(POS - 1) is greater than this value. '
+         'Set to greater than the barcode length for no threshold. Default: %(default)s')
+  parser.add_argument('-t', '--tag-len', type=int,
+    help='Length of each half of the barcode. If not given, it will be determined from the first '
+         'barcode in the families file.')
+  parser.add_argument('-c', '--choose-by', choices=('reads', 'connectivity'))
+  parser.add_argument('--limit', type=int,
+    help='Limit the number of lines that will be read from each input file, for testing purposes.')
+  parser.add_argument('-S', '--structures', action='store_true',
+    help='Print a list of the unique isoforms')
+  parser.add_argument('--struct-human', action='store_true')
+  parser.add_argument('-V', '--visualize', nargs='?',
+    help='Produce a visualization of the unique structures write the image to this file. '
+         'If you omit a filename, it will be displayed in a window.')
+  parser.add_argument('-F', '--viz-format', choices=('dot', 'graphviz', 'png'))
+  parser.add_argument('-n', '--no-output', dest='output', action='store_false')
+  parser.add_argument('-l', '--log', type=argparse.FileType('w'),
+    help='Print log messages to this file instead of to stderr. Warning: Will overwrite the file.')
+  parser.add_argument('-q', '--quiet', dest='volume', action='store_const', const=logging.CRITICAL)
+  parser.add_argument('-i', '--info', dest='volume', action='store_const', const=logging.INFO)
+  parser.add_argument('-v', '--verbose', dest='volume', action='store_const', const=VERBOSE)
+  parser.add_argument('-D', '--debug', dest='volume', action='store_const', const=logging.DEBUG,
+    help='Print debug messages (very verbose).')
+
+  args = parser.parse_args(argv[1:])
+
+  logging.basicConfig(stream=args.log, level=args.volume, format='%(message)s')
+  tone_down_logger()
+
+  logging.info('Reading the fasta/q to map read names to barcodes..')
+  names_to_barcodes = map_names_to_barcodes(args.reads, args.limit)
+
+  logging.info('Reading the SAM to build the graph of barcode relationships..')
+  graph, reversed_barcodes = read_alignments(args.sam, names_to_barcodes, args.pos, args.mapq,
+                                             args.dist, args.limit)
+  logging.info('{} reversed barcodes'.format(len(reversed_barcodes)))
+
+  logging.info('Reading the families.tsv to get the counts of each family..')
+  family_counts = get_family_counts(args.families, args.limit)
+
+  if args.structures:
+    logging.info('Counting the unique barcode networks..')
+    structures = count_structures(graph, family_counts)
+    print_structures(structures, args.struct_human)
+    if args.visualize != 0:
+      logging.info('Generating a visualization of barcode networks..')
+      visualize([s['graph'] for s in structures], args.visualize, args.viz_format)
+
+  logging.info('Building the correction table from the graph..')
+  corrections = make_correction_table(graph, family_counts, args.choose_by)
+
+  logging.info('Reading the families.tsv again to print corrected output..')
+  families = open_as_text_or_gzip(args.families.name)
+  print_corrected_output(families, corrections, reversed_barcodes, args.prepend, args.limit,
+                         args.output)
+
+  max_mem = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024
+  logging.info('Max memory usage: {:0.2f}MB'.format(max_mem))
+
+
+def detect_format(reads_file, max_lines=7):
+  """Detect whether a file is a fastq or a fasta, based on its content."""
+  fasta_votes = 0
+  fastq_votes = 0
+  line_num = 0
+  for line in reads_file:
+    line_num += 1
+    if line_num % 4 == 1:
+      if line.startswith('@'):
+        fastq_votes += 1
+      elif line.startswith('>'):
+        fasta_votes += 1
+    elif line_num % 4 == 3:
+      if line.startswith('+'):
+        fastq_votes += 1
+      elif line.startswith('>'):
+        fasta_votes += 1
+    if line_num >= max_lines:
+      break
+  reads_file.seek(0)
+  if fasta_votes > fastq_votes:
+    return 'fasta'
+  elif fastq_votes > fasta_votes:
+    return 'fastq'
+  else:
+    return None
+
+
+def read_fastaq(reads_file):
+  filename = reads_file.name
+  if filename.endswith('.fa') or filename.endswith('.fasta'):
+    format = 'fasta'
+  elif filename.endswith('.fq') or filename.endswith('.fastq'):
+    format = 'fastq'
+  else:
+    format = detect_format(reads_file)
+  if format == 'fasta':
+    return read_fasta(reads_file)
+  elif format == 'fastq':
+    return read_fastq(reads_file)
+
+
+def read_fasta(reads_file):
+  """Read a FASTA file, yielding read names and sequences.
+  NOTE: This assumes sequences are only one line!"""
+  line_num = 0
+  for line_raw in reads_file:
+    line = line_raw.rstrip('\r\n')
+    line_num += 1
+    if line_num % 2 == 1:
+      assert line.startswith('>'), line
+      read_name = line[1:]
+    elif line_num % 2 == 0:
+      read_seq = line
+      yield read_name, read_seq
+
+
+def read_fastq(reads_file):
+  """Read a FASTQ file, yielding read names and sequences.
+  NOTE: This assumes sequences are only one line!"""
+  line_num = 0
+  for line in reads_file:
+    line_num += 1
+    if line_num % 4 == 1:
+      assert line.startswith('@'), line
+      read_name = line[1:].rstrip('\r\n')
+    elif line_num % 4 == 2:
+      read_seq = line.rstrip('\r\n')
+      yield read_name, read_seq
+
+
+def map_names_to_barcodes(reads_file, limit=None):
+  """Map barcode names to their sequences."""
+  names_to_barcodes = {}
+  read_num = 0
+  for read_name, read_seq in read_fastaq(reads_file):
+    read_num += 1
+    if limit is not None and read_num > limit:
+      break
+    try:
+      name = int(read_name)
+    except ValueError:
+      logging.critical('non-int read name "{}"'.format(name))
+      raise
+    names_to_barcodes[name] = read_seq
+  reads_file.close()
+  return names_to_barcodes
+
+
+def parse_alignment(sam_file, pos_thres, mapq_thres, dist_thres):
+  """Parse the SAM file and yield reads that pass the filters.
+  Returns (qname, rname, reversed)."""
+  line_num = 0
+  for line in sam_file:
+    line_num += 1
+    if line.startswith('@'):
+      logging.debug('Header line ({})'.format(line_num))
+      continue
+    fields = line.split('\t')
+    logging.debug('read {} -> ref {} (read seq {}):'.format(fields[2], fields[0], fields[9]))
+    qname_str = fields[0]
+    rname_str = fields[2]
+    rname_fields = rname_str.split(':')
+    if len(rname_fields) == 2 and rname_fields[1] == 'rev':
+      reversed = True
+      rname_str = rname_fields[0]
+    else:
+      reversed = False
+    try:
+      qname = int(qname_str)
+      rname = int(rname_str)
+    except ValueError:
+      if fields[2] == '*':
+        logging.debug('\tRead unmapped (reference == "*")')
+        continue
+      else:
+        logging.error('Non-integer read name(s) on line {}: "{}", "{}".'
+                      .format(line_num, qname, rname))
+        raise
+    # Apply alignment quality filters.
+    try:
+      flags = int(fields[1])
+      pos = int(fields[3])
+      mapq = int(fields[4])
+    except ValueError:
+      logging.warn('\tNon-integer flag ({}), pos ({}), or mapq ({})'
+                   .format(fields[1], fields[3], fields[4]))
+      continue
+    if flags & 4:
+      logging.debug('\tRead unmapped (flag & 4 == True)')
+      continue
+    if abs(pos - 1) > pos_thres:
+      logging.debug('\tAlignment failed pos filter: abs({} - 1) > {}'.format(pos, pos_thres))
+      continue
+    if mapq < mapq_thres:
+      logging.debug('\tAlignment failed mapq filter: {} > {}'.format(mapq, mapq_thres))
+      continue
+    nm = None
+    for tag in fields[11:]:
+      if tag.startswith('NM:i:'):
+        try:
+          nm = int(tag[5:])
+        except ValueError:
+          logging.error('Invalid NM tag "{}" on line {}.'.format(tag, line_num))
+          raise
+        break
+    assert nm is not None, line_num
+    if nm > dist_thres:
+      logging.debug('\tAlignment failed NM distance filter: {} > {}'.format(nm, dist_thres))
+      continue
+    yield qname, rname, reversed
+  sam_file.close()
+
+
+def read_alignments(sam_file, names_to_barcodes, pos_thres, mapq_thres, dist_thres, limit=None):
+  """Read the alignments from the SAM file.
+  Returns a dict mapping each reference sequence (RNAME) to sets of sequences (QNAMEs) that align to
+  it."""
+  graph = networkx.Graph()
+  # This is the set of all barcodes which are involved in an alignment where the target is reversed.
+  # Whether it's a query or reference sequence in the alignment, it's marked here.
+  reversed_barcodes = set()
+  # Maps correct barcode numbers to sets of original barcodes (includes correct ones).
+  line_num = 0
+  for qname, rname, reversed in parse_alignment(sam_file, pos_thres, mapq_thres, dist_thres):
+    line_num += 1
+    if limit is not None and line_num > limit:
+      break
+    # Skip self-alignments.
+    if rname == qname:
+      continue
+    rseq = names_to_barcodes[rname]
+    qseq = names_to_barcodes[qname]
+    # Is this an alignment to a reversed barcode?
+    if reversed:
+      reversed_barcodes.add(rseq)
+      reversed_barcodes.add(qseq)
+    graph.add_node(rseq)
+    graph.add_node(qseq)
+    graph.add_edge(rseq, qseq)
+  return graph, reversed_barcodes
+
+
+def get_family_counts(families_file, limit=None):
+  """For each family (barcode), count how many read pairs exist for each strand (order)."""
+  family_counts = {}
+  last_barcode = None
+  this_family_counts = None
+  line_num = 0
+  for line in families_file:
+    line_num += 1
+    if limit is not None and line_num > limit:
+      break
+    fields = line.rstrip('\r\n').split('\t')
+    barcode = fields[0]
+    order = fields[1]
+    if barcode != last_barcode:
+      if this_family_counts:
+        this_family_counts['all'] = this_family_counts['ab'] + this_family_counts['ba']
+      family_counts[last_barcode] = this_family_counts
+      this_family_counts = {'ab':0, 'ba':0}
+      last_barcode = barcode
+    this_family_counts[order] += 1
+  this_family_counts['all'] = this_family_counts['ab'] + this_family_counts['ba']
+  family_counts[last_barcode] = this_family_counts
+  families_file.close()
+  return family_counts
+
+
+def make_correction_table(meta_graph, family_counts, choose_by='reads'):
+  """Make a table mapping original barcode sequences to correct barcodes.
+  Assumes the most connected node in the graph as the correct barcode."""
+  corrections = {}
+  for graph in networkx.connected_component_subgraphs(meta_graph):
+    if choose_by == 'reads':
+      def key(bar):
+        return family_counts[bar]['all']
+    elif choose_by == 'connectivity':
+      degrees = graph.degree()
+      def key(bar):
+        return degrees[bar]
+    barcodes = sorted(graph.nodes(), key=key, reverse=True)
+    correct = barcodes[0]
+    for barcode in barcodes:
+      if barcode != correct:
+        logging.debug('Correcting {} ->\n           {}\n'.format(barcode, correct))
+        corrections[barcode] = correct
+  return corrections
+
+
+def print_corrected_output(families_file, corrections, reversed_barcodes, prepend=False, limit=None,
+                           output=True):
+  line_num = 0
+  barcode_num = 0
+  barcode_last = None
+  corrected = {'reads':0, 'barcodes':0, 'reversed':0}
+  reads = [0, 0]
+  corrections_in_this_family = 0
+  for line in families_file:
+    line_num += 1
+    if limit is not None and line_num > limit:
+      break
+    fields = line.rstrip('\r\n').split('\t')
+    raw_barcode = fields[0]
+    order = fields[1]
+    if raw_barcode != barcode_last:
+      # We just started a new family.
+      barcode_num += 1
+      family_info = '{}\t{}\t{}'.format(barcode_last, reads[0], reads[1])
+      if corrections_in_this_family:
+        corrected['reads'] += corrections_in_this_family
+        corrected['barcodes'] += 1
+        family_info += '\tCORRECTED!'
+      else:
+        family_info += '\tuncorrected'
+      logging.log(VERBOSE, family_info)
+      reads = [0, 0]
+      corrections_in_this_family = 0
+      barcode_last = raw_barcode
+    if order == 'ab':
+      reads[0] += 1
+    elif order == 'ba':
+      reads[1] += 1
+    if raw_barcode in corrections:
+      correct_barcode = corrections[raw_barcode]
+      corrections_in_this_family += 1
+      # Check if the order of the barcode reverses in the correct version.
+      # First, we check in reversed_barcodes whether either barcode was involved in a reversed
+      # alignment, to save time (is_alignment_reversed() does a full smith-waterman alignment).
+      if ((raw_barcode in reversed_barcodes or correct_barcode in reversed_barcodes) and
+          is_alignment_reversed(raw_barcode, correct_barcode)):
+        # If so, then switch the order field.
+        corrected['reversed'] += 1
+        if order == 'ab':
+          fields[1] = 'ba'
+        else:
+          fields[1] = 'ab'
+    else:
+      correct_barcode = raw_barcode
+    if prepend:
+      fields.insert(0, correct_barcode)
+    else:
+      fields[0] = correct_barcode
+    if output:
+      print(*fields, sep='\t')
+  families_file.close()
+  if corrections_in_this_family:
+    corrected['reads'] += corrections_in_this_family
+    corrected['barcodes'] += 1
+  logging.info('Corrected {barcodes} barcodes on {reads} read pairs, with {reversed} reversed.'
+               .format(**corrected))
+
+
+def is_alignment_reversed(barcode1, barcode2):
+  """Return True if the barcodes are reversed with respect to each other, False otherwise.
+  "reversed" in this case meaning the alpha + beta halves are swapped.
+  Determine by aligning the two to each other, once in their original forms, and once with the
+  second barcode reversed. If the smith-waterman score is higher in the reversed form, return True.
+  """
+  half = len(barcode2)//2
+  barcode2_rev = barcode2[half:] + barcode2[:half]
+  fwd_align = swalign.smith_waterman(barcode1, barcode2)
+  rev_align = swalign.smith_waterman(barcode1, barcode2_rev)
+  if rev_align.score > fwd_align.score:
+    return True
+  else:
+    return False
+
+
+def count_structures(meta_graph, family_counts):
+  """Count the number of unique (isomorphic) subgraphs in the main graph."""
+  structures = []
+  for graph in networkx.connected_component_subgraphs(meta_graph):
+    match = False
+    for structure in structures:
+      archetype = structure['graph']
+      if networkx.is_isomorphic(graph, archetype):
+        match = True
+        structure['count'] += 1
+        structure['central'] += int(is_centralized(graph, family_counts))
+        break
+    if not match:
+      size = len(graph)
+      central = is_centralized(graph, family_counts)
+      structures.append({'graph':graph, 'size':size, 'count':1, 'central':int(central)})
+  return structures
+
+
+def is_centralized(graph, family_counts):
+  """Checks if the graph is centralized in terms of where the reads are located.
+  In a centralized graph, the node with the highest degree is the only one which (may) have more
+  than one read pair associated with that barcode.
+  This returns True if that's the case, False otherwise."""
+  if len(graph) == 2:
+    # Special-case graphs with 2 nodes, since the other algorithm doesn't work for them.
+    # - When both nodes have a degree of 1, sorting by degree doesn't work and can result in the
+    #   barcode with more read pairs coming second.
+    barcode1, barcode2 = graph.nodes()
+    counts1 = family_counts[barcode1]
+    counts2 = family_counts[barcode2]
+    total1 = counts1['all']
+    total2 = counts2['all']
+    logging.debug('{}: {:3d} ({}/{})\n{}: {:3d} ({}/{})\n'
+                  .format(barcode1, total1, counts1['ab'], counts1['ba'],
+                          barcode2, total2, counts2['ab'], counts2['ba']))
+    if (total1 >= 1 and total2 == 1) or (total1 == 1 and total2 >= 1):
+      return True
+    else:
+      return False
+  else:
+    degrees = graph.degree()
+    first = True
+    for barcode in sorted(graph.nodes(), key=lambda barcode: degrees[barcode], reverse=True):
+      if not first:
+        counts = family_counts[barcode]
+        # How many read pairs are associated with this barcode (how many times did we see this barcode)?
+        try:
+          if counts['all'] > 1:
+            return False
+        except TypeError:
+          logging.critical('barcode: {}, counts: {}'.format(barcode, counts))
+          raise
+      first = False
+    return True
+
+
+def print_structures(structures, human=True):
+  # Define a cmp function to sort the list of structures in ascending order of size, but then
+  # descending order of count.
+  def cmp_fxn(structure1, structure2):
+    if structure1['size'] == structure2['size']:
+      return structure2['count'] - structure1['count']
+    else:
+      return structure1['size'] - structure2['size']
+  width = None
+  last_size = None
+  for structure in sorted(structures, cmp=cmp_fxn):
+    size = structure['size']
+    graph = structure['graph']
+    if size == last_size:
+      i += 1
+    else:
+      i = 0
+    if width is None:
+      width = str(len(str(structure['count'])))
+    letters = num_to_letters(i)
+    degrees = sorted(graph.degree().values(), reverse=True)
+    if human:
+      degrees_str = ' '.join(map(str, degrees))
+    else:
+      degrees_str = ','.join(map(str, degrees))
+    if human:
+      format_str = '{:2d}{:<3s} {count:<'+width+'d} {central:<'+width+'d} {}'
+      print(format_str.format(size, letters+':', degrees_str, **structure))
+    else:
+      print(size, letters, structure['count'], structure['central'], degrees_str, sep='\t')
+    last_size = size
+
+
+def num_to_letters(i):
+  """Translate numbers to letters, e.g. 1 -> A, 10 -> J, 100 -> CV"""
+  letters = ''
+  while i > 0:
+    n = (i-1) % 26
+    i = i // 26
+    if n == 25:
+      i -= 1
+    letters = chr(65+n) + letters
+  return letters
+
+
+def visualize(graphs, viz_path, args_viz_format):
+    import matplotlib
+    from networkx.drawing.nx_agraph import graphviz_layout
+    meta_graph = networkx.Graph()
+    for graph in graphs:
+      add_graph(meta_graph, graph)
+    pos = graphviz_layout(meta_graph)
+    networkx.draw(meta_graph, pos)
+    if viz_path:
+      ext = os.path.splitext(viz_path)[1]
+      if ext == '.dot':
+        viz_format = 'graphviz'
+      elif ext == '.png':
+        viz_format = 'png'
+    else:
+      viz_format = args_viz_format
+    if viz_format == 'graphviz':
+      from networkx.drawing.nx_pydot import write_dot
+      assert viz_path is not None, 'Must provide a filename to --visualize if using --viz-format "graphviz".'
+      base_path = os.path.splitext(viz_path)
+      write_dot(meta_graph, base_path+'.dot')
+      run_command('dot', '-T', 'png', '-o', base_path+'.png', base_path+'.dot')
+      logging.info('Wrote image of graph to '+base_path+'.dot')
+    elif viz_format == 'png':
+      if viz_path is None:
+        matplotlib.pyplot.show()
+      else:
+        matplotlib.pyplot.savefig(viz_path)
+
+
+def add_graph(graph, subgraph):
+  # I'm sure there's a function in the library for this, but just cause I need it quick..
+  for node in subgraph.nodes():
+    graph.add_node(node)
+  for edge in subgraph.edges():
+    graph.add_edge(*edge)
+  return graph
+
+
+def open_as_text_or_gzip(path):
+  """Return an open file-like object reading the path as a text file or a gzip file, depending on
+  which it looks like."""
+  if detect_gzip(path):
+    return gzip.open(path, 'r')
+  else:
+    return open(path, 'rU')
+
+
+def detect_gzip(path):
+  """Return True if the file looks like a gzip file: ends with .gz or contains non-ASCII bytes."""
+  ext = os.path.splitext(path)[1]
+  if ext == '.gz':
+    return True
+  elif ext in ('.txt', '.tsv', '.csv'):
+    return False
+  with open(path) as fh:
+    is_not_ascii = detect_non_ascii(fh.read(100))
+  if is_not_ascii:
+    return True
+
+
+def detect_non_ascii(bytes, max_test=100):
+  """Return True if any of the first "max_test" bytes are non-ASCII (the high bit set to 1).
+  Return False otherwise."""
+  for i, char in enumerate(bytes):
+    # Is the high bit a 1?
+    if ord(char) & 128:
+      return True
+    if i >= max_test:
+      return False
+  return False
+
+
+def run_command(*command):
+  try:
+    exit_status = subprocess.call(command)
+  except subprocess.CalledProcessError as cpe:
+    exit_status = cpe.returncode
+  except OSError:
+    exit_status = None
+  return exit_status
+
+
+def tone_down_logger():
+  """Change the logging level names from all-caps to capitalized lowercase.
+  E.g. "WARNING" -> "Warning" (turn down the volume a bit in your log files)"""
+  for level in (logging.CRITICAL, logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG):
+    level_name = logging.getLevelName(level)
+    logging.addLevelName(level, level_name.capitalize())
+
+
+if __name__ == '__main__':
+  sys.exit(main(sys.argv))