comparison utils/sim.py @ 18:e4d75f9efb90 draft

planemo upload commit b'4303231da9e48b2719b4429a29b72421d24310f4\n'-dirty
author nick
date Thu, 02 Feb 2017 18:44:31 -0500
parents
children
comparison
equal deleted inserted replaced
17:836fa4fe9494 18:e4d75f9efb90
1 #!/usr/bin/env python
2 from __future__ import division
3 from __future__ import print_function
4 import re
5 import os
6 import sys
7 import copy
8 import numpy
9 import bisect
10 import random
11 import string
12 import numbers
13 import tempfile
14 import argparse
15 import subprocess
16 import fastqreader
17
18 REVCOMP_TABLE = string.maketrans('acgtrymkbdhvACGTRYMKBDHV', 'tgcayrkmvhdbTGCAYRKMVHDB')
19 WGSIM_ID_REGEX = r'^(.+)_(\d+)_(\d+)_\d+:\d+:\d+_\d+:\d+:\d+_([0-9a-f]+)/[12]$'
20 ARG_DEFAULTS = {'read_len':100, 'frag_len':400, 'n_frags':1000, 'out_format':'fasta',
21 'seq_error':0.001, 'pcr_error':0.001, 'cycles':25, 'indel_rate':0.15,
22 'ext_rate':0.3, 'seed':None, 'invariant':'TGACT', 'bar_len':12, 'fastq_qual':'I'}
23 USAGE = "%(prog)s [options]"
24 DESCRIPTION = """Simulate a duplex sequencing experiment."""
25
26 RAW_DISTRIBUTION = (
27 # 0 1 2 3 4 5 6 7 8 9
28 # Low singletons, but then constant drop-off. From pML113 (see 2015-09-28 report).
29 # 0, 100, 36, 31, 27, 22, 17, 12, 7, 4.3,
30 #2.4, 1.2, 0.6, 0.3, 0.2, 0.15, 0.1, 0.07, 0.05, 0.03,
31 # High singletons, but then a second peak around 10. From Christine plasmid (2015-10-06 report).
32 # 0, 100, 5.24, 3.67, 3.50, 3.67, 3.85, 4.02, 4.11, 4.20,
33 # 4.17, 4.10, 4.00, 3.85, 3.69, 3.55, 3.38, 3.15, 2.92, 2.62,
34 # 2.27, 2.01, 1.74, 1.56, 1.38, 1.20, 1.02, 0.85,
35 # Same as above, but low singletons, 2's, and 3's (rely on errors to fill out those).
36 0, 1, 2, 3, 3.50, 3.67, 3.85, 4.02, 4.11, 4.20,
37 4.17, 4.10, 4.00, 3.85, 3.69, 3.55, 3.38, 3.15, 2.92, 2.62,
38 2.27, 2.01, 1.74, 1.56, 1.38, 1.20, 1.02, 0.85,
39 )
40
41
42 def main(argv):
43
44 parser = argparse.ArgumentParser(description=DESCRIPTION)
45 parser.set_defaults(**ARG_DEFAULTS)
46
47 parser.add_argument('ref', metavar='ref.fa', nargs='?',
48 help='Reference sequence. Omit if giving --frag-file.')
49 parser.add_argument('out1', type=argparse.FileType('w'),
50 help='Write final mate 1 reads to this file.')
51 parser.add_argument('out2', type=argparse.FileType('w'),
52 help='Write final mate 2 reads to this file.')
53 parser.add_argument('-o', '--out-format', choices=('fastq', 'fasta'))
54 parser.add_argument('--stdout', action='store_true',
55 help='Print interleaved output reads to stdout.')
56 parser.add_argument('-m', '--mutations', type=argparse.FileType('w'),
57 help='Write a log of the PCR and sequencing errors introduced to this file. Will overwrite any '
58 'existing file at this path.')
59 parser.add_argument('-b', '--barcodes', type=argparse.FileType('w'),
60 help='Write a log of which barcodes were ligated to which fragments. Will overwrite any '
61 'existing file at this path.')
62 parser.add_argument('--frag-file',
63 help='The path of the FASTQ file of fragments. If --ref is given, these will be generated with '
64 'wgsim and kept (normally a temporary file is used, then deleted). Note: the file will be '
65 'overwritten! If --ref is not given, then this should be a file of already generated '
66 'fragments, and they will be used instead of generating new ones.')
67 parser.add_argument('-Q', '--fastq-qual',
68 help='The quality score to assign to all bases in FASTQ output. Give a character or PHRED '
69 'score (integer). A PHRED score will be converted using the Sanger offset (33). Default: '
70 '"%(default)s"')
71 parser.add_argument('-S', '--seed', type=int,
72 help='Random number generator seed. By default, a random, 32-bit seed will be generated and '
73 'logged to stdout.')
74 params = parser.add_argument_group('simulation parameters')
75 params.add_argument('-n', '--n-frags', type=int,
76 help='The number of original fragment molecules to simulate. The final number of reads will be '
77 'this multiplied by the average number of reads per family. If you provide fragments with '
78 '--frag-file, the script will still only read in the number specified here. Default: '
79 '%(default)s')
80 params.add_argument('-r', '--read-len', type=int,
81 help='Default: %(default)s')
82 params.add_argument('-f', '--frag-len', type=int,
83 help='Default: %(default)s')
84 params.add_argument('-s', '--seq-error', type=float,
85 help='Sequencing error rate per base (0-1 proportion, not percent). Default: %(default)s')
86 params.add_argument('-p', '--pcr-error', type=float,
87 help='PCR error rate per base (0-1 proportion, not percent). Default: %(default)s')
88 params.add_argument('-c', '--cycles', type=int,
89 help='Number of PCR cycles to simulate. Default: %(default)s')
90 params.add_argument('-i', '--indel-rate', type=float,
91 help='Fraction of errors which are indels. Default: %(default)s')
92 params.add_argument('-E', '--extension-rate', dest='ext_rate', type=float,
93 help='Probability an indel is extended. Default: %(default)s')
94 params.add_argument('-B', '--bar-len', type=int,
95 help='Length of the barcodes to generate. Default: %(default)s')
96 params.add_argument('-I', '--invariant',
97 help='The invariant linker sequence between the barcode and sample sequence in each read. '
98 'Default: %(default)s')
99
100 # Parse and interpret arguments.
101 args = parser.parse_args(argv[1:])
102 assert args.ref or args.frag_file, 'You must provide either a reference or fragments file.'
103 if args.seed is None:
104 seed = random.randint(0, 2**31-1)
105 sys.stderr.write('seed: {}\n'.format(seed))
106 else:
107 seed = args.seed
108 random.seed(seed)
109 if args.stdout:
110 out1 = sys.stdout
111 out2 = sys.stdout
112 else:
113 out1 = args.out1
114 out2 = args.out2
115 if isinstance(args.fastq_qual, numbers.Integral):
116 assert args.fastq_qual >= 0, '--fastq-qual cannot be negative.'
117 fastq_qual = chr(args.fastq_qual + 33)
118 elif isinstance(args.fastq_qual, basestring):
119 assert len(args.fastq_qual) == 1, '--fastq-qual cannot be more than a single character.'
120 fastq_qual = args.fastq_qual
121 else:
122 raise AssertionError('--fastq-qual must be a positive integer or single character.')
123 qual_line = fastq_qual * args.read_len
124
125 invariant_rc = get_revcomp(args.invariant)
126
127 # Create a temporary director to do our work in. Then work inside a try so we can finally remove
128 # the directory no matter what exceptions are encountered.
129 tmpfile = tempfile.NamedTemporaryFile(prefix='wgdsim.frags.')
130 tmpfile.close()
131 try:
132 # Step 1: Use wgsim to create fragments from the reference.
133 if args.frag_file:
134 frag_file = args.frag_file
135 else:
136 frag_file = tmpfile.name
137 if args.ref and os.path.isfile(args.ref) and os.path.getsize(args.ref):
138 #TODO: Check exit status
139 #TODO: Check for wgsim on the PATH.
140 # Set error and mutation rates to 0 to just slice sequences out of the reference without
141 # modification.
142 run_command('wgsim', '-e', '0', '-r', '0', '-d', '0', '-R', args.indel_rate, '-S', seed,
143 '-N', args.n_frags, '-X', args.ext_rate, '-1', args.frag_len,
144 args.ref, frag_file, os.devnull)
145
146 # NOTE: Coordinates here are 0-based (0 is the first base in the sequence).
147 extended_dist = extend_dist(RAW_DISTRIBUTION)
148 proportional_dist = compile_dist(extended_dist)
149 n_frags = 0
150 for raw_fragment in fastqreader.FastqReadGenerator(frag_file):
151 n_frags += 1
152 if n_frags > args.n_frags:
153 break
154 chrom, id_num, start, stop = parse_read_id(raw_fragment.id)
155 barcode1 = get_rand_seq(args.bar_len)
156 barcode2 = get_rand_seq(args.bar_len)
157 barcode2_rc = get_revcomp(barcode2)
158 raw_frag_full = barcode1 + args.invariant + raw_fragment.seq + invariant_rc + barcode2
159
160 # Step 2: Determine how many reads to produce from each fragment.
161 # - Use random.random() and divide the range 0-1 into segments of sizes proportional to
162 # the likelihood of each family size.
163 # bisect.bisect() finds where an element belongs in a sorted list, returning the index.
164 # proportional_dist is just such a sorted list, with values from 0 to 1.
165 n_reads = bisect.bisect(proportional_dist, random.random())
166
167 # Step 3: Introduce PCR errors.
168 # - Determine the mutations and their frequencies.
169 # - Could get frequency from the cycle of PCR it occurs in.
170 # - Important to have PCR errors shared between reads.
171 # - For each read, determine which mutations it contains.
172 # - Use random.random() < mut_freq.
173 tree = get_good_pcr_tree(n_reads, args.cycles, 1000, max_diff=1)
174 # Add errors to all children of original fragment.
175 subtree1 = tree.get('child1')
176 subtree2 = tree.get('child2')
177 #TODO: Only simulate errors on portions of fragment that will become reads.
178 add_pcr_errors(subtree1, '+', len(raw_frag_full), args.pcr_error, args.indel_rate, args.ext_rate)
179 add_pcr_errors(subtree2, '-', len(raw_frag_full), args.pcr_error, args.indel_rate, args.ext_rate)
180 apply_pcr_errors(tree, raw_frag_full)
181 fragments = get_final_fragments(tree)
182 add_mutation_lists(tree, fragments, [])
183
184 # Step 4: Introduce sequencing errors.
185 for fragment in fragments.values():
186 for mutation in generate_mutations(args.read_len, args.seq_error, args.indel_rate,
187 args.ext_rate):
188 fragment['mutations'].append(mutation)
189 fragment['seq'] = apply_mutation(mutation, fragment['seq'])
190
191 # Print barcodes to log file.
192 if args.barcodes:
193 args.barcodes.write('{}-{}\t{}\t{}\n'.format(chrom, id_num, barcode1, barcode2_rc))
194 # Print family.
195 for frag_id in sorted(fragments.keys()):
196 fragment = fragments[frag_id]
197 read_id = '{}-{}-{}'.format(chrom, id_num, frag_id)
198 # Print mutations to log file.
199 if args.mutations:
200 read1_muts = get_mutations_subset(fragment['mutations'], 0, args.read_len)
201 read2_muts = get_mutations_subset(fragment['mutations'], 0, args.read_len, revcomp=True,
202 seqlen=len(fragment['seq']))
203 if fragment['strand'] == '-':
204 read1_muts, read2_muts = read2_muts, read1_muts
205 log_mutations(args.mutations, read1_muts, read_id+'/1', chrom, start, stop)
206 log_mutations(args.mutations, read2_muts, read_id+'/2', chrom, start, stop)
207 frag_seq = fragment['seq']
208 read1_seq = frag_seq[:args.read_len]
209 read2_seq = get_revcomp(frag_seq[len(frag_seq)-args.read_len:])
210 if fragment['strand'] == '-':
211 read1_seq, read2_seq = read2_seq, read1_seq
212 if args.out_format == 'fasta':
213 out1.write('>{}\n{}\n'.format(read_id, read1_seq))
214 out2.write('>{}\n{}\n'.format(read_id, read2_seq))
215 elif args.out_format == 'fastq':
216 out1.write('@{}\n{}\n+\n{}\n'.format(read_id, read1_seq, qual_line))
217 out2.write('@{}\n{}\n+\n{}\n'.format(read_id, read2_seq, qual_line))
218
219 finally:
220 try:
221 os.remove(tmpfile.name)
222 except OSError:
223 pass
224
225
226 def run_command(*command, **kwargs):
227 """Run a command and return the exit code.
228 run_command('echo', 'hello')
229 Will print the command to stderr before running, unless "silent" is set to True."""
230 command_strs = map(str, command)
231 if not kwargs.get('silent'):
232 sys.stderr.write('$ '+' '.join(command_strs)+'\n')
233 devnull = open(os.devnull, 'w')
234 try:
235 exit_status = subprocess.call(map(str, command), stderr=devnull)
236 except OSError:
237 exit_status = None
238 finally:
239 devnull.close()
240 return exit_status
241
242
243 def extend_dist(raw_dist, exponent=1.25, min_prob=0.00001, max_len_mult=2):
244 """Add an exponentially decreasing tail to the distribution.
245 It takes the final value in the distribution and keeps dividing it by
246 "exponent", adding each new value to the end. It will not add probabilities
247 smaller than "min_prob" or extend the length of the list by more than
248 "max_len_mult" times."""
249 extended_dist = list(raw_dist)
250 final_sum = sum(raw_dist)
251 value = raw_dist[-1]
252 value /= exponent
253 while value/final_sum >= min_prob and len(extended_dist) < len(raw_dist)*max_len_mult:
254 extended_dist.append(value)
255 final_sum += value
256 value /= exponent
257 return extended_dist
258
259
260 def compile_dist(raw_dist):
261 """Turn the human-readable list of probabilities defined at the top into
262 proportional probabilities.
263 E.g. [10, 5, 5] -> [0.5, 0.75, 1.0]"""
264 proportional_dist = []
265 final_sum = sum(raw_dist)
266 current_sum = 0
267 for magnitude in raw_dist:
268 current_sum += magnitude
269 proportional_dist.append(current_sum/final_sum)
270 return proportional_dist
271
272
273 def parse_read_id(read_id):
274 match = re.search(WGSIM_ID_REGEX, read_id)
275 if match:
276 chrom = match.group(1)
277 start = match.group(2)
278 stop = match.group(3)
279 id_num = match.group(4)
280 else:
281 chrom, id_num, start, stop = read_id, None, None, None
282 return chrom, id_num, start, stop
283
284
285 #TODO: Clean up "mutation" vs "error" terminology.
286 def generate_mutations(seq_len, error_rate, indel_rate, extension_rate):
287 """Generate all the mutations that occur over the length of a sequence."""
288 i = 0
289 while i <= seq_len:
290 if random.random() < error_rate:
291 mtype, alt = make_mutation(indel_rate, extension_rate)
292 # Allow mutation after the last base only if it's an insertion.
293 if i < seq_len or mtype == 'ins':
294 yield {'coord':i, 'type':mtype, 'alt':alt}
295 # Compensate for length variations to keep i tracking the original read's base coordinates.
296 if mtype == 'ins':
297 i += len(alt)
298 elif mtype == 'del':
299 i -= alt
300 i += 1
301
302
303 def make_mutation(indel_rate, extension_rate):
304 """Simulate a random mutation."""
305 # Is it an indel?
306 rand = random.random()
307 if rand < indel_rate:
308 # Is it an insertion or deletion? Decide, then initialize it.
309 # Re-use the random number from above. Just check if it's in the lower or upper half of the
310 # range from 0 to indel_rate.
311 if rand < indel_rate/2:
312 mtype = 'del'
313 alt = 1
314 else:
315 mtype = 'ins'
316 alt = get_rand_base()
317 # Extend the indel as long as the extension rate allows.
318 while random.random() < extension_rate:
319 if mtype == 'ins':
320 alt += get_rand_base()
321 else:
322 alt += 1
323 else:
324 # What is the new base for the SNV?
325 mtype = 'snv'
326 alt = get_rand_base()
327 return mtype, alt
328
329
330 def get_rand_base(bases='ACGT'):
331 return random.choice(bases)
332
333
334 def get_rand_seq(seq_len):
335 return ''.join([get_rand_base() for i in range(seq_len)])
336
337
338 def get_revcomp(seq):
339 return seq.translate(REVCOMP_TABLE)[::-1]
340
341
342 def apply_mutation(mut, seq):
343 i = mut['coord']
344 if mut['type'] == 'snv':
345 # Replace the base at "coord".
346 new_seq = seq[:i] + mut['alt'] + seq[i+1:]
347 else:
348 # Indels are handled by inserting or deleting bases starting *before* the base at "coord".
349 # This goes agains the VCF convention, but it allows deleting the first and last base, as well
350 # as inserting before and after the sequence without as much special-casing.
351 if mut['type'] == 'ins':
352 # Example: 'ACGTACGT' + ins 'GC' at 4 = 'ACGTGCACGT'
353 new_seq = seq[:i] + mut['alt'] + seq[i:]
354 else:
355 # Example: 'ACGTACGT' + del 2 at 4 = 'ACGTGT'
356 new_seq = seq[:i] + seq[i+mut['alt']:]
357 return new_seq
358
359
360 def get_mutations_subset(mutations_old, start, length, revcomp=False, seqlen=None):
361 """Get a list of the input mutations which are within a certain region.
362 The output list maintains the order in the input list, only filtering out
363 mutations outside the specified region.
364 "start" is the start of the region (0-based). If revcomp, this start should be
365 in the coordinate system of the reverse-complemented sequence.
366 "length" is the length of the region.
367 "revcomp" causes the mutations to be converted to their reverse complements, and
368 the "start" to refer to the reverse complement sequence's coordinates. The order
369 of the mutations is unchanged, though.
370 "seqlen" is the length of the sequence the mutations occurred in. This is only
371 needed when revcomp is True, to convert coordinates to the reverse complement
372 coordinate system."""
373 stop = start + length
374 mutations_new = []
375 for mutation in mutations_old:
376 if revcomp:
377 mutation = get_mutation_revcomp(mutation, seqlen)
378 if start <= mutation['coord'] < stop:
379 mutations_new.append(mutation)
380 elif mutation['coord'] == stop and mutation['type'] == 'ins':
381 # Allow insertions at the last coordinate.
382 mutations_new.append(mutation)
383 return mutations_new
384
385
386 def get_mutation_revcomp(mut, seqlen):
387 """Convert a mutation to its reverse complement.
388 "seqlen" is the length of the sequence the mutation is being applied to. Needed
389 to convert the coordinate to a coordinate system starting at the end of the
390 sequence."""
391 mut_rc = {'type':mut['type']}
392 if mut['type'] == 'snv':
393 mut_rc['coord'] = seqlen - mut['coord'] - 1
394 mut_rc['alt'] = get_revcomp(mut['alt'])
395 elif mut['type'] == 'ins':
396 mut_rc['coord'] = seqlen - mut['coord']
397 mut_rc['alt'] = get_revcomp(mut['alt'])
398 elif mut['type'] == 'del':
399 mut_rc['coord'] = seqlen - mut['coord'] - mut['alt']
400 mut_rc['alt'] = mut['alt']
401 return mut_rc
402
403
404 def log_mutations(mutfile, mutations, read_id, chrom, start, stop):
405 for mutation in mutations:
406 mutfile.write('{read_id}\t{chrom}\t{start}\t{stop}\t{coord}\t{type}\t{alt}\n'
407 .format(read_id=read_id, chrom=chrom, start=start, stop=stop, **mutation))
408
409
410 def add_pcr_errors(subtree, strand, read_len, error_rate, indel_rate, extension_rate):
411 """Add simulated PCR errors to a node in a tree and all its descendants."""
412 # Note: The errors are intended as "errors made in creating this molecule", so don't apply this to
413 # the root node, since that is supposed to be the original, unaltered molecule.
414 # Go down the subtree and simulate errors in creating each fragment.
415 # Process all the first-child descendants of the original node in a loop, and recursively call
416 # this function to process all second children.
417 node = subtree
418 while node:
419 node['strand'] = strand
420 node['errors'] = list(generate_mutations(read_len, error_rate, indel_rate, extension_rate))
421 add_pcr_errors(node.get('child2'), strand, read_len, error_rate, indel_rate, extension_rate)
422 node = node.get('child1')
423
424
425 def apply_pcr_errors(subtree, seq):
426 node = subtree
427 while node:
428 for error in node.get('errors', ()):
429 seq = apply_mutation(error, seq)
430 if 'child1' not in node:
431 node['seq'] = seq
432 apply_pcr_errors(node.get('child2'), seq)
433 node = node.get('child1')
434
435
436 def get_final_fragments(tree):
437 """Walk to the leaf nodes of the tree and get the post-PCR sequences of all the fragments.
438 Returns a dict mapping fragment id number to a dict representing the fragment. Its only two keys
439 are 'seq' (the final sequence) and 'strand' ('+' or '-')."""
440 fragments = {}
441 nodes = [tree]
442 while nodes:
443 node = nodes.pop()
444 child1 = node.get('child1')
445 if child1:
446 nodes.append(child1)
447 else:
448 fragments[node['id']] = {'seq':node['seq'], 'strand':node['strand']}
449 child2 = node.get('child2')
450 if child2:
451 nodes.append(child2)
452 return fragments
453
454
455 def add_mutation_lists(subtree, fragments, mut_list1):
456 """Compile the list of mutations that each fragment has undergone in PCR.
457 To call from the root, give [] as "mut_list1" and a dict mapping all existing node id's to a dict
458 as "fragments". Instead of returning the data, this will add a 'mutations' key to the dict for
459 each fragment, mapping it to a list of PCR mutations that occurred in the lineage of the fragment,
460 in chronological order."""
461 node = subtree
462 while node:
463 mut_list1.extend(node.get('errors', ()))
464 if 'child1' not in node:
465 fragments[node['id']]['mutations'] = mut_list1
466 if 'child2' in node:
467 mut_list2 = copy.deepcopy(mut_list1)
468 add_mutation_lists(node.get('child2'), fragments, mut_list2)
469 node = node.get('child1')
470
471
472 def check_tree_balance(subtree):
473 """Find all points in the tree where the cycles of sibling nodes is unequal, and
474 return the maximum difference."""
475 node = subtree
476 if node:
477 child1 = node.get('child1')
478 child2 = node.get('child2')
479 if child1 and child2:
480 diff = abs(child1['cycle'] - child2['cycle'])
481 else:
482 diff = 0
483 diff_child1 = check_tree_balance(child1)
484 diff_child2 = check_tree_balance(child2)
485 return max(diff, diff_child1, diff_child2)
486 else:
487 return 0
488
489
490 def get_good_pcr_tree(n_reads, n_cycles, max_tries, max_diff=1):
491 """Return a single, balanced PCR tree from build_pcr_tree(), or fail if one cannot
492 be found in max_tries.
493 Compensate for bugs in build_pcr_tree() that sometimes result in multiple trees,
494 or trees with siblings from different cycles."""
495 tries = 0
496 while tries <= max_tries:
497 trees = build_pcr_tree(n_reads, n_cycles)
498 if len(trees) == 1 and check_tree_balance(trees[0]) <= max_diff:
499 return trees[0]
500 tries += 1
501 raise AssertionError('Could not generate a single, balanced tree! (tried {} times)'
502 .format(max_tries))
503
504
505 def build_pcr_tree(n_reads, n_cycles):
506 """Create a simulated descent lineage of how all the final PCR fragments are related.
507 Each node represents a fragment molecule at one stage of PCR. Each node is a dict containing the
508 fragment's children (other nodes) ('child1' and 'child2'), the PCR cycle number ('cycle'), and,
509 at the leaves, a unique id number for each final fragment.
510 Returns a list of root nodes. Usually there will only be one, but about 1-3% of the time it fails
511 to unify the subtrees and results in a broken tree.
512 """
513 #TODO: Make it always return a single tree.
514 # Begin a branch for each of the fragments. These are the leaf nodes. We'll work backward from
515 # these, simulating the points at which they share ancestors, eventually coalescing into the
516 # single (root) ancestor.
517 branches = []
518 for frag_id in range(n_reads):
519 branches.append({'cycle':n_cycles-1, 'id':frag_id})
520 # Build up all the branches in parallel. Start from the second-to-last PCR cycle.
521 for cycle in reversed(range(n_cycles-1)):
522 # Probability of 2 fragments sharing an ancestor at cycle c is 1/2^c.
523 prob = 1/2**cycle
524 frag_i = 0
525 while frag_i < len(branches):
526 current_root = branches[frag_i]
527 # Does the current fragment share this ancestor with any of the other fragments?
528 # numpy.random.binomial() is a fast way to simulate going through every other fragment and
529 # asking if random.random() < prob.
530 shared = numpy.random.binomial(len(branches)-1, prob)
531 if shared == 0:
532 # No branch point here. Just add another level to the lineage.
533 branches[frag_i] = {'cycle':cycle, 'child1':current_root}
534 else:
535 # Pick a random other fragment to share this ancestor with.
536 # Make a list of candidates to pick from.
537 candidates = []
538 for candidate_i, candidate in enumerate(branches):
539 # Don't include ourselves.
540 if candidate is current_root:
541 continue
542 # If it's at a cycle above us and it already has a child, skip it.
543 if candidate['cycle'] == cycle and candidate.get('child2'):
544 continue
545 candidates.append(candidate_i)
546 if candidates:
547 relative_i = random.choice(candidates)
548 relative = branches[relative_i]
549 # Have we already passed this fragmentfragment on this cycle?
550 if relative['cycle'] == cycle:
551 # If we've already passed it, we're looking at the fragment's parent. We want the child.
552 relative = relative['child1']
553 # Join the lineages of our current fragment and the relative to a new parent.
554 #TODO: Sometimes, we end up matching up subtrees of different depths. But the discrepancy
555 # is rarely greater than 1. Figure out why.
556 # assert abs(current_root['cycle'] - relative['cycle']) < 3, ('cycle: {}, current_root: {},'
557 # ' relative: {}, frag_i: {}, relative_i: {}, branches: {}, candidates: {}, shared: {}'
558 # .format(cycle, current_root['cycle'], relative['cycle'], frag_i, relative_i,
559 # len(branches), len(candidates), shared))
560 branches[frag_i] = {'cycle':cycle, 'child1':current_root, 'child2':relative}
561 # Remove the relative from the list of lineages.
562 del(branches[relative_i])
563 if relative_i < frag_i:
564 frag_i -= 1
565 frag_i += 1
566 return branches
567
568
569 def get_depth(tree):
570 depth = 0
571 node = tree
572 while node:
573 depth += 1
574 node = node.get('child1')
575 return depth
576
577
578 def convert_tree(tree_orig):
579 # Let's operate on a copy only.
580 tree = copy.deepcopy(tree_orig)
581 # Turn the tree vertical.
582 tree['line'] = 1
583 tree['children'] = 0
584 levels = [[tree]]
585 done = False
586 while not done:
587 last_level = levels[-1]
588 this_level = []
589 done = True
590 for node in last_level:
591 for child_name in ('child1', 'child2'):
592 child = node.get(child_name)
593 if child:
594 done = False
595 child['parent'] = node
596 child['branch'] = child['parent']['branch']
597 if child_name == 'child2':
598 child['branch'] += 1
599 this_level.append(child)
600 this_level.sort(key=lambda node: node['branch'])
601 levels.append(this_level)
602 return levels
603
604
605 def print_levels(levels):
606 last_level = []
607 for level in levels:
608 for node in level:
609 child = 1
610 for parent in last_level:
611 if parent.get('child2') is node:
612 child = 2
613 if child == 1:
614 sys.stdout.write('| ')
615 else:
616 sys.stdout.write('\ ')
617 last_level = level
618 print()
619
620
621 def label_branches(tree):
622 """Label each vertical branch (line of 'child1's) with an id number."""
623 counter = 1
624 tree['branch'] = counter
625 nodes = [tree]
626 while nodes:
627 node = nodes.pop(0)
628 child1 = node.get('child1')
629 if child1:
630 child1['branch'] = node['branch']
631 nodes.append(child1)
632 child2 = node.get('child2')
633 if child2:
634 counter += 1
635 child2['branch'] = counter
636 nodes.append(child2)
637
638
639 def print_tree(tree_orig):
640 # We "write" strings to an output buffer instead of directly printing, so we can post-process the
641 # output. The buffer is a matrix of cells, each holding a string representing one element.
642 lines = [[]]
643 # Let's operate on a copy only.
644 tree = copy.deepcopy(tree_orig)
645 # Add some bookkeeping data.
646 label_branches(tree)
647 tree['level'] = 0
648 branches = [tree]
649 while branches:
650 line = lines[-1]
651 branch = branches.pop()
652 level = branch['level']
653 while level > 0:
654 line.append(' ')
655 level -= 1
656 node = branch
657 while node:
658 # Is it the root node? (Have we written anything yet?)
659 if lines[0]:
660 # Are we at the start of the line? (Is it only spaces so far?)
661 if line[-1] == ' ':
662 line.append('\-')
663 elif line[-1].endswith('-'):
664 line.append('=-')
665 else:
666 line.append('*-')
667 child2 = node.get('child2')
668 if child2:
669 child2['level'] = node['level'] + 1
670 branches.append(child2)
671 parent = node
672 node = node.get('child1')
673 if node:
674 node['level'] = parent['level'] + 1
675 else:
676 line.append(' {}'.format(parent['branch']))
677 lines.append([])
678 # Post-process output: Add lines connecting branches to parents.
679 x = 0
680 done = False
681 while not done:
682 # Draw vertical lines upward from branch points.
683 drawing = False
684 for line in reversed(lines):
685 done = True
686 if x < len(line):
687 done = False
688 cell = line[x]
689 if cell == '\-':
690 drawing = True
691 elif cell == ' ' and drawing:
692 line[x] = '| '
693 elif cell == '=-' and drawing:
694 drawing = False
695 x += 1
696 # Print the final output.
697 for line in lines:
698 print(''.join(line))
699
700
701 def fail(message):
702 sys.stderr.write(message+"\n")
703 sys.exit(1)
704
705 if __name__ == '__main__':
706 sys.exit(main(sys.argv))