Mercurial > repos > nick > duplex
comparison utils/sim.py @ 18:e4d75f9efb90 draft
planemo upload commit b'4303231da9e48b2719b4429a29b72421d24310f4\n'-dirty
author | nick |
---|---|
date | Thu, 02 Feb 2017 18:44:31 -0500 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
17:836fa4fe9494 | 18:e4d75f9efb90 |
---|---|
1 #!/usr/bin/env python | |
2 from __future__ import division | |
3 from __future__ import print_function | |
4 import re | |
5 import os | |
6 import sys | |
7 import copy | |
8 import numpy | |
9 import bisect | |
10 import random | |
11 import string | |
12 import numbers | |
13 import tempfile | |
14 import argparse | |
15 import subprocess | |
16 import fastqreader | |
17 | |
18 REVCOMP_TABLE = string.maketrans('acgtrymkbdhvACGTRYMKBDHV', 'tgcayrkmvhdbTGCAYRKMVHDB') | |
19 WGSIM_ID_REGEX = r'^(.+)_(\d+)_(\d+)_\d+:\d+:\d+_\d+:\d+:\d+_([0-9a-f]+)/[12]$' | |
20 ARG_DEFAULTS = {'read_len':100, 'frag_len':400, 'n_frags':1000, 'out_format':'fasta', | |
21 'seq_error':0.001, 'pcr_error':0.001, 'cycles':25, 'indel_rate':0.15, | |
22 'ext_rate':0.3, 'seed':None, 'invariant':'TGACT', 'bar_len':12, 'fastq_qual':'I'} | |
23 USAGE = "%(prog)s [options]" | |
24 DESCRIPTION = """Simulate a duplex sequencing experiment.""" | |
25 | |
26 RAW_DISTRIBUTION = ( | |
27 # 0 1 2 3 4 5 6 7 8 9 | |
28 # Low singletons, but then constant drop-off. From pML113 (see 2015-09-28 report). | |
29 # 0, 100, 36, 31, 27, 22, 17, 12, 7, 4.3, | |
30 #2.4, 1.2, 0.6, 0.3, 0.2, 0.15, 0.1, 0.07, 0.05, 0.03, | |
31 # High singletons, but then a second peak around 10. From Christine plasmid (2015-10-06 report). | |
32 # 0, 100, 5.24, 3.67, 3.50, 3.67, 3.85, 4.02, 4.11, 4.20, | |
33 # 4.17, 4.10, 4.00, 3.85, 3.69, 3.55, 3.38, 3.15, 2.92, 2.62, | |
34 # 2.27, 2.01, 1.74, 1.56, 1.38, 1.20, 1.02, 0.85, | |
35 # Same as above, but low singletons, 2's, and 3's (rely on errors to fill out those). | |
36 0, 1, 2, 3, 3.50, 3.67, 3.85, 4.02, 4.11, 4.20, | |
37 4.17, 4.10, 4.00, 3.85, 3.69, 3.55, 3.38, 3.15, 2.92, 2.62, | |
38 2.27, 2.01, 1.74, 1.56, 1.38, 1.20, 1.02, 0.85, | |
39 ) | |
40 | |
41 | |
42 def main(argv): | |
43 | |
44 parser = argparse.ArgumentParser(description=DESCRIPTION) | |
45 parser.set_defaults(**ARG_DEFAULTS) | |
46 | |
47 parser.add_argument('ref', metavar='ref.fa', nargs='?', | |
48 help='Reference sequence. Omit if giving --frag-file.') | |
49 parser.add_argument('out1', type=argparse.FileType('w'), | |
50 help='Write final mate 1 reads to this file.') | |
51 parser.add_argument('out2', type=argparse.FileType('w'), | |
52 help='Write final mate 2 reads to this file.') | |
53 parser.add_argument('-o', '--out-format', choices=('fastq', 'fasta')) | |
54 parser.add_argument('--stdout', action='store_true', | |
55 help='Print interleaved output reads to stdout.') | |
56 parser.add_argument('-m', '--mutations', type=argparse.FileType('w'), | |
57 help='Write a log of the PCR and sequencing errors introduced to this file. Will overwrite any ' | |
58 'existing file at this path.') | |
59 parser.add_argument('-b', '--barcodes', type=argparse.FileType('w'), | |
60 help='Write a log of which barcodes were ligated to which fragments. Will overwrite any ' | |
61 'existing file at this path.') | |
62 parser.add_argument('--frag-file', | |
63 help='The path of the FASTQ file of fragments. If --ref is given, these will be generated with ' | |
64 'wgsim and kept (normally a temporary file is used, then deleted). Note: the file will be ' | |
65 'overwritten! If --ref is not given, then this should be a file of already generated ' | |
66 'fragments, and they will be used instead of generating new ones.') | |
67 parser.add_argument('-Q', '--fastq-qual', | |
68 help='The quality score to assign to all bases in FASTQ output. Give a character or PHRED ' | |
69 'score (integer). A PHRED score will be converted using the Sanger offset (33). Default: ' | |
70 '"%(default)s"') | |
71 parser.add_argument('-S', '--seed', type=int, | |
72 help='Random number generator seed. By default, a random, 32-bit seed will be generated and ' | |
73 'logged to stdout.') | |
74 params = parser.add_argument_group('simulation parameters') | |
75 params.add_argument('-n', '--n-frags', type=int, | |
76 help='The number of original fragment molecules to simulate. The final number of reads will be ' | |
77 'this multiplied by the average number of reads per family. If you provide fragments with ' | |
78 '--frag-file, the script will still only read in the number specified here. Default: ' | |
79 '%(default)s') | |
80 params.add_argument('-r', '--read-len', type=int, | |
81 help='Default: %(default)s') | |
82 params.add_argument('-f', '--frag-len', type=int, | |
83 help='Default: %(default)s') | |
84 params.add_argument('-s', '--seq-error', type=float, | |
85 help='Sequencing error rate per base (0-1 proportion, not percent). Default: %(default)s') | |
86 params.add_argument('-p', '--pcr-error', type=float, | |
87 help='PCR error rate per base (0-1 proportion, not percent). Default: %(default)s') | |
88 params.add_argument('-c', '--cycles', type=int, | |
89 help='Number of PCR cycles to simulate. Default: %(default)s') | |
90 params.add_argument('-i', '--indel-rate', type=float, | |
91 help='Fraction of errors which are indels. Default: %(default)s') | |
92 params.add_argument('-E', '--extension-rate', dest='ext_rate', type=float, | |
93 help='Probability an indel is extended. Default: %(default)s') | |
94 params.add_argument('-B', '--bar-len', type=int, | |
95 help='Length of the barcodes to generate. Default: %(default)s') | |
96 params.add_argument('-I', '--invariant', | |
97 help='The invariant linker sequence between the barcode and sample sequence in each read. ' | |
98 'Default: %(default)s') | |
99 | |
100 # Parse and interpret arguments. | |
101 args = parser.parse_args(argv[1:]) | |
102 assert args.ref or args.frag_file, 'You must provide either a reference or fragments file.' | |
103 if args.seed is None: | |
104 seed = random.randint(0, 2**31-1) | |
105 sys.stderr.write('seed: {}\n'.format(seed)) | |
106 else: | |
107 seed = args.seed | |
108 random.seed(seed) | |
109 if args.stdout: | |
110 out1 = sys.stdout | |
111 out2 = sys.stdout | |
112 else: | |
113 out1 = args.out1 | |
114 out2 = args.out2 | |
115 if isinstance(args.fastq_qual, numbers.Integral): | |
116 assert args.fastq_qual >= 0, '--fastq-qual cannot be negative.' | |
117 fastq_qual = chr(args.fastq_qual + 33) | |
118 elif isinstance(args.fastq_qual, basestring): | |
119 assert len(args.fastq_qual) == 1, '--fastq-qual cannot be more than a single character.' | |
120 fastq_qual = args.fastq_qual | |
121 else: | |
122 raise AssertionError('--fastq-qual must be a positive integer or single character.') | |
123 qual_line = fastq_qual * args.read_len | |
124 | |
125 invariant_rc = get_revcomp(args.invariant) | |
126 | |
127 # Create a temporary director to do our work in. Then work inside a try so we can finally remove | |
128 # the directory no matter what exceptions are encountered. | |
129 tmpfile = tempfile.NamedTemporaryFile(prefix='wgdsim.frags.') | |
130 tmpfile.close() | |
131 try: | |
132 # Step 1: Use wgsim to create fragments from the reference. | |
133 if args.frag_file: | |
134 frag_file = args.frag_file | |
135 else: | |
136 frag_file = tmpfile.name | |
137 if args.ref and os.path.isfile(args.ref) and os.path.getsize(args.ref): | |
138 #TODO: Check exit status | |
139 #TODO: Check for wgsim on the PATH. | |
140 # Set error and mutation rates to 0 to just slice sequences out of the reference without | |
141 # modification. | |
142 run_command('wgsim', '-e', '0', '-r', '0', '-d', '0', '-R', args.indel_rate, '-S', seed, | |
143 '-N', args.n_frags, '-X', args.ext_rate, '-1', args.frag_len, | |
144 args.ref, frag_file, os.devnull) | |
145 | |
146 # NOTE: Coordinates here are 0-based (0 is the first base in the sequence). | |
147 extended_dist = extend_dist(RAW_DISTRIBUTION) | |
148 proportional_dist = compile_dist(extended_dist) | |
149 n_frags = 0 | |
150 for raw_fragment in fastqreader.FastqReadGenerator(frag_file): | |
151 n_frags += 1 | |
152 if n_frags > args.n_frags: | |
153 break | |
154 chrom, id_num, start, stop = parse_read_id(raw_fragment.id) | |
155 barcode1 = get_rand_seq(args.bar_len) | |
156 barcode2 = get_rand_seq(args.bar_len) | |
157 barcode2_rc = get_revcomp(barcode2) | |
158 raw_frag_full = barcode1 + args.invariant + raw_fragment.seq + invariant_rc + barcode2 | |
159 | |
160 # Step 2: Determine how many reads to produce from each fragment. | |
161 # - Use random.random() and divide the range 0-1 into segments of sizes proportional to | |
162 # the likelihood of each family size. | |
163 # bisect.bisect() finds where an element belongs in a sorted list, returning the index. | |
164 # proportional_dist is just such a sorted list, with values from 0 to 1. | |
165 n_reads = bisect.bisect(proportional_dist, random.random()) | |
166 | |
167 # Step 3: Introduce PCR errors. | |
168 # - Determine the mutations and their frequencies. | |
169 # - Could get frequency from the cycle of PCR it occurs in. | |
170 # - Important to have PCR errors shared between reads. | |
171 # - For each read, determine which mutations it contains. | |
172 # - Use random.random() < mut_freq. | |
173 tree = get_good_pcr_tree(n_reads, args.cycles, 1000, max_diff=1) | |
174 # Add errors to all children of original fragment. | |
175 subtree1 = tree.get('child1') | |
176 subtree2 = tree.get('child2') | |
177 #TODO: Only simulate errors on portions of fragment that will become reads. | |
178 add_pcr_errors(subtree1, '+', len(raw_frag_full), args.pcr_error, args.indel_rate, args.ext_rate) | |
179 add_pcr_errors(subtree2, '-', len(raw_frag_full), args.pcr_error, args.indel_rate, args.ext_rate) | |
180 apply_pcr_errors(tree, raw_frag_full) | |
181 fragments = get_final_fragments(tree) | |
182 add_mutation_lists(tree, fragments, []) | |
183 | |
184 # Step 4: Introduce sequencing errors. | |
185 for fragment in fragments.values(): | |
186 for mutation in generate_mutations(args.read_len, args.seq_error, args.indel_rate, | |
187 args.ext_rate): | |
188 fragment['mutations'].append(mutation) | |
189 fragment['seq'] = apply_mutation(mutation, fragment['seq']) | |
190 | |
191 # Print barcodes to log file. | |
192 if args.barcodes: | |
193 args.barcodes.write('{}-{}\t{}\t{}\n'.format(chrom, id_num, barcode1, barcode2_rc)) | |
194 # Print family. | |
195 for frag_id in sorted(fragments.keys()): | |
196 fragment = fragments[frag_id] | |
197 read_id = '{}-{}-{}'.format(chrom, id_num, frag_id) | |
198 # Print mutations to log file. | |
199 if args.mutations: | |
200 read1_muts = get_mutations_subset(fragment['mutations'], 0, args.read_len) | |
201 read2_muts = get_mutations_subset(fragment['mutations'], 0, args.read_len, revcomp=True, | |
202 seqlen=len(fragment['seq'])) | |
203 if fragment['strand'] == '-': | |
204 read1_muts, read2_muts = read2_muts, read1_muts | |
205 log_mutations(args.mutations, read1_muts, read_id+'/1', chrom, start, stop) | |
206 log_mutations(args.mutations, read2_muts, read_id+'/2', chrom, start, stop) | |
207 frag_seq = fragment['seq'] | |
208 read1_seq = frag_seq[:args.read_len] | |
209 read2_seq = get_revcomp(frag_seq[len(frag_seq)-args.read_len:]) | |
210 if fragment['strand'] == '-': | |
211 read1_seq, read2_seq = read2_seq, read1_seq | |
212 if args.out_format == 'fasta': | |
213 out1.write('>{}\n{}\n'.format(read_id, read1_seq)) | |
214 out2.write('>{}\n{}\n'.format(read_id, read2_seq)) | |
215 elif args.out_format == 'fastq': | |
216 out1.write('@{}\n{}\n+\n{}\n'.format(read_id, read1_seq, qual_line)) | |
217 out2.write('@{}\n{}\n+\n{}\n'.format(read_id, read2_seq, qual_line)) | |
218 | |
219 finally: | |
220 try: | |
221 os.remove(tmpfile.name) | |
222 except OSError: | |
223 pass | |
224 | |
225 | |
226 def run_command(*command, **kwargs): | |
227 """Run a command and return the exit code. | |
228 run_command('echo', 'hello') | |
229 Will print the command to stderr before running, unless "silent" is set to True.""" | |
230 command_strs = map(str, command) | |
231 if not kwargs.get('silent'): | |
232 sys.stderr.write('$ '+' '.join(command_strs)+'\n') | |
233 devnull = open(os.devnull, 'w') | |
234 try: | |
235 exit_status = subprocess.call(map(str, command), stderr=devnull) | |
236 except OSError: | |
237 exit_status = None | |
238 finally: | |
239 devnull.close() | |
240 return exit_status | |
241 | |
242 | |
243 def extend_dist(raw_dist, exponent=1.25, min_prob=0.00001, max_len_mult=2): | |
244 """Add an exponentially decreasing tail to the distribution. | |
245 It takes the final value in the distribution and keeps dividing it by | |
246 "exponent", adding each new value to the end. It will not add probabilities | |
247 smaller than "min_prob" or extend the length of the list by more than | |
248 "max_len_mult" times.""" | |
249 extended_dist = list(raw_dist) | |
250 final_sum = sum(raw_dist) | |
251 value = raw_dist[-1] | |
252 value /= exponent | |
253 while value/final_sum >= min_prob and len(extended_dist) < len(raw_dist)*max_len_mult: | |
254 extended_dist.append(value) | |
255 final_sum += value | |
256 value /= exponent | |
257 return extended_dist | |
258 | |
259 | |
260 def compile_dist(raw_dist): | |
261 """Turn the human-readable list of probabilities defined at the top into | |
262 proportional probabilities. | |
263 E.g. [10, 5, 5] -> [0.5, 0.75, 1.0]""" | |
264 proportional_dist = [] | |
265 final_sum = sum(raw_dist) | |
266 current_sum = 0 | |
267 for magnitude in raw_dist: | |
268 current_sum += magnitude | |
269 proportional_dist.append(current_sum/final_sum) | |
270 return proportional_dist | |
271 | |
272 | |
273 def parse_read_id(read_id): | |
274 match = re.search(WGSIM_ID_REGEX, read_id) | |
275 if match: | |
276 chrom = match.group(1) | |
277 start = match.group(2) | |
278 stop = match.group(3) | |
279 id_num = match.group(4) | |
280 else: | |
281 chrom, id_num, start, stop = read_id, None, None, None | |
282 return chrom, id_num, start, stop | |
283 | |
284 | |
285 #TODO: Clean up "mutation" vs "error" terminology. | |
286 def generate_mutations(seq_len, error_rate, indel_rate, extension_rate): | |
287 """Generate all the mutations that occur over the length of a sequence.""" | |
288 i = 0 | |
289 while i <= seq_len: | |
290 if random.random() < error_rate: | |
291 mtype, alt = make_mutation(indel_rate, extension_rate) | |
292 # Allow mutation after the last base only if it's an insertion. | |
293 if i < seq_len or mtype == 'ins': | |
294 yield {'coord':i, 'type':mtype, 'alt':alt} | |
295 # Compensate for length variations to keep i tracking the original read's base coordinates. | |
296 if mtype == 'ins': | |
297 i += len(alt) | |
298 elif mtype == 'del': | |
299 i -= alt | |
300 i += 1 | |
301 | |
302 | |
303 def make_mutation(indel_rate, extension_rate): | |
304 """Simulate a random mutation.""" | |
305 # Is it an indel? | |
306 rand = random.random() | |
307 if rand < indel_rate: | |
308 # Is it an insertion or deletion? Decide, then initialize it. | |
309 # Re-use the random number from above. Just check if it's in the lower or upper half of the | |
310 # range from 0 to indel_rate. | |
311 if rand < indel_rate/2: | |
312 mtype = 'del' | |
313 alt = 1 | |
314 else: | |
315 mtype = 'ins' | |
316 alt = get_rand_base() | |
317 # Extend the indel as long as the extension rate allows. | |
318 while random.random() < extension_rate: | |
319 if mtype == 'ins': | |
320 alt += get_rand_base() | |
321 else: | |
322 alt += 1 | |
323 else: | |
324 # What is the new base for the SNV? | |
325 mtype = 'snv' | |
326 alt = get_rand_base() | |
327 return mtype, alt | |
328 | |
329 | |
330 def get_rand_base(bases='ACGT'): | |
331 return random.choice(bases) | |
332 | |
333 | |
334 def get_rand_seq(seq_len): | |
335 return ''.join([get_rand_base() for i in range(seq_len)]) | |
336 | |
337 | |
338 def get_revcomp(seq): | |
339 return seq.translate(REVCOMP_TABLE)[::-1] | |
340 | |
341 | |
342 def apply_mutation(mut, seq): | |
343 i = mut['coord'] | |
344 if mut['type'] == 'snv': | |
345 # Replace the base at "coord". | |
346 new_seq = seq[:i] + mut['alt'] + seq[i+1:] | |
347 else: | |
348 # Indels are handled by inserting or deleting bases starting *before* the base at "coord". | |
349 # This goes agains the VCF convention, but it allows deleting the first and last base, as well | |
350 # as inserting before and after the sequence without as much special-casing. | |
351 if mut['type'] == 'ins': | |
352 # Example: 'ACGTACGT' + ins 'GC' at 4 = 'ACGTGCACGT' | |
353 new_seq = seq[:i] + mut['alt'] + seq[i:] | |
354 else: | |
355 # Example: 'ACGTACGT' + del 2 at 4 = 'ACGTGT' | |
356 new_seq = seq[:i] + seq[i+mut['alt']:] | |
357 return new_seq | |
358 | |
359 | |
360 def get_mutations_subset(mutations_old, start, length, revcomp=False, seqlen=None): | |
361 """Get a list of the input mutations which are within a certain region. | |
362 The output list maintains the order in the input list, only filtering out | |
363 mutations outside the specified region. | |
364 "start" is the start of the region (0-based). If revcomp, this start should be | |
365 in the coordinate system of the reverse-complemented sequence. | |
366 "length" is the length of the region. | |
367 "revcomp" causes the mutations to be converted to their reverse complements, and | |
368 the "start" to refer to the reverse complement sequence's coordinates. The order | |
369 of the mutations is unchanged, though. | |
370 "seqlen" is the length of the sequence the mutations occurred in. This is only | |
371 needed when revcomp is True, to convert coordinates to the reverse complement | |
372 coordinate system.""" | |
373 stop = start + length | |
374 mutations_new = [] | |
375 for mutation in mutations_old: | |
376 if revcomp: | |
377 mutation = get_mutation_revcomp(mutation, seqlen) | |
378 if start <= mutation['coord'] < stop: | |
379 mutations_new.append(mutation) | |
380 elif mutation['coord'] == stop and mutation['type'] == 'ins': | |
381 # Allow insertions at the last coordinate. | |
382 mutations_new.append(mutation) | |
383 return mutations_new | |
384 | |
385 | |
386 def get_mutation_revcomp(mut, seqlen): | |
387 """Convert a mutation to its reverse complement. | |
388 "seqlen" is the length of the sequence the mutation is being applied to. Needed | |
389 to convert the coordinate to a coordinate system starting at the end of the | |
390 sequence.""" | |
391 mut_rc = {'type':mut['type']} | |
392 if mut['type'] == 'snv': | |
393 mut_rc['coord'] = seqlen - mut['coord'] - 1 | |
394 mut_rc['alt'] = get_revcomp(mut['alt']) | |
395 elif mut['type'] == 'ins': | |
396 mut_rc['coord'] = seqlen - mut['coord'] | |
397 mut_rc['alt'] = get_revcomp(mut['alt']) | |
398 elif mut['type'] == 'del': | |
399 mut_rc['coord'] = seqlen - mut['coord'] - mut['alt'] | |
400 mut_rc['alt'] = mut['alt'] | |
401 return mut_rc | |
402 | |
403 | |
404 def log_mutations(mutfile, mutations, read_id, chrom, start, stop): | |
405 for mutation in mutations: | |
406 mutfile.write('{read_id}\t{chrom}\t{start}\t{stop}\t{coord}\t{type}\t{alt}\n' | |
407 .format(read_id=read_id, chrom=chrom, start=start, stop=stop, **mutation)) | |
408 | |
409 | |
410 def add_pcr_errors(subtree, strand, read_len, error_rate, indel_rate, extension_rate): | |
411 """Add simulated PCR errors to a node in a tree and all its descendants.""" | |
412 # Note: The errors are intended as "errors made in creating this molecule", so don't apply this to | |
413 # the root node, since that is supposed to be the original, unaltered molecule. | |
414 # Go down the subtree and simulate errors in creating each fragment. | |
415 # Process all the first-child descendants of the original node in a loop, and recursively call | |
416 # this function to process all second children. | |
417 node = subtree | |
418 while node: | |
419 node['strand'] = strand | |
420 node['errors'] = list(generate_mutations(read_len, error_rate, indel_rate, extension_rate)) | |
421 add_pcr_errors(node.get('child2'), strand, read_len, error_rate, indel_rate, extension_rate) | |
422 node = node.get('child1') | |
423 | |
424 | |
425 def apply_pcr_errors(subtree, seq): | |
426 node = subtree | |
427 while node: | |
428 for error in node.get('errors', ()): | |
429 seq = apply_mutation(error, seq) | |
430 if 'child1' not in node: | |
431 node['seq'] = seq | |
432 apply_pcr_errors(node.get('child2'), seq) | |
433 node = node.get('child1') | |
434 | |
435 | |
436 def get_final_fragments(tree): | |
437 """Walk to the leaf nodes of the tree and get the post-PCR sequences of all the fragments. | |
438 Returns a dict mapping fragment id number to a dict representing the fragment. Its only two keys | |
439 are 'seq' (the final sequence) and 'strand' ('+' or '-').""" | |
440 fragments = {} | |
441 nodes = [tree] | |
442 while nodes: | |
443 node = nodes.pop() | |
444 child1 = node.get('child1') | |
445 if child1: | |
446 nodes.append(child1) | |
447 else: | |
448 fragments[node['id']] = {'seq':node['seq'], 'strand':node['strand']} | |
449 child2 = node.get('child2') | |
450 if child2: | |
451 nodes.append(child2) | |
452 return fragments | |
453 | |
454 | |
455 def add_mutation_lists(subtree, fragments, mut_list1): | |
456 """Compile the list of mutations that each fragment has undergone in PCR. | |
457 To call from the root, give [] as "mut_list1" and a dict mapping all existing node id's to a dict | |
458 as "fragments". Instead of returning the data, this will add a 'mutations' key to the dict for | |
459 each fragment, mapping it to a list of PCR mutations that occurred in the lineage of the fragment, | |
460 in chronological order.""" | |
461 node = subtree | |
462 while node: | |
463 mut_list1.extend(node.get('errors', ())) | |
464 if 'child1' not in node: | |
465 fragments[node['id']]['mutations'] = mut_list1 | |
466 if 'child2' in node: | |
467 mut_list2 = copy.deepcopy(mut_list1) | |
468 add_mutation_lists(node.get('child2'), fragments, mut_list2) | |
469 node = node.get('child1') | |
470 | |
471 | |
472 def check_tree_balance(subtree): | |
473 """Find all points in the tree where the cycles of sibling nodes is unequal, and | |
474 return the maximum difference.""" | |
475 node = subtree | |
476 if node: | |
477 child1 = node.get('child1') | |
478 child2 = node.get('child2') | |
479 if child1 and child2: | |
480 diff = abs(child1['cycle'] - child2['cycle']) | |
481 else: | |
482 diff = 0 | |
483 diff_child1 = check_tree_balance(child1) | |
484 diff_child2 = check_tree_balance(child2) | |
485 return max(diff, diff_child1, diff_child2) | |
486 else: | |
487 return 0 | |
488 | |
489 | |
490 def get_good_pcr_tree(n_reads, n_cycles, max_tries, max_diff=1): | |
491 """Return a single, balanced PCR tree from build_pcr_tree(), or fail if one cannot | |
492 be found in max_tries. | |
493 Compensate for bugs in build_pcr_tree() that sometimes result in multiple trees, | |
494 or trees with siblings from different cycles.""" | |
495 tries = 0 | |
496 while tries <= max_tries: | |
497 trees = build_pcr_tree(n_reads, n_cycles) | |
498 if len(trees) == 1 and check_tree_balance(trees[0]) <= max_diff: | |
499 return trees[0] | |
500 tries += 1 | |
501 raise AssertionError('Could not generate a single, balanced tree! (tried {} times)' | |
502 .format(max_tries)) | |
503 | |
504 | |
505 def build_pcr_tree(n_reads, n_cycles): | |
506 """Create a simulated descent lineage of how all the final PCR fragments are related. | |
507 Each node represents a fragment molecule at one stage of PCR. Each node is a dict containing the | |
508 fragment's children (other nodes) ('child1' and 'child2'), the PCR cycle number ('cycle'), and, | |
509 at the leaves, a unique id number for each final fragment. | |
510 Returns a list of root nodes. Usually there will only be one, but about 1-3% of the time it fails | |
511 to unify the subtrees and results in a broken tree. | |
512 """ | |
513 #TODO: Make it always return a single tree. | |
514 # Begin a branch for each of the fragments. These are the leaf nodes. We'll work backward from | |
515 # these, simulating the points at which they share ancestors, eventually coalescing into the | |
516 # single (root) ancestor. | |
517 branches = [] | |
518 for frag_id in range(n_reads): | |
519 branches.append({'cycle':n_cycles-1, 'id':frag_id}) | |
520 # Build up all the branches in parallel. Start from the second-to-last PCR cycle. | |
521 for cycle in reversed(range(n_cycles-1)): | |
522 # Probability of 2 fragments sharing an ancestor at cycle c is 1/2^c. | |
523 prob = 1/2**cycle | |
524 frag_i = 0 | |
525 while frag_i < len(branches): | |
526 current_root = branches[frag_i] | |
527 # Does the current fragment share this ancestor with any of the other fragments? | |
528 # numpy.random.binomial() is a fast way to simulate going through every other fragment and | |
529 # asking if random.random() < prob. | |
530 shared = numpy.random.binomial(len(branches)-1, prob) | |
531 if shared == 0: | |
532 # No branch point here. Just add another level to the lineage. | |
533 branches[frag_i] = {'cycle':cycle, 'child1':current_root} | |
534 else: | |
535 # Pick a random other fragment to share this ancestor with. | |
536 # Make a list of candidates to pick from. | |
537 candidates = [] | |
538 for candidate_i, candidate in enumerate(branches): | |
539 # Don't include ourselves. | |
540 if candidate is current_root: | |
541 continue | |
542 # If it's at a cycle above us and it already has a child, skip it. | |
543 if candidate['cycle'] == cycle and candidate.get('child2'): | |
544 continue | |
545 candidates.append(candidate_i) | |
546 if candidates: | |
547 relative_i = random.choice(candidates) | |
548 relative = branches[relative_i] | |
549 # Have we already passed this fragmentfragment on this cycle? | |
550 if relative['cycle'] == cycle: | |
551 # If we've already passed it, we're looking at the fragment's parent. We want the child. | |
552 relative = relative['child1'] | |
553 # Join the lineages of our current fragment and the relative to a new parent. | |
554 #TODO: Sometimes, we end up matching up subtrees of different depths. But the discrepancy | |
555 # is rarely greater than 1. Figure out why. | |
556 # assert abs(current_root['cycle'] - relative['cycle']) < 3, ('cycle: {}, current_root: {},' | |
557 # ' relative: {}, frag_i: {}, relative_i: {}, branches: {}, candidates: {}, shared: {}' | |
558 # .format(cycle, current_root['cycle'], relative['cycle'], frag_i, relative_i, | |
559 # len(branches), len(candidates), shared)) | |
560 branches[frag_i] = {'cycle':cycle, 'child1':current_root, 'child2':relative} | |
561 # Remove the relative from the list of lineages. | |
562 del(branches[relative_i]) | |
563 if relative_i < frag_i: | |
564 frag_i -= 1 | |
565 frag_i += 1 | |
566 return branches | |
567 | |
568 | |
569 def get_depth(tree): | |
570 depth = 0 | |
571 node = tree | |
572 while node: | |
573 depth += 1 | |
574 node = node.get('child1') | |
575 return depth | |
576 | |
577 | |
578 def convert_tree(tree_orig): | |
579 # Let's operate on a copy only. | |
580 tree = copy.deepcopy(tree_orig) | |
581 # Turn the tree vertical. | |
582 tree['line'] = 1 | |
583 tree['children'] = 0 | |
584 levels = [[tree]] | |
585 done = False | |
586 while not done: | |
587 last_level = levels[-1] | |
588 this_level = [] | |
589 done = True | |
590 for node in last_level: | |
591 for child_name in ('child1', 'child2'): | |
592 child = node.get(child_name) | |
593 if child: | |
594 done = False | |
595 child['parent'] = node | |
596 child['branch'] = child['parent']['branch'] | |
597 if child_name == 'child2': | |
598 child['branch'] += 1 | |
599 this_level.append(child) | |
600 this_level.sort(key=lambda node: node['branch']) | |
601 levels.append(this_level) | |
602 return levels | |
603 | |
604 | |
605 def print_levels(levels): | |
606 last_level = [] | |
607 for level in levels: | |
608 for node in level: | |
609 child = 1 | |
610 for parent in last_level: | |
611 if parent.get('child2') is node: | |
612 child = 2 | |
613 if child == 1: | |
614 sys.stdout.write('| ') | |
615 else: | |
616 sys.stdout.write('\ ') | |
617 last_level = level | |
618 print() | |
619 | |
620 | |
621 def label_branches(tree): | |
622 """Label each vertical branch (line of 'child1's) with an id number.""" | |
623 counter = 1 | |
624 tree['branch'] = counter | |
625 nodes = [tree] | |
626 while nodes: | |
627 node = nodes.pop(0) | |
628 child1 = node.get('child1') | |
629 if child1: | |
630 child1['branch'] = node['branch'] | |
631 nodes.append(child1) | |
632 child2 = node.get('child2') | |
633 if child2: | |
634 counter += 1 | |
635 child2['branch'] = counter | |
636 nodes.append(child2) | |
637 | |
638 | |
639 def print_tree(tree_orig): | |
640 # We "write" strings to an output buffer instead of directly printing, so we can post-process the | |
641 # output. The buffer is a matrix of cells, each holding a string representing one element. | |
642 lines = [[]] | |
643 # Let's operate on a copy only. | |
644 tree = copy.deepcopy(tree_orig) | |
645 # Add some bookkeeping data. | |
646 label_branches(tree) | |
647 tree['level'] = 0 | |
648 branches = [tree] | |
649 while branches: | |
650 line = lines[-1] | |
651 branch = branches.pop() | |
652 level = branch['level'] | |
653 while level > 0: | |
654 line.append(' ') | |
655 level -= 1 | |
656 node = branch | |
657 while node: | |
658 # Is it the root node? (Have we written anything yet?) | |
659 if lines[0]: | |
660 # Are we at the start of the line? (Is it only spaces so far?) | |
661 if line[-1] == ' ': | |
662 line.append('\-') | |
663 elif line[-1].endswith('-'): | |
664 line.append('=-') | |
665 else: | |
666 line.append('*-') | |
667 child2 = node.get('child2') | |
668 if child2: | |
669 child2['level'] = node['level'] + 1 | |
670 branches.append(child2) | |
671 parent = node | |
672 node = node.get('child1') | |
673 if node: | |
674 node['level'] = parent['level'] + 1 | |
675 else: | |
676 line.append(' {}'.format(parent['branch'])) | |
677 lines.append([]) | |
678 # Post-process output: Add lines connecting branches to parents. | |
679 x = 0 | |
680 done = False | |
681 while not done: | |
682 # Draw vertical lines upward from branch points. | |
683 drawing = False | |
684 for line in reversed(lines): | |
685 done = True | |
686 if x < len(line): | |
687 done = False | |
688 cell = line[x] | |
689 if cell == '\-': | |
690 drawing = True | |
691 elif cell == ' ' and drawing: | |
692 line[x] = '| ' | |
693 elif cell == '=-' and drawing: | |
694 drawing = False | |
695 x += 1 | |
696 # Print the final output. | |
697 for line in lines: | |
698 print(''.join(line)) | |
699 | |
700 | |
701 def fail(message): | |
702 sys.stderr.write(message+"\n") | |
703 sys.exit(1) | |
704 | |
705 if __name__ == '__main__': | |
706 sys.exit(main(sys.argv)) |