diff utils/sim.py @ 18:e4d75f9efb90 draft

planemo upload commit b'4303231da9e48b2719b4429a29b72421d24310f4\n'-dirty
author nick
date Thu, 02 Feb 2017 18:44:31 -0500
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/utils/sim.py	Thu Feb 02 18:44:31 2017 -0500
@@ -0,0 +1,706 @@
+#!/usr/bin/env python
+from __future__ import division
+from __future__ import print_function
+import re
+import os
+import sys
+import copy
+import numpy
+import bisect
+import random
+import string
+import numbers
+import tempfile
+import argparse
+import subprocess
+import fastqreader
+
+REVCOMP_TABLE = string.maketrans('acgtrymkbdhvACGTRYMKBDHV', 'tgcayrkmvhdbTGCAYRKMVHDB')
+WGSIM_ID_REGEX = r'^(.+)_(\d+)_(\d+)_\d+:\d+:\d+_\d+:\d+:\d+_([0-9a-f]+)/[12]$'
+ARG_DEFAULTS = {'read_len':100, 'frag_len':400, 'n_frags':1000, 'out_format':'fasta',
+                'seq_error':0.001, 'pcr_error':0.001, 'cycles':25, 'indel_rate':0.15,
+                'ext_rate':0.3, 'seed':None, 'invariant':'TGACT', 'bar_len':12, 'fastq_qual':'I'}
+USAGE = "%(prog)s [options]"
+DESCRIPTION = """Simulate a duplex sequencing experiment."""
+
+RAW_DISTRIBUTION = (
+  #  0     1     2     3     4     5     6     7     8     9
+  # Low singletons, but then constant drop-off. From pML113 (see 2015-09-28 report).
+  #  0,  100,   36,   31,   27,   22,   17,   12,    7,  4.3,
+  #2.4,  1.2,  0.6,  0.3,  0.2, 0.15,  0.1, 0.07, 0.05, 0.03,
+  # High singletons, but then a second peak around 10. From Christine plasmid (2015-10-06 report).
+  #    0,  100, 5.24, 3.67, 3.50, 3.67, 3.85, 4.02, 4.11, 4.20,
+  # 4.17, 4.10, 4.00, 3.85, 3.69, 3.55, 3.38, 3.15, 2.92, 2.62,
+  # 2.27, 2.01, 1.74, 1.56, 1.38, 1.20, 1.02, 0.85,
+  # Same as above, but low singletons, 2's, and 3's (rely on errors to fill out those).
+     0,    1,    2,    3, 3.50, 3.67, 3.85, 4.02, 4.11, 4.20,
+  4.17, 4.10, 4.00, 3.85, 3.69, 3.55, 3.38, 3.15, 2.92, 2.62,
+  2.27, 2.01, 1.74, 1.56, 1.38, 1.20, 1.02, 0.85,
+)
+
+
+def main(argv):
+
+  parser = argparse.ArgumentParser(description=DESCRIPTION)
+  parser.set_defaults(**ARG_DEFAULTS)
+
+  parser.add_argument('ref', metavar='ref.fa', nargs='?',
+    help='Reference sequence. Omit if giving --frag-file.')
+  parser.add_argument('out1', type=argparse.FileType('w'),
+    help='Write final mate 1 reads to this file.')
+  parser.add_argument('out2', type=argparse.FileType('w'),
+    help='Write final mate 2 reads to this file.')
+  parser.add_argument('-o', '--out-format', choices=('fastq', 'fasta'))
+  parser.add_argument('--stdout', action='store_true',
+    help='Print interleaved output reads to stdout.')
+  parser.add_argument('-m', '--mutations', type=argparse.FileType('w'),
+    help='Write a log of the PCR and sequencing errors introduced to this file. Will overwrite any '
+         'existing file at this path.')
+  parser.add_argument('-b', '--barcodes', type=argparse.FileType('w'),
+    help='Write a log of which barcodes were ligated to which fragments. Will overwrite any '
+         'existing file at this path.')
+  parser.add_argument('--frag-file',
+    help='The path of the FASTQ file of fragments. If --ref is given, these will be generated with '
+         'wgsim and kept (normally a temporary file is used, then deleted). Note: the file will be '
+         'overwritten! If --ref is not given, then this should be a file of already generated '
+         'fragments, and they will be used instead of generating new ones.')
+  parser.add_argument('-Q', '--fastq-qual',
+    help='The quality score to assign to all bases in FASTQ output. Give a character or PHRED '
+         'score (integer). A PHRED score will be converted using the Sanger offset (33). Default: '
+         '"%(default)s"')
+  parser.add_argument('-S', '--seed', type=int,
+    help='Random number generator seed. By default, a random, 32-bit seed will be generated and '
+         'logged to stdout.')
+  params = parser.add_argument_group('simulation parameters')
+  params.add_argument('-n', '--n-frags', type=int,
+    help='The number of original fragment molecules to simulate. The final number of reads will be '
+         'this multiplied by the average number of reads per family. If you provide fragments with '
+         '--frag-file, the script will still only read in the number specified here. Default: '
+         '%(default)s')
+  params.add_argument('-r', '--read-len', type=int,
+    help='Default: %(default)s')
+  params.add_argument('-f', '--frag-len', type=int,
+    help='Default: %(default)s')
+  params.add_argument('-s', '--seq-error', type=float,
+    help='Sequencing error rate per base (0-1 proportion, not percent). Default: %(default)s')
+  params.add_argument('-p', '--pcr-error', type=float,
+    help='PCR error rate per base (0-1 proportion, not percent). Default: %(default)s')
+  params.add_argument('-c', '--cycles', type=int,
+    help='Number of PCR cycles to simulate. Default: %(default)s')
+  params.add_argument('-i', '--indel-rate', type=float,
+    help='Fraction of errors which are indels. Default: %(default)s')
+  params.add_argument('-E', '--extension-rate', dest='ext_rate', type=float,
+    help='Probability an indel is extended. Default: %(default)s')
+  params.add_argument('-B', '--bar-len', type=int,
+    help='Length of the barcodes to generate. Default: %(default)s')
+  params.add_argument('-I', '--invariant',
+    help='The invariant linker sequence between the barcode and sample sequence in each read. '
+         'Default: %(default)s')
+
+  # Parse and interpret arguments.
+  args = parser.parse_args(argv[1:])
+  assert args.ref or args.frag_file, 'You must provide either a reference or fragments file.'
+  if args.seed is None:
+    seed = random.randint(0, 2**31-1)
+    sys.stderr.write('seed: {}\n'.format(seed))
+  else:
+    seed = args.seed
+  random.seed(seed)
+  if args.stdout:
+    out1 = sys.stdout
+    out2 = sys.stdout
+  else:
+    out1 = args.out1
+    out2 = args.out2
+  if isinstance(args.fastq_qual, numbers.Integral):
+    assert args.fastq_qual >= 0, '--fastq-qual cannot be negative.'
+    fastq_qual = chr(args.fastq_qual + 33)
+  elif isinstance(args.fastq_qual, basestring):
+    assert len(args.fastq_qual) == 1, '--fastq-qual cannot be more than a single character.'
+    fastq_qual = args.fastq_qual
+  else:
+    raise AssertionError('--fastq-qual must be a positive integer or single character.')
+  qual_line = fastq_qual * args.read_len
+
+  invariant_rc = get_revcomp(args.invariant)
+
+  # Create a temporary director to do our work in. Then work inside a try so we can finally remove
+  # the directory no matter what exceptions are encountered.
+  tmpfile = tempfile.NamedTemporaryFile(prefix='wgdsim.frags.')
+  tmpfile.close()
+  try:
+    # Step 1: Use wgsim to create fragments from the reference.
+    if args.frag_file:
+      frag_file = args.frag_file
+    else:
+      frag_file = tmpfile.name
+    if args.ref and os.path.isfile(args.ref) and os.path.getsize(args.ref):
+      #TODO: Check exit status
+      #TODO: Check for wgsim on the PATH.
+      # Set error and mutation rates to 0 to just slice sequences out of the reference without
+      # modification.
+      run_command('wgsim', '-e', '0', '-r', '0', '-d', '0', '-R', args.indel_rate, '-S', seed,
+                  '-N', args.n_frags, '-X', args.ext_rate, '-1', args.frag_len,
+                  args.ref, frag_file, os.devnull)
+
+    # NOTE: Coordinates here are 0-based (0 is the first base in the sequence).
+    extended_dist = extend_dist(RAW_DISTRIBUTION)
+    proportional_dist = compile_dist(extended_dist)
+    n_frags = 0
+    for raw_fragment in fastqreader.FastqReadGenerator(frag_file):
+      n_frags += 1
+      if n_frags > args.n_frags:
+        break
+      chrom, id_num, start, stop = parse_read_id(raw_fragment.id)
+      barcode1 = get_rand_seq(args.bar_len)
+      barcode2 = get_rand_seq(args.bar_len)
+      barcode2_rc = get_revcomp(barcode2)
+      raw_frag_full = barcode1 + args.invariant + raw_fragment.seq + invariant_rc + barcode2
+
+      # Step 2: Determine how many reads to produce from each fragment.
+      # - Use random.random() and divide the range 0-1 into segments of sizes proportional to
+      #   the likelihood of each family size.
+      # bisect.bisect() finds where an element belongs in a sorted list, returning the index.
+      # proportional_dist is just such a sorted list, with values from 0 to 1.
+      n_reads = bisect.bisect(proportional_dist, random.random())
+
+      # Step 3: Introduce PCR errors.
+      # - Determine the mutations and their frequencies.
+      #   - Could get frequency from the cycle of PCR it occurs in.
+      #     - Important to have PCR errors shared between reads.
+      # - For each read, determine which mutations it contains.
+      #   - Use random.random() < mut_freq.
+      tree = get_good_pcr_tree(n_reads, args.cycles, 1000, max_diff=1)
+      # Add errors to all children of original fragment.
+      subtree1 = tree.get('child1')
+      subtree2 = tree.get('child2')
+      #TODO: Only simulate errors on portions of fragment that will become reads.
+      add_pcr_errors(subtree1, '+', len(raw_frag_full), args.pcr_error, args.indel_rate, args.ext_rate)
+      add_pcr_errors(subtree2, '-', len(raw_frag_full), args.pcr_error, args.indel_rate, args.ext_rate)
+      apply_pcr_errors(tree, raw_frag_full)
+      fragments = get_final_fragments(tree)
+      add_mutation_lists(tree, fragments, [])
+
+      # Step 4: Introduce sequencing errors.
+      for fragment in fragments.values():
+        for mutation in generate_mutations(args.read_len, args.seq_error, args.indel_rate,
+                                           args.ext_rate):
+          fragment['mutations'].append(mutation)
+          fragment['seq'] = apply_mutation(mutation, fragment['seq'])
+
+      # Print barcodes to log file.
+      if args.barcodes:
+        args.barcodes.write('{}-{}\t{}\t{}\n'.format(chrom, id_num, barcode1, barcode2_rc))
+      # Print family.
+      for frag_id in sorted(fragments.keys()):
+        fragment = fragments[frag_id]
+        read_id = '{}-{}-{}'.format(chrom, id_num, frag_id)
+        # Print mutations to log file.
+        if args.mutations:
+          read1_muts = get_mutations_subset(fragment['mutations'], 0, args.read_len)
+          read2_muts = get_mutations_subset(fragment['mutations'], 0, args.read_len, revcomp=True,
+                                            seqlen=len(fragment['seq']))
+          if fragment['strand'] == '-':
+            read1_muts, read2_muts = read2_muts, read1_muts
+          log_mutations(args.mutations, read1_muts, read_id+'/1', chrom, start, stop)
+          log_mutations(args.mutations, read2_muts, read_id+'/2', chrom, start, stop)
+        frag_seq = fragment['seq']
+        read1_seq = frag_seq[:args.read_len]
+        read2_seq = get_revcomp(frag_seq[len(frag_seq)-args.read_len:])
+        if fragment['strand'] == '-':
+          read1_seq, read2_seq = read2_seq, read1_seq
+        if args.out_format == 'fasta':
+          out1.write('>{}\n{}\n'.format(read_id, read1_seq))
+          out2.write('>{}\n{}\n'.format(read_id, read2_seq))
+        elif args.out_format == 'fastq':
+          out1.write('@{}\n{}\n+\n{}\n'.format(read_id, read1_seq, qual_line))
+          out2.write('@{}\n{}\n+\n{}\n'.format(read_id, read2_seq, qual_line))
+
+  finally:
+    try:
+      os.remove(tmpfile.name)
+    except OSError:
+      pass
+
+
+def run_command(*command, **kwargs):
+  """Run a command and return the exit code.
+  run_command('echo', 'hello')
+  Will print the command to stderr before running, unless "silent" is set to True."""
+  command_strs = map(str, command)
+  if not kwargs.get('silent'):
+    sys.stderr.write('$ '+' '.join(command_strs)+'\n')
+  devnull = open(os.devnull, 'w')
+  try:
+    exit_status = subprocess.call(map(str, command), stderr=devnull)
+  except OSError:
+    exit_status = None
+  finally:
+    devnull.close()
+  return exit_status
+
+
+def extend_dist(raw_dist, exponent=1.25, min_prob=0.00001, max_len_mult=2):
+  """Add an exponentially decreasing tail to the distribution.
+  It takes the final value in the distribution and keeps dividing it by
+  "exponent", adding each new value to the end. It will not add probabilities
+  smaller than "min_prob" or extend the length of the list by more than
+  "max_len_mult" times."""
+  extended_dist = list(raw_dist)
+  final_sum = sum(raw_dist)
+  value = raw_dist[-1]
+  value /= exponent
+  while value/final_sum >= min_prob and len(extended_dist) < len(raw_dist)*max_len_mult:
+    extended_dist.append(value)
+    final_sum += value
+    value /= exponent
+  return extended_dist
+
+
+def compile_dist(raw_dist):
+  """Turn the human-readable list of probabilities defined at the top into
+  proportional probabilities.
+  E.g. [10, 5, 5] -> [0.5, 0.75, 1.0]"""
+  proportional_dist = []
+  final_sum = sum(raw_dist)
+  current_sum = 0
+  for magnitude in raw_dist:
+    current_sum += magnitude
+    proportional_dist.append(current_sum/final_sum)
+  return proportional_dist
+
+
+def parse_read_id(read_id):
+  match = re.search(WGSIM_ID_REGEX, read_id)
+  if match:
+    chrom = match.group(1)
+    start = match.group(2)
+    stop = match.group(3)
+    id_num = match.group(4)
+  else:
+    chrom, id_num, start, stop = read_id, None, None, None
+  return chrom, id_num, start, stop
+
+
+#TODO: Clean up "mutation" vs "error" terminology.
+def generate_mutations(seq_len, error_rate, indel_rate, extension_rate):
+  """Generate all the mutations that occur over the length of a sequence."""
+  i = 0
+  while i <= seq_len:
+    if random.random() < error_rate:
+      mtype, alt = make_mutation(indel_rate, extension_rate)
+      # Allow mutation after the last base only if it's an insertion.
+      if i < seq_len or mtype == 'ins':
+        yield {'coord':i, 'type':mtype, 'alt':alt}
+      # Compensate for length variations to keep i tracking the original read's base coordinates.
+      if mtype == 'ins':
+        i += len(alt)
+      elif mtype == 'del':
+        i -= alt
+    i += 1
+
+
+def make_mutation(indel_rate, extension_rate):
+  """Simulate a random mutation."""
+  # Is it an indel?
+  rand = random.random()
+  if rand < indel_rate:
+    # Is it an insertion or deletion? Decide, then initialize it.
+    # Re-use the random number from above. Just check if it's in the lower or upper half of the
+    # range from 0 to indel_rate.
+    if rand < indel_rate/2:
+      mtype = 'del'
+      alt = 1
+    else:
+      mtype = 'ins'
+      alt = get_rand_base()
+    # Extend the indel as long as the extension rate allows.
+    while random.random() < extension_rate:
+      if mtype == 'ins':
+        alt += get_rand_base()
+      else:
+        alt += 1
+  else:
+    # What is the new base for the SNV?
+    mtype = 'snv'
+    alt = get_rand_base()
+  return mtype, alt
+
+
+def get_rand_base(bases='ACGT'):
+  return random.choice(bases)
+
+
+def get_rand_seq(seq_len):
+  return ''.join([get_rand_base() for i in range(seq_len)])
+
+
+def get_revcomp(seq):
+  return seq.translate(REVCOMP_TABLE)[::-1]
+
+
+def apply_mutation(mut, seq):
+  i = mut['coord']
+  if mut['type'] == 'snv':
+    # Replace the base at "coord".
+    new_seq = seq[:i] + mut['alt'] + seq[i+1:]
+  else:
+    # Indels are handled by inserting or deleting bases starting *before* the base at "coord".
+    # This goes agains the VCF convention, but it allows deleting the first and last base, as well
+    # as inserting before and after the sequence without as much special-casing.
+    if mut['type'] == 'ins':
+      # Example: 'ACGTACGT' + ins 'GC' at 4 = 'ACGTGCACGT'
+      new_seq = seq[:i] + mut['alt'] + seq[i:]
+    else:
+      # Example: 'ACGTACGT' + del 2 at 4 = 'ACGTGT'
+      new_seq = seq[:i] + seq[i+mut['alt']:]
+  return new_seq
+
+
+def get_mutations_subset(mutations_old, start, length, revcomp=False, seqlen=None):
+  """Get a list of the input mutations which are within a certain region.
+  The output list maintains the order in the input list, only filtering out
+  mutations outside the specified region.
+  "start" is the start of the region (0-based). If revcomp, this start should be
+  in the coordinate system of the reverse-complemented sequence.
+  "length" is the length of the region.
+  "revcomp" causes the mutations to be converted to their reverse complements, and
+  the "start" to refer to the reverse complement sequence's coordinates. The order
+  of the mutations is unchanged, though.
+  "seqlen" is the length of the sequence the mutations occurred in. This is only
+  needed when revcomp is True, to convert coordinates to the reverse complement
+  coordinate system."""
+  stop = start + length
+  mutations_new = []
+  for mutation in mutations_old:
+    if revcomp:
+      mutation = get_mutation_revcomp(mutation, seqlen)
+    if start <= mutation['coord'] < stop:
+      mutations_new.append(mutation)
+    elif mutation['coord'] == stop and mutation['type'] == 'ins':
+      # Allow insertions at the last coordinate.
+      mutations_new.append(mutation)
+  return mutations_new
+
+
+def get_mutation_revcomp(mut, seqlen):
+  """Convert a mutation to its reverse complement.
+  "seqlen" is the length of the sequence the mutation is being applied to. Needed
+  to convert the coordinate to a coordinate system starting at the end of the
+  sequence."""
+  mut_rc = {'type':mut['type']}
+  if mut['type'] == 'snv':
+    mut_rc['coord'] = seqlen - mut['coord'] - 1
+    mut_rc['alt'] = get_revcomp(mut['alt'])
+  elif mut['type'] == 'ins':
+    mut_rc['coord'] = seqlen - mut['coord']
+    mut_rc['alt'] = get_revcomp(mut['alt'])
+  elif mut['type'] == 'del':
+    mut_rc['coord'] = seqlen - mut['coord'] - mut['alt']
+    mut_rc['alt'] = mut['alt']
+  return mut_rc
+
+
+def log_mutations(mutfile, mutations, read_id, chrom, start, stop):
+  for mutation in mutations:
+    mutfile.write('{read_id}\t{chrom}\t{start}\t{stop}\t{coord}\t{type}\t{alt}\n'
+                  .format(read_id=read_id, chrom=chrom, start=start, stop=stop, **mutation))
+
+
+def add_pcr_errors(subtree, strand, read_len, error_rate, indel_rate, extension_rate):
+  """Add simulated PCR errors to a node in a tree and all its descendants."""
+  # Note: The errors are intended as "errors made in creating this molecule", so don't apply this to
+  # the root node, since that is supposed to be the original, unaltered molecule.
+  # Go down the subtree and simulate errors in creating each fragment.
+  # Process all the first-child descendants of the original node in a loop, and recursively call
+  # this function to process all second children.
+  node = subtree
+  while node:
+    node['strand'] = strand
+    node['errors'] = list(generate_mutations(read_len, error_rate, indel_rate, extension_rate))
+    add_pcr_errors(node.get('child2'), strand, read_len, error_rate, indel_rate, extension_rate)
+    node = node.get('child1')
+
+
+def apply_pcr_errors(subtree, seq):
+  node = subtree
+  while node:
+    for error in node.get('errors', ()):
+      seq = apply_mutation(error, seq)
+    if 'child1' not in node:
+      node['seq'] = seq
+    apply_pcr_errors(node.get('child2'), seq)
+    node = node.get('child1')
+
+
+def get_final_fragments(tree):
+  """Walk to the leaf nodes of the tree and get the post-PCR sequences of all the fragments.
+  Returns a dict mapping fragment id number to a dict representing the fragment. Its only two keys
+  are 'seq' (the final sequence) and 'strand' ('+' or '-')."""
+  fragments = {}
+  nodes = [tree]
+  while nodes:
+    node = nodes.pop()
+    child1 = node.get('child1')
+    if child1:
+      nodes.append(child1)
+    else:
+      fragments[node['id']] = {'seq':node['seq'], 'strand':node['strand']}
+    child2 = node.get('child2')
+    if child2:
+      nodes.append(child2)
+  return fragments
+
+
+def add_mutation_lists(subtree, fragments, mut_list1):
+  """Compile the list of mutations that each fragment has undergone in PCR.
+  To call from the root, give [] as "mut_list1" and a dict mapping all existing node id's to a dict
+  as "fragments". Instead of returning the data, this will add a 'mutations' key to the dict for
+  each fragment, mapping it to a list of PCR mutations that occurred in the lineage of the fragment,
+  in chronological order."""
+  node = subtree
+  while node:
+    mut_list1.extend(node.get('errors', ()))
+    if 'child1' not in node:
+      fragments[node['id']]['mutations'] = mut_list1
+    if 'child2' in node:
+      mut_list2 = copy.deepcopy(mut_list1)
+      add_mutation_lists(node.get('child2'), fragments, mut_list2)
+    node = node.get('child1')
+
+
+def check_tree_balance(subtree):
+  """Find all points in the tree where the cycles of sibling nodes is unequal, and
+  return the maximum difference."""
+  node = subtree
+  if node:
+    child1 = node.get('child1')
+    child2 = node.get('child2')
+    if child1 and child2:
+      diff = abs(child1['cycle'] - child2['cycle'])
+    else:
+      diff = 0
+    diff_child1 = check_tree_balance(child1)
+    diff_child2 = check_tree_balance(child2)
+    return max(diff, diff_child1, diff_child2)
+  else:
+    return 0
+
+
+def get_good_pcr_tree(n_reads, n_cycles, max_tries, max_diff=1):
+  """Return a single, balanced PCR tree from build_pcr_tree(), or fail if one cannot
+  be found in max_tries.
+  Compensate for bugs in build_pcr_tree() that sometimes result in multiple trees,
+  or trees with siblings from different cycles."""
+  tries = 0
+  while tries <= max_tries:
+    trees = build_pcr_tree(n_reads, n_cycles)
+    if len(trees) == 1 and check_tree_balance(trees[0]) <= max_diff:
+      return trees[0]
+    tries += 1
+  raise AssertionError('Could not generate a single, balanced tree! (tried {} times)'
+                       .format(max_tries))
+
+
+def build_pcr_tree(n_reads, n_cycles):
+  """Create a simulated descent lineage of how all the final PCR fragments are related.
+  Each node represents a fragment molecule at one stage of PCR. Each node is a dict containing the
+  fragment's children (other nodes) ('child1' and 'child2'), the PCR cycle number ('cycle'), and,
+  at the leaves, a unique id number for each final fragment.
+  Returns a list of root nodes. Usually there will only be one, but about 1-3% of the time it fails
+  to unify the subtrees and results in a broken tree.
+  """
+  #TODO: Make it always return a single tree.
+  # Begin a branch for each of the fragments. These are the leaf nodes. We'll work backward from
+  # these, simulating the points at which they share ancestors, eventually coalescing into the
+  # single (root) ancestor.
+  branches = []
+  for frag_id in range(n_reads):
+    branches.append({'cycle':n_cycles-1, 'id':frag_id})
+  # Build up all the branches in parallel. Start from the second-to-last PCR cycle.
+  for cycle in reversed(range(n_cycles-1)):
+    # Probability of 2 fragments sharing an ancestor at cycle c is 1/2^c.
+    prob = 1/2**cycle
+    frag_i = 0
+    while frag_i < len(branches):
+      current_root = branches[frag_i]
+      # Does the current fragment share this ancestor with any of the other fragments?
+      # numpy.random.binomial() is a fast way to simulate going through every other fragment and
+      # asking if random.random() < prob.
+      shared = numpy.random.binomial(len(branches)-1, prob)
+      if shared == 0:
+        # No branch point here. Just add another level to the lineage.
+        branches[frag_i] = {'cycle':cycle, 'child1':current_root}
+      else:
+        # Pick a random other fragment to share this ancestor with.
+        # Make a list of candidates to pick from.
+        candidates = []
+        for candidate_i, candidate in enumerate(branches):
+          # Don't include ourselves.
+          if candidate is current_root:
+            continue
+          # If it's at a cycle above us and it already has a child, skip it.
+          if candidate['cycle'] == cycle and candidate.get('child2'):
+            continue
+          candidates.append(candidate_i)
+        if candidates:
+          relative_i = random.choice(candidates)
+          relative = branches[relative_i]
+          # Have we already passed this fragmentfragment on this cycle?
+          if relative['cycle'] == cycle:
+            # If we've already passed it, we're looking at the fragment's parent. We want the child.
+            relative = relative['child1']
+          # Join the lineages of our current fragment and the relative to a new parent.
+          #TODO: Sometimes, we end up matching up subtrees of different depths. But the discrepancy
+          #      is rarely greater than 1. Figure out why.
+          # assert abs(current_root['cycle'] - relative['cycle']) < 3, ('cycle: {}, current_root: {},'
+          #   ' relative: {}, frag_i: {}, relative_i: {}, branches: {}, candidates: {}, shared: {}'
+          #   .format(cycle, current_root['cycle'], relative['cycle'], frag_i, relative_i,
+          #           len(branches), len(candidates), shared))
+          branches[frag_i] = {'cycle':cycle, 'child1':current_root, 'child2':relative}
+          # Remove the relative from the list of lineages.
+          del(branches[relative_i])
+          if relative_i < frag_i:
+            frag_i -= 1
+      frag_i += 1
+  return branches
+
+
+def get_depth(tree):
+  depth = 0
+  node = tree
+  while node:
+    depth += 1
+    node = node.get('child1')
+  return depth
+
+
+def convert_tree(tree_orig):
+  # Let's operate on a copy only.
+  tree = copy.deepcopy(tree_orig)
+  # Turn the tree vertical.
+  tree['line'] = 1
+  tree['children'] = 0
+  levels = [[tree]]
+  done = False
+  while not done:
+    last_level = levels[-1]
+    this_level = []
+    done = True
+    for node in last_level:
+      for child_name in ('child1', 'child2'):
+        child = node.get(child_name)
+        if child:
+          done = False
+          child['parent'] = node
+          child['branch'] = child['parent']['branch']
+          if child_name == 'child2':
+            child['branch'] += 1
+          this_level.append(child)
+    this_level.sort(key=lambda node: node['branch'])
+    levels.append(this_level)
+  return levels
+
+
+def print_levels(levels):
+  last_level = []
+  for level in levels:
+    for node in level:
+      child = 1
+      for parent in last_level:
+        if parent.get('child2') is node:
+          child = 2
+      if child == 1:
+        sys.stdout.write('| ')
+      else:
+        sys.stdout.write('\ ')
+    last_level = level
+    print()
+
+
+def label_branches(tree):
+  """Label each vertical branch (line of 'child1's) with an id number."""
+  counter = 1
+  tree['branch'] = counter
+  nodes = [tree]
+  while nodes:
+    node = nodes.pop(0)
+    child1 = node.get('child1')
+    if child1:
+      child1['branch'] = node['branch']
+      nodes.append(child1)
+    child2 = node.get('child2')
+    if child2:
+      counter += 1
+      child2['branch'] = counter
+      nodes.append(child2)
+
+
+def print_tree(tree_orig):
+  # We "write" strings to an output buffer instead of directly printing, so we can post-process the
+  # output. The buffer is a matrix of cells, each holding a string representing one element.
+  lines = [[]]
+  # Let's operate on a copy only.
+  tree = copy.deepcopy(tree_orig)
+  # Add some bookkeeping data.
+  label_branches(tree)
+  tree['level'] = 0
+  branches = [tree]
+  while branches:
+    line = lines[-1]
+    branch = branches.pop()
+    level = branch['level']
+    while level > 0:
+      line.append('  ')
+      level -= 1
+    node = branch
+    while node:
+      # Is it the root node? (Have we written anything yet?)
+      if lines[0]:
+        # Are we at the start of the line? (Is it only spaces so far?)
+        if line[-1] == '  ':
+          line.append('\-')
+        elif line[-1].endswith('-'):
+          line.append('=-')
+      else:
+        line.append('*-')
+      child2 = node.get('child2')
+      if child2:
+        child2['level'] = node['level'] + 1
+        branches.append(child2)
+      parent = node
+      node = node.get('child1')
+      if node:
+        node['level'] = parent['level'] + 1
+      else:
+        line.append(' {}'.format(parent['branch']))
+        lines.append([])
+  # Post-process output: Add lines connecting branches to parents.
+  x = 0
+  done = False
+  while not done:
+    # Draw vertical lines upward from branch points.
+    drawing = False
+    for line in reversed(lines):
+      done = True
+      if x < len(line):
+        done = False
+        cell = line[x]
+        if cell == '\-':
+          drawing = True
+        elif cell == '  ' and drawing:
+          line[x] = '| '
+        elif cell == '=-' and drawing:
+          drawing = False
+    x += 1
+  # Print the final output.
+  for line in lines:
+    print(''.join(line))
+
+
+def fail(message):
+  sys.stderr.write(message+"\n")
+  sys.exit(1)
+
+if __name__ == '__main__':
+  sys.exit(main(sys.argv))