view correct.py @ 18:e4d75f9efb90 draft

planemo upload commit b'4303231da9e48b2719b4429a29b72421d24310f4\n'-dirty
author nick
date Thu, 02 Feb 2017 18:44:31 -0500
parents
children
line wrap: on
line source

#!/usr/bin/env python
from __future__ import division
from __future__ import print_function
import os
import sys
import gzip
import logging
import argparse
import resource
import subprocess
import networkx
import swalign

VERBOSE = (logging.DEBUG+logging.INFO)//2
ARG_DEFAULTS = {'sam':sys.stdin, 'qual':20, 'pos':2, 'dist':1, 'choose_by':'reads', 'output':True,
                'visualize':0, 'viz_format':'png', 'log':sys.stderr, 'volume':logging.WARNING}
USAGE = "%(prog)s [options]"
DESCRIPTION = """Correct barcodes using an alignment of all barcodes to themselves. Reads the
alignment in SAM format and corrects the barcodes in an input "families" file (the output of
make-barcodes.awk). It will print the "families" file to stdout with barcodes (and orders)
corrected."""


def main(argv):

  parser = argparse.ArgumentParser(description=DESCRIPTION)
  parser.set_defaults(**ARG_DEFAULTS)

  parser.add_argument('families', type=open_as_text_or_gzip,
    help='The sorted output of make-barcodes.awk. The important part is that it\'s a tab-delimited '
         'file with at least 2 columns: the barcode sequence and order, and it must be sorted in '
         'the same order as the "reads" in the SAM file.')
  parser.add_argument('reads', type=open_as_text_or_gzip,
    help='The fasta/q file given to the aligner. Used to get barcode sequences from read names.')
  parser.add_argument('sam', type=argparse.FileType('r'), nargs='?',
    help='Barcode alignment, in SAM format. Omit to read from stdin. The read names must be '
         'integers, representing the (1-based) order they appear in the families file.')
  parser.add_argument('-P', '--prepend', action='store_true',
    help='Prepend the corrected barcodes and orders to the original columns.')
  parser.add_argument('-d', '--dist', type=int,
    help='NM edit distance threshold. Default: %(default)s')
  parser.add_argument('-m', '--mapq', type=int,
    help='MAPQ threshold. Default: %(default)s')
  parser.add_argument('-p', '--pos', type=int,
    help='POS tolerance. Alignments will be ignored if abs(POS - 1) is greater than this value. '
         'Set to greater than the barcode length for no threshold. Default: %(default)s')
  parser.add_argument('-t', '--tag-len', type=int,
    help='Length of each half of the barcode. If not given, it will be determined from the first '
         'barcode in the families file.')
  parser.add_argument('-c', '--choose-by', choices=('reads', 'connectivity'))
  parser.add_argument('--limit', type=int,
    help='Limit the number of lines that will be read from each input file, for testing purposes.')
  parser.add_argument('-S', '--structures', action='store_true',
    help='Print a list of the unique isoforms')
  parser.add_argument('--struct-human', action='store_true')
  parser.add_argument('-V', '--visualize', nargs='?',
    help='Produce a visualization of the unique structures write the image to this file. '
         'If you omit a filename, it will be displayed in a window.')
  parser.add_argument('-F', '--viz-format', choices=('dot', 'graphviz', 'png'))
  parser.add_argument('-n', '--no-output', dest='output', action='store_false')
  parser.add_argument('-l', '--log', type=argparse.FileType('w'),
    help='Print log messages to this file instead of to stderr. Warning: Will overwrite the file.')
  parser.add_argument('-q', '--quiet', dest='volume', action='store_const', const=logging.CRITICAL)
  parser.add_argument('-i', '--info', dest='volume', action='store_const', const=logging.INFO)
  parser.add_argument('-v', '--verbose', dest='volume', action='store_const', const=VERBOSE)
  parser.add_argument('-D', '--debug', dest='volume', action='store_const', const=logging.DEBUG,
    help='Print debug messages (very verbose).')

  args = parser.parse_args(argv[1:])

  logging.basicConfig(stream=args.log, level=args.volume, format='%(message)s')
  tone_down_logger()

  logging.info('Reading the fasta/q to map read names to barcodes..')
  names_to_barcodes = map_names_to_barcodes(args.reads, args.limit)

  logging.info('Reading the SAM to build the graph of barcode relationships..')
  graph, reversed_barcodes = read_alignments(args.sam, names_to_barcodes, args.pos, args.mapq,
                                             args.dist, args.limit)
  logging.info('{} reversed barcodes'.format(len(reversed_barcodes)))

  logging.info('Reading the families.tsv to get the counts of each family..')
  family_counts = get_family_counts(args.families, args.limit)

  if args.structures:
    logging.info('Counting the unique barcode networks..')
    structures = count_structures(graph, family_counts)
    print_structures(structures, args.struct_human)
    if args.visualize != 0:
      logging.info('Generating a visualization of barcode networks..')
      visualize([s['graph'] for s in structures], args.visualize, args.viz_format)

  logging.info('Building the correction table from the graph..')
  corrections = make_correction_table(graph, family_counts, args.choose_by)

  logging.info('Reading the families.tsv again to print corrected output..')
  families = open_as_text_or_gzip(args.families.name)
  print_corrected_output(families, corrections, reversed_barcodes, args.prepend, args.limit,
                         args.output)

  max_mem = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024
  logging.info('Max memory usage: {:0.2f}MB'.format(max_mem))


def detect_format(reads_file, max_lines=7):
  """Detect whether a file is a fastq or a fasta, based on its content."""
  fasta_votes = 0
  fastq_votes = 0
  line_num = 0
  for line in reads_file:
    line_num += 1
    if line_num % 4 == 1:
      if line.startswith('@'):
        fastq_votes += 1
      elif line.startswith('>'):
        fasta_votes += 1
    elif line_num % 4 == 3:
      if line.startswith('+'):
        fastq_votes += 1
      elif line.startswith('>'):
        fasta_votes += 1
    if line_num >= max_lines:
      break
  reads_file.seek(0)
  if fasta_votes > fastq_votes:
    return 'fasta'
  elif fastq_votes > fasta_votes:
    return 'fastq'
  else:
    return None


def read_fastaq(reads_file):
  filename = reads_file.name
  if filename.endswith('.fa') or filename.endswith('.fasta'):
    format = 'fasta'
  elif filename.endswith('.fq') or filename.endswith('.fastq'):
    format = 'fastq'
  else:
    format = detect_format(reads_file)
  if format == 'fasta':
    return read_fasta(reads_file)
  elif format == 'fastq':
    return read_fastq(reads_file)


def read_fasta(reads_file):
  """Read a FASTA file, yielding read names and sequences.
  NOTE: This assumes sequences are only one line!"""
  line_num = 0
  for line_raw in reads_file:
    line = line_raw.rstrip('\r\n')
    line_num += 1
    if line_num % 2 == 1:
      assert line.startswith('>'), line
      read_name = line[1:]
    elif line_num % 2 == 0:
      read_seq = line
      yield read_name, read_seq


def read_fastq(reads_file):
  """Read a FASTQ file, yielding read names and sequences.
  NOTE: This assumes sequences are only one line!"""
  line_num = 0
  for line in reads_file:
    line_num += 1
    if line_num % 4 == 1:
      assert line.startswith('@'), line
      read_name = line[1:].rstrip('\r\n')
    elif line_num % 4 == 2:
      read_seq = line.rstrip('\r\n')
      yield read_name, read_seq


def map_names_to_barcodes(reads_file, limit=None):
  """Map barcode names to their sequences."""
  names_to_barcodes = {}
  read_num = 0
  for read_name, read_seq in read_fastaq(reads_file):
    read_num += 1
    if limit is not None and read_num > limit:
      break
    try:
      name = int(read_name)
    except ValueError:
      logging.critical('non-int read name "{}"'.format(name))
      raise
    names_to_barcodes[name] = read_seq
  reads_file.close()
  return names_to_barcodes


def parse_alignment(sam_file, pos_thres, mapq_thres, dist_thres):
  """Parse the SAM file and yield reads that pass the filters.
  Returns (qname, rname, reversed)."""
  line_num = 0
  for line in sam_file:
    line_num += 1
    if line.startswith('@'):
      logging.debug('Header line ({})'.format(line_num))
      continue
    fields = line.split('\t')
    logging.debug('read {} -> ref {} (read seq {}):'.format(fields[2], fields[0], fields[9]))
    qname_str = fields[0]
    rname_str = fields[2]
    rname_fields = rname_str.split(':')
    if len(rname_fields) == 2 and rname_fields[1] == 'rev':
      reversed = True
      rname_str = rname_fields[0]
    else:
      reversed = False
    try:
      qname = int(qname_str)
      rname = int(rname_str)
    except ValueError:
      if fields[2] == '*':
        logging.debug('\tRead unmapped (reference == "*")')
        continue
      else:
        logging.error('Non-integer read name(s) on line {}: "{}", "{}".'
                      .format(line_num, qname, rname))
        raise
    # Apply alignment quality filters.
    try:
      flags = int(fields[1])
      pos = int(fields[3])
      mapq = int(fields[4])
    except ValueError:
      logging.warn('\tNon-integer flag ({}), pos ({}), or mapq ({})'
                   .format(fields[1], fields[3], fields[4]))
      continue
    if flags & 4:
      logging.debug('\tRead unmapped (flag & 4 == True)')
      continue
    if abs(pos - 1) > pos_thres:
      logging.debug('\tAlignment failed pos filter: abs({} - 1) > {}'.format(pos, pos_thres))
      continue
    if mapq < mapq_thres:
      logging.debug('\tAlignment failed mapq filter: {} > {}'.format(mapq, mapq_thres))
      continue
    nm = None
    for tag in fields[11:]:
      if tag.startswith('NM:i:'):
        try:
          nm = int(tag[5:])
        except ValueError:
          logging.error('Invalid NM tag "{}" on line {}.'.format(tag, line_num))
          raise
        break
    assert nm is not None, line_num
    if nm > dist_thres:
      logging.debug('\tAlignment failed NM distance filter: {} > {}'.format(nm, dist_thres))
      continue
    yield qname, rname, reversed
  sam_file.close()


def read_alignments(sam_file, names_to_barcodes, pos_thres, mapq_thres, dist_thres, limit=None):
  """Read the alignments from the SAM file.
  Returns a dict mapping each reference sequence (RNAME) to sets of sequences (QNAMEs) that align to
  it."""
  graph = networkx.Graph()
  # This is the set of all barcodes which are involved in an alignment where the target is reversed.
  # Whether it's a query or reference sequence in the alignment, it's marked here.
  reversed_barcodes = set()
  # Maps correct barcode numbers to sets of original barcodes (includes correct ones).
  line_num = 0
  for qname, rname, reversed in parse_alignment(sam_file, pos_thres, mapq_thres, dist_thres):
    line_num += 1
    if limit is not None and line_num > limit:
      break
    # Skip self-alignments.
    if rname == qname:
      continue
    rseq = names_to_barcodes[rname]
    qseq = names_to_barcodes[qname]
    # Is this an alignment to a reversed barcode?
    if reversed:
      reversed_barcodes.add(rseq)
      reversed_barcodes.add(qseq)
    graph.add_node(rseq)
    graph.add_node(qseq)
    graph.add_edge(rseq, qseq)
  return graph, reversed_barcodes


def get_family_counts(families_file, limit=None):
  """For each family (barcode), count how many read pairs exist for each strand (order)."""
  family_counts = {}
  last_barcode = None
  this_family_counts = None
  line_num = 0
  for line in families_file:
    line_num += 1
    if limit is not None and line_num > limit:
      break
    fields = line.rstrip('\r\n').split('\t')
    barcode = fields[0]
    order = fields[1]
    if barcode != last_barcode:
      if this_family_counts:
        this_family_counts['all'] = this_family_counts['ab'] + this_family_counts['ba']
      family_counts[last_barcode] = this_family_counts
      this_family_counts = {'ab':0, 'ba':0}
      last_barcode = barcode
    this_family_counts[order] += 1
  this_family_counts['all'] = this_family_counts['ab'] + this_family_counts['ba']
  family_counts[last_barcode] = this_family_counts
  families_file.close()
  return family_counts


def make_correction_table(meta_graph, family_counts, choose_by='reads'):
  """Make a table mapping original barcode sequences to correct barcodes.
  Assumes the most connected node in the graph as the correct barcode."""
  corrections = {}
  for graph in networkx.connected_component_subgraphs(meta_graph):
    if choose_by == 'reads':
      def key(bar):
        return family_counts[bar]['all']
    elif choose_by == 'connectivity':
      degrees = graph.degree()
      def key(bar):
        return degrees[bar]
    barcodes = sorted(graph.nodes(), key=key, reverse=True)
    correct = barcodes[0]
    for barcode in barcodes:
      if barcode != correct:
        logging.debug('Correcting {} ->\n           {}\n'.format(barcode, correct))
        corrections[barcode] = correct
  return corrections


def print_corrected_output(families_file, corrections, reversed_barcodes, prepend=False, limit=None,
                           output=True):
  line_num = 0
  barcode_num = 0
  barcode_last = None
  corrected = {'reads':0, 'barcodes':0, 'reversed':0}
  reads = [0, 0]
  corrections_in_this_family = 0
  for line in families_file:
    line_num += 1
    if limit is not None and line_num > limit:
      break
    fields = line.rstrip('\r\n').split('\t')
    raw_barcode = fields[0]
    order = fields[1]
    if raw_barcode != barcode_last:
      # We just started a new family.
      barcode_num += 1
      family_info = '{}\t{}\t{}'.format(barcode_last, reads[0], reads[1])
      if corrections_in_this_family:
        corrected['reads'] += corrections_in_this_family
        corrected['barcodes'] += 1
        family_info += '\tCORRECTED!'
      else:
        family_info += '\tuncorrected'
      logging.log(VERBOSE, family_info)
      reads = [0, 0]
      corrections_in_this_family = 0
      barcode_last = raw_barcode
    if order == 'ab':
      reads[0] += 1
    elif order == 'ba':
      reads[1] += 1
    if raw_barcode in corrections:
      correct_barcode = corrections[raw_barcode]
      corrections_in_this_family += 1
      # Check if the order of the barcode reverses in the correct version.
      # First, we check in reversed_barcodes whether either barcode was involved in a reversed
      # alignment, to save time (is_alignment_reversed() does a full smith-waterman alignment).
      if ((raw_barcode in reversed_barcodes or correct_barcode in reversed_barcodes) and
          is_alignment_reversed(raw_barcode, correct_barcode)):
        # If so, then switch the order field.
        corrected['reversed'] += 1
        if order == 'ab':
          fields[1] = 'ba'
        else:
          fields[1] = 'ab'
    else:
      correct_barcode = raw_barcode
    if prepend:
      fields.insert(0, correct_barcode)
    else:
      fields[0] = correct_barcode
    if output:
      print(*fields, sep='\t')
  families_file.close()
  if corrections_in_this_family:
    corrected['reads'] += corrections_in_this_family
    corrected['barcodes'] += 1
  logging.info('Corrected {barcodes} barcodes on {reads} read pairs, with {reversed} reversed.'
               .format(**corrected))


def is_alignment_reversed(barcode1, barcode2):
  """Return True if the barcodes are reversed with respect to each other, False otherwise.
  "reversed" in this case meaning the alpha + beta halves are swapped.
  Determine by aligning the two to each other, once in their original forms, and once with the
  second barcode reversed. If the smith-waterman score is higher in the reversed form, return True.
  """
  half = len(barcode2)//2
  barcode2_rev = barcode2[half:] + barcode2[:half]
  fwd_align = swalign.smith_waterman(barcode1, barcode2)
  rev_align = swalign.smith_waterman(barcode1, barcode2_rev)
  if rev_align.score > fwd_align.score:
    return True
  else:
    return False


def count_structures(meta_graph, family_counts):
  """Count the number of unique (isomorphic) subgraphs in the main graph."""
  structures = []
  for graph in networkx.connected_component_subgraphs(meta_graph):
    match = False
    for structure in structures:
      archetype = structure['graph']
      if networkx.is_isomorphic(graph, archetype):
        match = True
        structure['count'] += 1
        structure['central'] += int(is_centralized(graph, family_counts))
        break
    if not match:
      size = len(graph)
      central = is_centralized(graph, family_counts)
      structures.append({'graph':graph, 'size':size, 'count':1, 'central':int(central)})
  return structures


def is_centralized(graph, family_counts):
  """Checks if the graph is centralized in terms of where the reads are located.
  In a centralized graph, the node with the highest degree is the only one which (may) have more
  than one read pair associated with that barcode.
  This returns True if that's the case, False otherwise."""
  if len(graph) == 2:
    # Special-case graphs with 2 nodes, since the other algorithm doesn't work for them.
    # - When both nodes have a degree of 1, sorting by degree doesn't work and can result in the
    #   barcode with more read pairs coming second.
    barcode1, barcode2 = graph.nodes()
    counts1 = family_counts[barcode1]
    counts2 = family_counts[barcode2]
    total1 = counts1['all']
    total2 = counts2['all']
    logging.debug('{}: {:3d} ({}/{})\n{}: {:3d} ({}/{})\n'
                  .format(barcode1, total1, counts1['ab'], counts1['ba'],
                          barcode2, total2, counts2['ab'], counts2['ba']))
    if (total1 >= 1 and total2 == 1) or (total1 == 1 and total2 >= 1):
      return True
    else:
      return False
  else:
    degrees = graph.degree()
    first = True
    for barcode in sorted(graph.nodes(), key=lambda barcode: degrees[barcode], reverse=True):
      if not first:
        counts = family_counts[barcode]
        # How many read pairs are associated with this barcode (how many times did we see this barcode)?
        try:
          if counts['all'] > 1:
            return False
        except TypeError:
          logging.critical('barcode: {}, counts: {}'.format(barcode, counts))
          raise
      first = False
    return True


def print_structures(structures, human=True):
  # Define a cmp function to sort the list of structures in ascending order of size, but then
  # descending order of count.
  def cmp_fxn(structure1, structure2):
    if structure1['size'] == structure2['size']:
      return structure2['count'] - structure1['count']
    else:
      return structure1['size'] - structure2['size']
  width = None
  last_size = None
  for structure in sorted(structures, cmp=cmp_fxn):
    size = structure['size']
    graph = structure['graph']
    if size == last_size:
      i += 1
    else:
      i = 0
    if width is None:
      width = str(len(str(structure['count'])))
    letters = num_to_letters(i)
    degrees = sorted(graph.degree().values(), reverse=True)
    if human:
      degrees_str = ' '.join(map(str, degrees))
    else:
      degrees_str = ','.join(map(str, degrees))
    if human:
      format_str = '{:2d}{:<3s} {count:<'+width+'d} {central:<'+width+'d} {}'
      print(format_str.format(size, letters+':', degrees_str, **structure))
    else:
      print(size, letters, structure['count'], structure['central'], degrees_str, sep='\t')
    last_size = size


def num_to_letters(i):
  """Translate numbers to letters, e.g. 1 -> A, 10 -> J, 100 -> CV"""
  letters = ''
  while i > 0:
    n = (i-1) % 26
    i = i // 26
    if n == 25:
      i -= 1
    letters = chr(65+n) + letters
  return letters


def visualize(graphs, viz_path, args_viz_format):
    import matplotlib
    from networkx.drawing.nx_agraph import graphviz_layout
    meta_graph = networkx.Graph()
    for graph in graphs:
      add_graph(meta_graph, graph)
    pos = graphviz_layout(meta_graph)
    networkx.draw(meta_graph, pos)
    if viz_path:
      ext = os.path.splitext(viz_path)[1]
      if ext == '.dot':
        viz_format = 'graphviz'
      elif ext == '.png':
        viz_format = 'png'
    else:
      viz_format = args_viz_format
    if viz_format == 'graphviz':
      from networkx.drawing.nx_pydot import write_dot
      assert viz_path is not None, 'Must provide a filename to --visualize if using --viz-format "graphviz".'
      base_path = os.path.splitext(viz_path)
      write_dot(meta_graph, base_path+'.dot')
      run_command('dot', '-T', 'png', '-o', base_path+'.png', base_path+'.dot')
      logging.info('Wrote image of graph to '+base_path+'.dot')
    elif viz_format == 'png':
      if viz_path is None:
        matplotlib.pyplot.show()
      else:
        matplotlib.pyplot.savefig(viz_path)


def add_graph(graph, subgraph):
  # I'm sure there's a function in the library for this, but just cause I need it quick..
  for node in subgraph.nodes():
    graph.add_node(node)
  for edge in subgraph.edges():
    graph.add_edge(*edge)
  return graph


def open_as_text_or_gzip(path):
  """Return an open file-like object reading the path as a text file or a gzip file, depending on
  which it looks like."""
  if detect_gzip(path):
    return gzip.open(path, 'r')
  else:
    return open(path, 'rU')


def detect_gzip(path):
  """Return True if the file looks like a gzip file: ends with .gz or contains non-ASCII bytes."""
  ext = os.path.splitext(path)[1]
  if ext == '.gz':
    return True
  elif ext in ('.txt', '.tsv', '.csv'):
    return False
  with open(path) as fh:
    is_not_ascii = detect_non_ascii(fh.read(100))
  if is_not_ascii:
    return True


def detect_non_ascii(bytes, max_test=100):
  """Return True if any of the first "max_test" bytes are non-ASCII (the high bit set to 1).
  Return False otherwise."""
  for i, char in enumerate(bytes):
    # Is the high bit a 1?
    if ord(char) & 128:
      return True
    if i >= max_test:
      return False
  return False


def run_command(*command):
  try:
    exit_status = subprocess.call(command)
  except subprocess.CalledProcessError as cpe:
    exit_status = cpe.returncode
  except OSError:
    exit_status = None
  return exit_status


def tone_down_logger():
  """Change the logging level names from all-caps to capitalized lowercase.
  E.g. "WARNING" -> "Warning" (turn down the volume a bit in your log files)"""
  for level in (logging.CRITICAL, logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG):
    level_name = logging.getLevelName(level)
    logging.addLevelName(level, level_name.capitalize())


if __name__ == '__main__':
  sys.exit(main(sys.argv))