Mercurial > repos > nick > duplex
comparison utils/sim.py @ 18:e4d75f9efb90 draft
planemo upload commit b'4303231da9e48b2719b4429a29b72421d24310f4\n'-dirty
| author | nick |
|---|---|
| date | Thu, 02 Feb 2017 18:44:31 -0500 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 17:836fa4fe9494 | 18:e4d75f9efb90 |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 from __future__ import division | |
| 3 from __future__ import print_function | |
| 4 import re | |
| 5 import os | |
| 6 import sys | |
| 7 import copy | |
| 8 import numpy | |
| 9 import bisect | |
| 10 import random | |
| 11 import string | |
| 12 import numbers | |
| 13 import tempfile | |
| 14 import argparse | |
| 15 import subprocess | |
| 16 import fastqreader | |
| 17 | |
| 18 REVCOMP_TABLE = string.maketrans('acgtrymkbdhvACGTRYMKBDHV', 'tgcayrkmvhdbTGCAYRKMVHDB') | |
| 19 WGSIM_ID_REGEX = r'^(.+)_(\d+)_(\d+)_\d+:\d+:\d+_\d+:\d+:\d+_([0-9a-f]+)/[12]$' | |
| 20 ARG_DEFAULTS = {'read_len':100, 'frag_len':400, 'n_frags':1000, 'out_format':'fasta', | |
| 21 'seq_error':0.001, 'pcr_error':0.001, 'cycles':25, 'indel_rate':0.15, | |
| 22 'ext_rate':0.3, 'seed':None, 'invariant':'TGACT', 'bar_len':12, 'fastq_qual':'I'} | |
| 23 USAGE = "%(prog)s [options]" | |
| 24 DESCRIPTION = """Simulate a duplex sequencing experiment.""" | |
| 25 | |
| 26 RAW_DISTRIBUTION = ( | |
| 27 # 0 1 2 3 4 5 6 7 8 9 | |
| 28 # Low singletons, but then constant drop-off. From pML113 (see 2015-09-28 report). | |
| 29 # 0, 100, 36, 31, 27, 22, 17, 12, 7, 4.3, | |
| 30 #2.4, 1.2, 0.6, 0.3, 0.2, 0.15, 0.1, 0.07, 0.05, 0.03, | |
| 31 # High singletons, but then a second peak around 10. From Christine plasmid (2015-10-06 report). | |
| 32 # 0, 100, 5.24, 3.67, 3.50, 3.67, 3.85, 4.02, 4.11, 4.20, | |
| 33 # 4.17, 4.10, 4.00, 3.85, 3.69, 3.55, 3.38, 3.15, 2.92, 2.62, | |
| 34 # 2.27, 2.01, 1.74, 1.56, 1.38, 1.20, 1.02, 0.85, | |
| 35 # Same as above, but low singletons, 2's, and 3's (rely on errors to fill out those). | |
| 36 0, 1, 2, 3, 3.50, 3.67, 3.85, 4.02, 4.11, 4.20, | |
| 37 4.17, 4.10, 4.00, 3.85, 3.69, 3.55, 3.38, 3.15, 2.92, 2.62, | |
| 38 2.27, 2.01, 1.74, 1.56, 1.38, 1.20, 1.02, 0.85, | |
| 39 ) | |
| 40 | |
| 41 | |
| 42 def main(argv): | |
| 43 | |
| 44 parser = argparse.ArgumentParser(description=DESCRIPTION) | |
| 45 parser.set_defaults(**ARG_DEFAULTS) | |
| 46 | |
| 47 parser.add_argument('ref', metavar='ref.fa', nargs='?', | |
| 48 help='Reference sequence. Omit if giving --frag-file.') | |
| 49 parser.add_argument('out1', type=argparse.FileType('w'), | |
| 50 help='Write final mate 1 reads to this file.') | |
| 51 parser.add_argument('out2', type=argparse.FileType('w'), | |
| 52 help='Write final mate 2 reads to this file.') | |
| 53 parser.add_argument('-o', '--out-format', choices=('fastq', 'fasta')) | |
| 54 parser.add_argument('--stdout', action='store_true', | |
| 55 help='Print interleaved output reads to stdout.') | |
| 56 parser.add_argument('-m', '--mutations', type=argparse.FileType('w'), | |
| 57 help='Write a log of the PCR and sequencing errors introduced to this file. Will overwrite any ' | |
| 58 'existing file at this path.') | |
| 59 parser.add_argument('-b', '--barcodes', type=argparse.FileType('w'), | |
| 60 help='Write a log of which barcodes were ligated to which fragments. Will overwrite any ' | |
| 61 'existing file at this path.') | |
| 62 parser.add_argument('--frag-file', | |
| 63 help='The path of the FASTQ file of fragments. If --ref is given, these will be generated with ' | |
| 64 'wgsim and kept (normally a temporary file is used, then deleted). Note: the file will be ' | |
| 65 'overwritten! If --ref is not given, then this should be a file of already generated ' | |
| 66 'fragments, and they will be used instead of generating new ones.') | |
| 67 parser.add_argument('-Q', '--fastq-qual', | |
| 68 help='The quality score to assign to all bases in FASTQ output. Give a character or PHRED ' | |
| 69 'score (integer). A PHRED score will be converted using the Sanger offset (33). Default: ' | |
| 70 '"%(default)s"') | |
| 71 parser.add_argument('-S', '--seed', type=int, | |
| 72 help='Random number generator seed. By default, a random, 32-bit seed will be generated and ' | |
| 73 'logged to stdout.') | |
| 74 params = parser.add_argument_group('simulation parameters') | |
| 75 params.add_argument('-n', '--n-frags', type=int, | |
| 76 help='The number of original fragment molecules to simulate. The final number of reads will be ' | |
| 77 'this multiplied by the average number of reads per family. If you provide fragments with ' | |
| 78 '--frag-file, the script will still only read in the number specified here. Default: ' | |
| 79 '%(default)s') | |
| 80 params.add_argument('-r', '--read-len', type=int, | |
| 81 help='Default: %(default)s') | |
| 82 params.add_argument('-f', '--frag-len', type=int, | |
| 83 help='Default: %(default)s') | |
| 84 params.add_argument('-s', '--seq-error', type=float, | |
| 85 help='Sequencing error rate per base (0-1 proportion, not percent). Default: %(default)s') | |
| 86 params.add_argument('-p', '--pcr-error', type=float, | |
| 87 help='PCR error rate per base (0-1 proportion, not percent). Default: %(default)s') | |
| 88 params.add_argument('-c', '--cycles', type=int, | |
| 89 help='Number of PCR cycles to simulate. Default: %(default)s') | |
| 90 params.add_argument('-i', '--indel-rate', type=float, | |
| 91 help='Fraction of errors which are indels. Default: %(default)s') | |
| 92 params.add_argument('-E', '--extension-rate', dest='ext_rate', type=float, | |
| 93 help='Probability an indel is extended. Default: %(default)s') | |
| 94 params.add_argument('-B', '--bar-len', type=int, | |
| 95 help='Length of the barcodes to generate. Default: %(default)s') | |
| 96 params.add_argument('-I', '--invariant', | |
| 97 help='The invariant linker sequence between the barcode and sample sequence in each read. ' | |
| 98 'Default: %(default)s') | |
| 99 | |
| 100 # Parse and interpret arguments. | |
| 101 args = parser.parse_args(argv[1:]) | |
| 102 assert args.ref or args.frag_file, 'You must provide either a reference or fragments file.' | |
| 103 if args.seed is None: | |
| 104 seed = random.randint(0, 2**31-1) | |
| 105 sys.stderr.write('seed: {}\n'.format(seed)) | |
| 106 else: | |
| 107 seed = args.seed | |
| 108 random.seed(seed) | |
| 109 if args.stdout: | |
| 110 out1 = sys.stdout | |
| 111 out2 = sys.stdout | |
| 112 else: | |
| 113 out1 = args.out1 | |
| 114 out2 = args.out2 | |
| 115 if isinstance(args.fastq_qual, numbers.Integral): | |
| 116 assert args.fastq_qual >= 0, '--fastq-qual cannot be negative.' | |
| 117 fastq_qual = chr(args.fastq_qual + 33) | |
| 118 elif isinstance(args.fastq_qual, basestring): | |
| 119 assert len(args.fastq_qual) == 1, '--fastq-qual cannot be more than a single character.' | |
| 120 fastq_qual = args.fastq_qual | |
| 121 else: | |
| 122 raise AssertionError('--fastq-qual must be a positive integer or single character.') | |
| 123 qual_line = fastq_qual * args.read_len | |
| 124 | |
| 125 invariant_rc = get_revcomp(args.invariant) | |
| 126 | |
| 127 # Create a temporary director to do our work in. Then work inside a try so we can finally remove | |
| 128 # the directory no matter what exceptions are encountered. | |
| 129 tmpfile = tempfile.NamedTemporaryFile(prefix='wgdsim.frags.') | |
| 130 tmpfile.close() | |
| 131 try: | |
| 132 # Step 1: Use wgsim to create fragments from the reference. | |
| 133 if args.frag_file: | |
| 134 frag_file = args.frag_file | |
| 135 else: | |
| 136 frag_file = tmpfile.name | |
| 137 if args.ref and os.path.isfile(args.ref) and os.path.getsize(args.ref): | |
| 138 #TODO: Check exit status | |
| 139 #TODO: Check for wgsim on the PATH. | |
| 140 # Set error and mutation rates to 0 to just slice sequences out of the reference without | |
| 141 # modification. | |
| 142 run_command('wgsim', '-e', '0', '-r', '0', '-d', '0', '-R', args.indel_rate, '-S', seed, | |
| 143 '-N', args.n_frags, '-X', args.ext_rate, '-1', args.frag_len, | |
| 144 args.ref, frag_file, os.devnull) | |
| 145 | |
| 146 # NOTE: Coordinates here are 0-based (0 is the first base in the sequence). | |
| 147 extended_dist = extend_dist(RAW_DISTRIBUTION) | |
| 148 proportional_dist = compile_dist(extended_dist) | |
| 149 n_frags = 0 | |
| 150 for raw_fragment in fastqreader.FastqReadGenerator(frag_file): | |
| 151 n_frags += 1 | |
| 152 if n_frags > args.n_frags: | |
| 153 break | |
| 154 chrom, id_num, start, stop = parse_read_id(raw_fragment.id) | |
| 155 barcode1 = get_rand_seq(args.bar_len) | |
| 156 barcode2 = get_rand_seq(args.bar_len) | |
| 157 barcode2_rc = get_revcomp(barcode2) | |
| 158 raw_frag_full = barcode1 + args.invariant + raw_fragment.seq + invariant_rc + barcode2 | |
| 159 | |
| 160 # Step 2: Determine how many reads to produce from each fragment. | |
| 161 # - Use random.random() and divide the range 0-1 into segments of sizes proportional to | |
| 162 # the likelihood of each family size. | |
| 163 # bisect.bisect() finds where an element belongs in a sorted list, returning the index. | |
| 164 # proportional_dist is just such a sorted list, with values from 0 to 1. | |
| 165 n_reads = bisect.bisect(proportional_dist, random.random()) | |
| 166 | |
| 167 # Step 3: Introduce PCR errors. | |
| 168 # - Determine the mutations and their frequencies. | |
| 169 # - Could get frequency from the cycle of PCR it occurs in. | |
| 170 # - Important to have PCR errors shared between reads. | |
| 171 # - For each read, determine which mutations it contains. | |
| 172 # - Use random.random() < mut_freq. | |
| 173 tree = get_good_pcr_tree(n_reads, args.cycles, 1000, max_diff=1) | |
| 174 # Add errors to all children of original fragment. | |
| 175 subtree1 = tree.get('child1') | |
| 176 subtree2 = tree.get('child2') | |
| 177 #TODO: Only simulate errors on portions of fragment that will become reads. | |
| 178 add_pcr_errors(subtree1, '+', len(raw_frag_full), args.pcr_error, args.indel_rate, args.ext_rate) | |
| 179 add_pcr_errors(subtree2, '-', len(raw_frag_full), args.pcr_error, args.indel_rate, args.ext_rate) | |
| 180 apply_pcr_errors(tree, raw_frag_full) | |
| 181 fragments = get_final_fragments(tree) | |
| 182 add_mutation_lists(tree, fragments, []) | |
| 183 | |
| 184 # Step 4: Introduce sequencing errors. | |
| 185 for fragment in fragments.values(): | |
| 186 for mutation in generate_mutations(args.read_len, args.seq_error, args.indel_rate, | |
| 187 args.ext_rate): | |
| 188 fragment['mutations'].append(mutation) | |
| 189 fragment['seq'] = apply_mutation(mutation, fragment['seq']) | |
| 190 | |
| 191 # Print barcodes to log file. | |
| 192 if args.barcodes: | |
| 193 args.barcodes.write('{}-{}\t{}\t{}\n'.format(chrom, id_num, barcode1, barcode2_rc)) | |
| 194 # Print family. | |
| 195 for frag_id in sorted(fragments.keys()): | |
| 196 fragment = fragments[frag_id] | |
| 197 read_id = '{}-{}-{}'.format(chrom, id_num, frag_id) | |
| 198 # Print mutations to log file. | |
| 199 if args.mutations: | |
| 200 read1_muts = get_mutations_subset(fragment['mutations'], 0, args.read_len) | |
| 201 read2_muts = get_mutations_subset(fragment['mutations'], 0, args.read_len, revcomp=True, | |
| 202 seqlen=len(fragment['seq'])) | |
| 203 if fragment['strand'] == '-': | |
| 204 read1_muts, read2_muts = read2_muts, read1_muts | |
| 205 log_mutations(args.mutations, read1_muts, read_id+'/1', chrom, start, stop) | |
| 206 log_mutations(args.mutations, read2_muts, read_id+'/2', chrom, start, stop) | |
| 207 frag_seq = fragment['seq'] | |
| 208 read1_seq = frag_seq[:args.read_len] | |
| 209 read2_seq = get_revcomp(frag_seq[len(frag_seq)-args.read_len:]) | |
| 210 if fragment['strand'] == '-': | |
| 211 read1_seq, read2_seq = read2_seq, read1_seq | |
| 212 if args.out_format == 'fasta': | |
| 213 out1.write('>{}\n{}\n'.format(read_id, read1_seq)) | |
| 214 out2.write('>{}\n{}\n'.format(read_id, read2_seq)) | |
| 215 elif args.out_format == 'fastq': | |
| 216 out1.write('@{}\n{}\n+\n{}\n'.format(read_id, read1_seq, qual_line)) | |
| 217 out2.write('@{}\n{}\n+\n{}\n'.format(read_id, read2_seq, qual_line)) | |
| 218 | |
| 219 finally: | |
| 220 try: | |
| 221 os.remove(tmpfile.name) | |
| 222 except OSError: | |
| 223 pass | |
| 224 | |
| 225 | |
| 226 def run_command(*command, **kwargs): | |
| 227 """Run a command and return the exit code. | |
| 228 run_command('echo', 'hello') | |
| 229 Will print the command to stderr before running, unless "silent" is set to True.""" | |
| 230 command_strs = map(str, command) | |
| 231 if not kwargs.get('silent'): | |
| 232 sys.stderr.write('$ '+' '.join(command_strs)+'\n') | |
| 233 devnull = open(os.devnull, 'w') | |
| 234 try: | |
| 235 exit_status = subprocess.call(map(str, command), stderr=devnull) | |
| 236 except OSError: | |
| 237 exit_status = None | |
| 238 finally: | |
| 239 devnull.close() | |
| 240 return exit_status | |
| 241 | |
| 242 | |
| 243 def extend_dist(raw_dist, exponent=1.25, min_prob=0.00001, max_len_mult=2): | |
| 244 """Add an exponentially decreasing tail to the distribution. | |
| 245 It takes the final value in the distribution and keeps dividing it by | |
| 246 "exponent", adding each new value to the end. It will not add probabilities | |
| 247 smaller than "min_prob" or extend the length of the list by more than | |
| 248 "max_len_mult" times.""" | |
| 249 extended_dist = list(raw_dist) | |
| 250 final_sum = sum(raw_dist) | |
| 251 value = raw_dist[-1] | |
| 252 value /= exponent | |
| 253 while value/final_sum >= min_prob and len(extended_dist) < len(raw_dist)*max_len_mult: | |
| 254 extended_dist.append(value) | |
| 255 final_sum += value | |
| 256 value /= exponent | |
| 257 return extended_dist | |
| 258 | |
| 259 | |
| 260 def compile_dist(raw_dist): | |
| 261 """Turn the human-readable list of probabilities defined at the top into | |
| 262 proportional probabilities. | |
| 263 E.g. [10, 5, 5] -> [0.5, 0.75, 1.0]""" | |
| 264 proportional_dist = [] | |
| 265 final_sum = sum(raw_dist) | |
| 266 current_sum = 0 | |
| 267 for magnitude in raw_dist: | |
| 268 current_sum += magnitude | |
| 269 proportional_dist.append(current_sum/final_sum) | |
| 270 return proportional_dist | |
| 271 | |
| 272 | |
| 273 def parse_read_id(read_id): | |
| 274 match = re.search(WGSIM_ID_REGEX, read_id) | |
| 275 if match: | |
| 276 chrom = match.group(1) | |
| 277 start = match.group(2) | |
| 278 stop = match.group(3) | |
| 279 id_num = match.group(4) | |
| 280 else: | |
| 281 chrom, id_num, start, stop = read_id, None, None, None | |
| 282 return chrom, id_num, start, stop | |
| 283 | |
| 284 | |
| 285 #TODO: Clean up "mutation" vs "error" terminology. | |
| 286 def generate_mutations(seq_len, error_rate, indel_rate, extension_rate): | |
| 287 """Generate all the mutations that occur over the length of a sequence.""" | |
| 288 i = 0 | |
| 289 while i <= seq_len: | |
| 290 if random.random() < error_rate: | |
| 291 mtype, alt = make_mutation(indel_rate, extension_rate) | |
| 292 # Allow mutation after the last base only if it's an insertion. | |
| 293 if i < seq_len or mtype == 'ins': | |
| 294 yield {'coord':i, 'type':mtype, 'alt':alt} | |
| 295 # Compensate for length variations to keep i tracking the original read's base coordinates. | |
| 296 if mtype == 'ins': | |
| 297 i += len(alt) | |
| 298 elif mtype == 'del': | |
| 299 i -= alt | |
| 300 i += 1 | |
| 301 | |
| 302 | |
| 303 def make_mutation(indel_rate, extension_rate): | |
| 304 """Simulate a random mutation.""" | |
| 305 # Is it an indel? | |
| 306 rand = random.random() | |
| 307 if rand < indel_rate: | |
| 308 # Is it an insertion or deletion? Decide, then initialize it. | |
| 309 # Re-use the random number from above. Just check if it's in the lower or upper half of the | |
| 310 # range from 0 to indel_rate. | |
| 311 if rand < indel_rate/2: | |
| 312 mtype = 'del' | |
| 313 alt = 1 | |
| 314 else: | |
| 315 mtype = 'ins' | |
| 316 alt = get_rand_base() | |
| 317 # Extend the indel as long as the extension rate allows. | |
| 318 while random.random() < extension_rate: | |
| 319 if mtype == 'ins': | |
| 320 alt += get_rand_base() | |
| 321 else: | |
| 322 alt += 1 | |
| 323 else: | |
| 324 # What is the new base for the SNV? | |
| 325 mtype = 'snv' | |
| 326 alt = get_rand_base() | |
| 327 return mtype, alt | |
| 328 | |
| 329 | |
| 330 def get_rand_base(bases='ACGT'): | |
| 331 return random.choice(bases) | |
| 332 | |
| 333 | |
| 334 def get_rand_seq(seq_len): | |
| 335 return ''.join([get_rand_base() for i in range(seq_len)]) | |
| 336 | |
| 337 | |
| 338 def get_revcomp(seq): | |
| 339 return seq.translate(REVCOMP_TABLE)[::-1] | |
| 340 | |
| 341 | |
| 342 def apply_mutation(mut, seq): | |
| 343 i = mut['coord'] | |
| 344 if mut['type'] == 'snv': | |
| 345 # Replace the base at "coord". | |
| 346 new_seq = seq[:i] + mut['alt'] + seq[i+1:] | |
| 347 else: | |
| 348 # Indels are handled by inserting or deleting bases starting *before* the base at "coord". | |
| 349 # This goes agains the VCF convention, but it allows deleting the first and last base, as well | |
| 350 # as inserting before and after the sequence without as much special-casing. | |
| 351 if mut['type'] == 'ins': | |
| 352 # Example: 'ACGTACGT' + ins 'GC' at 4 = 'ACGTGCACGT' | |
| 353 new_seq = seq[:i] + mut['alt'] + seq[i:] | |
| 354 else: | |
| 355 # Example: 'ACGTACGT' + del 2 at 4 = 'ACGTGT' | |
| 356 new_seq = seq[:i] + seq[i+mut['alt']:] | |
| 357 return new_seq | |
| 358 | |
| 359 | |
| 360 def get_mutations_subset(mutations_old, start, length, revcomp=False, seqlen=None): | |
| 361 """Get a list of the input mutations which are within a certain region. | |
| 362 The output list maintains the order in the input list, only filtering out | |
| 363 mutations outside the specified region. | |
| 364 "start" is the start of the region (0-based). If revcomp, this start should be | |
| 365 in the coordinate system of the reverse-complemented sequence. | |
| 366 "length" is the length of the region. | |
| 367 "revcomp" causes the mutations to be converted to their reverse complements, and | |
| 368 the "start" to refer to the reverse complement sequence's coordinates. The order | |
| 369 of the mutations is unchanged, though. | |
| 370 "seqlen" is the length of the sequence the mutations occurred in. This is only | |
| 371 needed when revcomp is True, to convert coordinates to the reverse complement | |
| 372 coordinate system.""" | |
| 373 stop = start + length | |
| 374 mutations_new = [] | |
| 375 for mutation in mutations_old: | |
| 376 if revcomp: | |
| 377 mutation = get_mutation_revcomp(mutation, seqlen) | |
| 378 if start <= mutation['coord'] < stop: | |
| 379 mutations_new.append(mutation) | |
| 380 elif mutation['coord'] == stop and mutation['type'] == 'ins': | |
| 381 # Allow insertions at the last coordinate. | |
| 382 mutations_new.append(mutation) | |
| 383 return mutations_new | |
| 384 | |
| 385 | |
| 386 def get_mutation_revcomp(mut, seqlen): | |
| 387 """Convert a mutation to its reverse complement. | |
| 388 "seqlen" is the length of the sequence the mutation is being applied to. Needed | |
| 389 to convert the coordinate to a coordinate system starting at the end of the | |
| 390 sequence.""" | |
| 391 mut_rc = {'type':mut['type']} | |
| 392 if mut['type'] == 'snv': | |
| 393 mut_rc['coord'] = seqlen - mut['coord'] - 1 | |
| 394 mut_rc['alt'] = get_revcomp(mut['alt']) | |
| 395 elif mut['type'] == 'ins': | |
| 396 mut_rc['coord'] = seqlen - mut['coord'] | |
| 397 mut_rc['alt'] = get_revcomp(mut['alt']) | |
| 398 elif mut['type'] == 'del': | |
| 399 mut_rc['coord'] = seqlen - mut['coord'] - mut['alt'] | |
| 400 mut_rc['alt'] = mut['alt'] | |
| 401 return mut_rc | |
| 402 | |
| 403 | |
| 404 def log_mutations(mutfile, mutations, read_id, chrom, start, stop): | |
| 405 for mutation in mutations: | |
| 406 mutfile.write('{read_id}\t{chrom}\t{start}\t{stop}\t{coord}\t{type}\t{alt}\n' | |
| 407 .format(read_id=read_id, chrom=chrom, start=start, stop=stop, **mutation)) | |
| 408 | |
| 409 | |
| 410 def add_pcr_errors(subtree, strand, read_len, error_rate, indel_rate, extension_rate): | |
| 411 """Add simulated PCR errors to a node in a tree and all its descendants.""" | |
| 412 # Note: The errors are intended as "errors made in creating this molecule", so don't apply this to | |
| 413 # the root node, since that is supposed to be the original, unaltered molecule. | |
| 414 # Go down the subtree and simulate errors in creating each fragment. | |
| 415 # Process all the first-child descendants of the original node in a loop, and recursively call | |
| 416 # this function to process all second children. | |
| 417 node = subtree | |
| 418 while node: | |
| 419 node['strand'] = strand | |
| 420 node['errors'] = list(generate_mutations(read_len, error_rate, indel_rate, extension_rate)) | |
| 421 add_pcr_errors(node.get('child2'), strand, read_len, error_rate, indel_rate, extension_rate) | |
| 422 node = node.get('child1') | |
| 423 | |
| 424 | |
| 425 def apply_pcr_errors(subtree, seq): | |
| 426 node = subtree | |
| 427 while node: | |
| 428 for error in node.get('errors', ()): | |
| 429 seq = apply_mutation(error, seq) | |
| 430 if 'child1' not in node: | |
| 431 node['seq'] = seq | |
| 432 apply_pcr_errors(node.get('child2'), seq) | |
| 433 node = node.get('child1') | |
| 434 | |
| 435 | |
| 436 def get_final_fragments(tree): | |
| 437 """Walk to the leaf nodes of the tree and get the post-PCR sequences of all the fragments. | |
| 438 Returns a dict mapping fragment id number to a dict representing the fragment. Its only two keys | |
| 439 are 'seq' (the final sequence) and 'strand' ('+' or '-').""" | |
| 440 fragments = {} | |
| 441 nodes = [tree] | |
| 442 while nodes: | |
| 443 node = nodes.pop() | |
| 444 child1 = node.get('child1') | |
| 445 if child1: | |
| 446 nodes.append(child1) | |
| 447 else: | |
| 448 fragments[node['id']] = {'seq':node['seq'], 'strand':node['strand']} | |
| 449 child2 = node.get('child2') | |
| 450 if child2: | |
| 451 nodes.append(child2) | |
| 452 return fragments | |
| 453 | |
| 454 | |
| 455 def add_mutation_lists(subtree, fragments, mut_list1): | |
| 456 """Compile the list of mutations that each fragment has undergone in PCR. | |
| 457 To call from the root, give [] as "mut_list1" and a dict mapping all existing node id's to a dict | |
| 458 as "fragments". Instead of returning the data, this will add a 'mutations' key to the dict for | |
| 459 each fragment, mapping it to a list of PCR mutations that occurred in the lineage of the fragment, | |
| 460 in chronological order.""" | |
| 461 node = subtree | |
| 462 while node: | |
| 463 mut_list1.extend(node.get('errors', ())) | |
| 464 if 'child1' not in node: | |
| 465 fragments[node['id']]['mutations'] = mut_list1 | |
| 466 if 'child2' in node: | |
| 467 mut_list2 = copy.deepcopy(mut_list1) | |
| 468 add_mutation_lists(node.get('child2'), fragments, mut_list2) | |
| 469 node = node.get('child1') | |
| 470 | |
| 471 | |
| 472 def check_tree_balance(subtree): | |
| 473 """Find all points in the tree where the cycles of sibling nodes is unequal, and | |
| 474 return the maximum difference.""" | |
| 475 node = subtree | |
| 476 if node: | |
| 477 child1 = node.get('child1') | |
| 478 child2 = node.get('child2') | |
| 479 if child1 and child2: | |
| 480 diff = abs(child1['cycle'] - child2['cycle']) | |
| 481 else: | |
| 482 diff = 0 | |
| 483 diff_child1 = check_tree_balance(child1) | |
| 484 diff_child2 = check_tree_balance(child2) | |
| 485 return max(diff, diff_child1, diff_child2) | |
| 486 else: | |
| 487 return 0 | |
| 488 | |
| 489 | |
| 490 def get_good_pcr_tree(n_reads, n_cycles, max_tries, max_diff=1): | |
| 491 """Return a single, balanced PCR tree from build_pcr_tree(), or fail if one cannot | |
| 492 be found in max_tries. | |
| 493 Compensate for bugs in build_pcr_tree() that sometimes result in multiple trees, | |
| 494 or trees with siblings from different cycles.""" | |
| 495 tries = 0 | |
| 496 while tries <= max_tries: | |
| 497 trees = build_pcr_tree(n_reads, n_cycles) | |
| 498 if len(trees) == 1 and check_tree_balance(trees[0]) <= max_diff: | |
| 499 return trees[0] | |
| 500 tries += 1 | |
| 501 raise AssertionError('Could not generate a single, balanced tree! (tried {} times)' | |
| 502 .format(max_tries)) | |
| 503 | |
| 504 | |
| 505 def build_pcr_tree(n_reads, n_cycles): | |
| 506 """Create a simulated descent lineage of how all the final PCR fragments are related. | |
| 507 Each node represents a fragment molecule at one stage of PCR. Each node is a dict containing the | |
| 508 fragment's children (other nodes) ('child1' and 'child2'), the PCR cycle number ('cycle'), and, | |
| 509 at the leaves, a unique id number for each final fragment. | |
| 510 Returns a list of root nodes. Usually there will only be one, but about 1-3% of the time it fails | |
| 511 to unify the subtrees and results in a broken tree. | |
| 512 """ | |
| 513 #TODO: Make it always return a single tree. | |
| 514 # Begin a branch for each of the fragments. These are the leaf nodes. We'll work backward from | |
| 515 # these, simulating the points at which they share ancestors, eventually coalescing into the | |
| 516 # single (root) ancestor. | |
| 517 branches = [] | |
| 518 for frag_id in range(n_reads): | |
| 519 branches.append({'cycle':n_cycles-1, 'id':frag_id}) | |
| 520 # Build up all the branches in parallel. Start from the second-to-last PCR cycle. | |
| 521 for cycle in reversed(range(n_cycles-1)): | |
| 522 # Probability of 2 fragments sharing an ancestor at cycle c is 1/2^c. | |
| 523 prob = 1/2**cycle | |
| 524 frag_i = 0 | |
| 525 while frag_i < len(branches): | |
| 526 current_root = branches[frag_i] | |
| 527 # Does the current fragment share this ancestor with any of the other fragments? | |
| 528 # numpy.random.binomial() is a fast way to simulate going through every other fragment and | |
| 529 # asking if random.random() < prob. | |
| 530 shared = numpy.random.binomial(len(branches)-1, prob) | |
| 531 if shared == 0: | |
| 532 # No branch point here. Just add another level to the lineage. | |
| 533 branches[frag_i] = {'cycle':cycle, 'child1':current_root} | |
| 534 else: | |
| 535 # Pick a random other fragment to share this ancestor with. | |
| 536 # Make a list of candidates to pick from. | |
| 537 candidates = [] | |
| 538 for candidate_i, candidate in enumerate(branches): | |
| 539 # Don't include ourselves. | |
| 540 if candidate is current_root: | |
| 541 continue | |
| 542 # If it's at a cycle above us and it already has a child, skip it. | |
| 543 if candidate['cycle'] == cycle and candidate.get('child2'): | |
| 544 continue | |
| 545 candidates.append(candidate_i) | |
| 546 if candidates: | |
| 547 relative_i = random.choice(candidates) | |
| 548 relative = branches[relative_i] | |
| 549 # Have we already passed this fragmentfragment on this cycle? | |
| 550 if relative['cycle'] == cycle: | |
| 551 # If we've already passed it, we're looking at the fragment's parent. We want the child. | |
| 552 relative = relative['child1'] | |
| 553 # Join the lineages of our current fragment and the relative to a new parent. | |
| 554 #TODO: Sometimes, we end up matching up subtrees of different depths. But the discrepancy | |
| 555 # is rarely greater than 1. Figure out why. | |
| 556 # assert abs(current_root['cycle'] - relative['cycle']) < 3, ('cycle: {}, current_root: {},' | |
| 557 # ' relative: {}, frag_i: {}, relative_i: {}, branches: {}, candidates: {}, shared: {}' | |
| 558 # .format(cycle, current_root['cycle'], relative['cycle'], frag_i, relative_i, | |
| 559 # len(branches), len(candidates), shared)) | |
| 560 branches[frag_i] = {'cycle':cycle, 'child1':current_root, 'child2':relative} | |
| 561 # Remove the relative from the list of lineages. | |
| 562 del(branches[relative_i]) | |
| 563 if relative_i < frag_i: | |
| 564 frag_i -= 1 | |
| 565 frag_i += 1 | |
| 566 return branches | |
| 567 | |
| 568 | |
| 569 def get_depth(tree): | |
| 570 depth = 0 | |
| 571 node = tree | |
| 572 while node: | |
| 573 depth += 1 | |
| 574 node = node.get('child1') | |
| 575 return depth | |
| 576 | |
| 577 | |
| 578 def convert_tree(tree_orig): | |
| 579 # Let's operate on a copy only. | |
| 580 tree = copy.deepcopy(tree_orig) | |
| 581 # Turn the tree vertical. | |
| 582 tree['line'] = 1 | |
| 583 tree['children'] = 0 | |
| 584 levels = [[tree]] | |
| 585 done = False | |
| 586 while not done: | |
| 587 last_level = levels[-1] | |
| 588 this_level = [] | |
| 589 done = True | |
| 590 for node in last_level: | |
| 591 for child_name in ('child1', 'child2'): | |
| 592 child = node.get(child_name) | |
| 593 if child: | |
| 594 done = False | |
| 595 child['parent'] = node | |
| 596 child['branch'] = child['parent']['branch'] | |
| 597 if child_name == 'child2': | |
| 598 child['branch'] += 1 | |
| 599 this_level.append(child) | |
| 600 this_level.sort(key=lambda node: node['branch']) | |
| 601 levels.append(this_level) | |
| 602 return levels | |
| 603 | |
| 604 | |
| 605 def print_levels(levels): | |
| 606 last_level = [] | |
| 607 for level in levels: | |
| 608 for node in level: | |
| 609 child = 1 | |
| 610 for parent in last_level: | |
| 611 if parent.get('child2') is node: | |
| 612 child = 2 | |
| 613 if child == 1: | |
| 614 sys.stdout.write('| ') | |
| 615 else: | |
| 616 sys.stdout.write('\ ') | |
| 617 last_level = level | |
| 618 print() | |
| 619 | |
| 620 | |
| 621 def label_branches(tree): | |
| 622 """Label each vertical branch (line of 'child1's) with an id number.""" | |
| 623 counter = 1 | |
| 624 tree['branch'] = counter | |
| 625 nodes = [tree] | |
| 626 while nodes: | |
| 627 node = nodes.pop(0) | |
| 628 child1 = node.get('child1') | |
| 629 if child1: | |
| 630 child1['branch'] = node['branch'] | |
| 631 nodes.append(child1) | |
| 632 child2 = node.get('child2') | |
| 633 if child2: | |
| 634 counter += 1 | |
| 635 child2['branch'] = counter | |
| 636 nodes.append(child2) | |
| 637 | |
| 638 | |
| 639 def print_tree(tree_orig): | |
| 640 # We "write" strings to an output buffer instead of directly printing, so we can post-process the | |
| 641 # output. The buffer is a matrix of cells, each holding a string representing one element. | |
| 642 lines = [[]] | |
| 643 # Let's operate on a copy only. | |
| 644 tree = copy.deepcopy(tree_orig) | |
| 645 # Add some bookkeeping data. | |
| 646 label_branches(tree) | |
| 647 tree['level'] = 0 | |
| 648 branches = [tree] | |
| 649 while branches: | |
| 650 line = lines[-1] | |
| 651 branch = branches.pop() | |
| 652 level = branch['level'] | |
| 653 while level > 0: | |
| 654 line.append(' ') | |
| 655 level -= 1 | |
| 656 node = branch | |
| 657 while node: | |
| 658 # Is it the root node? (Have we written anything yet?) | |
| 659 if lines[0]: | |
| 660 # Are we at the start of the line? (Is it only spaces so far?) | |
| 661 if line[-1] == ' ': | |
| 662 line.append('\-') | |
| 663 elif line[-1].endswith('-'): | |
| 664 line.append('=-') | |
| 665 else: | |
| 666 line.append('*-') | |
| 667 child2 = node.get('child2') | |
| 668 if child2: | |
| 669 child2['level'] = node['level'] + 1 | |
| 670 branches.append(child2) | |
| 671 parent = node | |
| 672 node = node.get('child1') | |
| 673 if node: | |
| 674 node['level'] = parent['level'] + 1 | |
| 675 else: | |
| 676 line.append(' {}'.format(parent['branch'])) | |
| 677 lines.append([]) | |
| 678 # Post-process output: Add lines connecting branches to parents. | |
| 679 x = 0 | |
| 680 done = False | |
| 681 while not done: | |
| 682 # Draw vertical lines upward from branch points. | |
| 683 drawing = False | |
| 684 for line in reversed(lines): | |
| 685 done = True | |
| 686 if x < len(line): | |
| 687 done = False | |
| 688 cell = line[x] | |
| 689 if cell == '\-': | |
| 690 drawing = True | |
| 691 elif cell == ' ' and drawing: | |
| 692 line[x] = '| ' | |
| 693 elif cell == '=-' and drawing: | |
| 694 drawing = False | |
| 695 x += 1 | |
| 696 # Print the final output. | |
| 697 for line in lines: | |
| 698 print(''.join(line)) | |
| 699 | |
| 700 | |
| 701 def fail(message): | |
| 702 sys.stderr.write(message+"\n") | |
| 703 sys.exit(1) | |
| 704 | |
| 705 if __name__ == '__main__': | |
| 706 sys.exit(main(sys.argv)) |
