Mercurial > repos > sanbi-uwc > assess_poliovirus_alignment
comparison assess_alignment.py @ 0:8bc106442b1a draft
planemo upload for repository https://github.com/pvanheus/polio_report commit a99e10fec2fac5aae70974c977eb3b362a1a8429-dirty
author | sanbi-uwc |
---|---|
date | Tue, 19 Jul 2022 11:54:46 +0000 |
parents | |
children | 31ca16290d4f |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:8bc106442b1a |
---|---|
1 #!/usr/bin/env python | |
2 | |
3 import argparse | |
4 import json | |
5 import sys | |
6 from dataclasses import dataclass | |
7 from typing import TextIO | |
8 | |
9 | |
10 @dataclass | |
11 class Sample: | |
12 name: str | |
13 reference: str | |
14 mismatches: int | |
15 quality: str | |
16 | |
17 | |
18 def load_json(json_file: TextIO) -> dict: | |
19 data = json.load(json_file) | |
20 if "msa" not in data: | |
21 raise ValueError("MSA missing from JSON, cannot proceed") | |
22 if "gappedTraces" not in data: | |
23 raise ValueError("gappedTraces missing from JSON, cannot proceed") | |
24 return data | |
25 | |
26 | |
27 def analyse_mismatches( | |
28 json_file: TextIO, | |
29 offset: int, | |
30 length: int, | |
31 vp1only: bool = True, | |
32 sec_is_conflict: bool = False, | |
33 ) -> list: | |
34 data = load_json(json_file) | |
35 msas = [al for al in data["msa"] if not al["reference"]] | |
36 reference = None | |
37 for al in data["msa"]: | |
38 if al["reference"]: | |
39 if reference is None: | |
40 reference = al | |
41 else: | |
42 sys.exit( | |
43 "more than one reference found in JSON MSA list, cannot proceed" | |
44 ) | |
45 min_start = min([int(al["leadingGaps"]) for al in msas]) | |
46 max_end = max([int(al["leadingGaps"]) + len(al["align"]) for al in msas]) | |
47 base_state = ["n"] * len(reference["align"]) | |
48 mismatch_bases = {} | |
49 for i, base in enumerate(reference["align"]): | |
50 for k, al in enumerate(msas): | |
51 leading_gaps = int(al["leadingGaps"]) | |
52 align_len = len(al["align"]) | |
53 if leading_gaps < i and (leading_gaps + align_len) > i: | |
54 vp1pos = i - offset | |
55 if vp1only and vp1pos < 0 or vp1pos > length: | |
56 # skip positions outside of vp1 gene region | |
57 continue | |
58 al_base = al["align"][i - leading_gaps] | |
59 has_secondary_basecall = False | |
60 if sec_is_conflict: | |
61 gappedTrace = data["gappedTraces"][k] | |
62 pos = i - int(gappedTrace["leadingGaps"]) | |
63 # print(len(gappedTrace['basecallPos']), pos, k, len(gappedTrace['basecalls']), gappedTrace['basecallPos'][pos]) | |
64 basecall_str = gappedTrace["basecalls"][ | |
65 str(gappedTrace["basecallPos"][pos]) | |
66 ] | |
67 if "|" in basecall_str: | |
68 has_secondary_basecall = True | |
69 # set this position to conflicted | |
70 base_state[i] = "C" | |
71 if al_base != base: | |
72 # let's deal with all the cases where the base state doesn't match the reference | |
73 if base_state[i] == "G": | |
74 # the base state was G (a trace matches reference) and now we see a mismatch | |
75 base_state[i] = "C" | |
76 elif base_state[i] == "C": | |
77 # already marked as conflicting - a mismatch doesn't change that | |
78 pass | |
79 elif base_state[i] == "n" or base_state[i] == "M": | |
80 # we never saw this before or its already marked as a mismatch | |
81 base_state[i] = "M" | |
82 mismatch_bases[i] = al_base | |
83 else: | |
84 sys.exit("unexpected base state: " + base_state[i]) | |
85 else: | |
86 if base_state[i] == "G" or base_state[i] == "n": | |
87 # we saw this before and got a match or | |
88 # we never saw this before | |
89 base_state[i] = "G" | |
90 elif base_state[i] == "M": | |
91 # we saw this before but it was a mismatch - mark this as a conflict | |
92 base_state[i] = "C" | |
93 if i in mismatch_bases: | |
94 del mismatch_bases[i] | |
95 elif base_state[i] == "C": | |
96 # we have seen a conflict here before | |
97 pass | |
98 else: | |
99 sys.exit("unexpected base_state: " + base_state[i]) | |
100 conflicts = base_state.count("C") | |
101 matches = base_state.count("G") | |
102 mismatches = base_state.count("M") | |
103 mismatch_list = [] | |
104 for i, state in enumerate(base_state): | |
105 # i is in zero-based genome coordinates | |
106 if state == "M": | |
107 # for mismatch store [pos_in_genome, pos_in_vp1, reference_base, sequenced_base] | |
108 mismatch_list.append( | |
109 [i, i - offset, reference["align"][i], mismatch_bases[i]] | |
110 ) | |
111 return [conflicts, matches, mismatches, mismatch_list] | |
112 | |
113 | |
114 def analyse_trace_quality(json_file: TextIO) -> float: | |
115 data = load_json(json_file) | |
116 | |
117 traces = data["gappedTraces"] | |
118 overall_avg = 0 | |
119 for trace in traces: | |
120 start = min(trace["basecallPos"]) | |
121 end = max(trace["basecallPos"]) | |
122 call_quality = {} | |
123 avg_ratio = 0 | |
124 for base in ("A", "C", "G", "T"): | |
125 calls = trace["peak" + base][start : end + 1] | |
126 min_call = min(calls) | |
127 max_call = max(calls) | |
128 avg_call = sum(calls) / len(calls) | |
129 ratio = max_call / avg_call | |
130 call_quality["avg" + base] = avg_call | |
131 call_quality["min" + base] = min_call | |
132 call_quality["max" + base] = max_call | |
133 call_quality["ratio" + base] = ratio | |
134 avg_ratio += ratio | |
135 avg_ratio = avg_ratio / 4 | |
136 overall_avg += avg_ratio | |
137 overall_avg = overall_avg / len(traces) | |
138 return overall_avg | |
139 | |
140 | |
141 def comma_split(args: str) -> list[str]: | |
142 return args.split(",") | |
143 | |
144 | |
145 if __name__ == "__main__": | |
146 parser = argparse.ArgumentParser() | |
147 parser.add_argument("--output_filename", help="Path to output file") | |
148 parser.add_argument("--sample_name", help="Name of sample being analysed") | |
149 parser.add_argument( | |
150 "--dataset_names", type=comma_split, help="Comma separated names for datasets" | |
151 ) | |
152 parser.add_argument("--datasets", nargs="+") | |
153 args = parser.parse_args() | |
154 | |
155 offsets = { | |
156 "poliovirus1sabin": 2480, | |
157 "poliovirus2sabin": 2482, | |
158 "poliovirus3sabin": 2477, | |
159 } | |
160 | |
161 lengths = { | |
162 "poliovirus1sabin": 906, | |
163 "poliovirus2sabin": 903, | |
164 "poliovirus3sabin": 900, | |
165 } | |
166 | |
167 min_mismatches = None | |
168 for file_index, json_filename in enumerate(args.datasets): | |
169 dataset_name = args.dataset_names[file_index].replace( | |
170 ".json", "" | |
171 ) # take the name but remove any json suffix | |
172 offset = offsets[dataset_name] | |
173 length = lengths[dataset_name] | |
174 (conflicts, matches, mismatches, mismatch_list) = analyse_mismatches( | |
175 open(json_filename), offset, length | |
176 ) | |
177 # analyse_mismatches(json_filename, True) | |
178 quality = analyse_trace_quality(open(json_filename)) | |
179 if min_mismatches is None or mismatches < min_mismatches: | |
180 min_mismatches = mismatches | |
181 best_match_mismatch_list = mismatch_list | |
182 best_match_quality = quality | |
183 best_match_reference = dataset_name | |
184 | |
185 info = { | |
186 "sample_name": args.sample_name, | |
187 "best_reference": best_match_reference, | |
188 "mismatches": min_mismatches, | |
189 "mismatch_list": best_match_mismatch_list, | |
190 "quality": best_match_quality, | |
191 } | |
192 json.dump(info, open(args.output_filename, "w")) |