comparison assess_alignment.py @ 0:8bc106442b1a draft

planemo upload for repository https://github.com/pvanheus/polio_report commit a99e10fec2fac5aae70974c977eb3b362a1a8429-dirty
author sanbi-uwc
date Tue, 19 Jul 2022 11:54:46 +0000
parents
children 31ca16290d4f
comparison
equal deleted inserted replaced
-1:000000000000 0:8bc106442b1a
1 #!/usr/bin/env python
2
3 import argparse
4 import json
5 import sys
6 from dataclasses import dataclass
7 from typing import TextIO
8
9
10 @dataclass
11 class Sample:
12 name: str
13 reference: str
14 mismatches: int
15 quality: str
16
17
18 def load_json(json_file: TextIO) -> dict:
19 data = json.load(json_file)
20 if "msa" not in data:
21 raise ValueError("MSA missing from JSON, cannot proceed")
22 if "gappedTraces" not in data:
23 raise ValueError("gappedTraces missing from JSON, cannot proceed")
24 return data
25
26
27 def analyse_mismatches(
28 json_file: TextIO,
29 offset: int,
30 length: int,
31 vp1only: bool = True,
32 sec_is_conflict: bool = False,
33 ) -> list:
34 data = load_json(json_file)
35 msas = [al for al in data["msa"] if not al["reference"]]
36 reference = None
37 for al in data["msa"]:
38 if al["reference"]:
39 if reference is None:
40 reference = al
41 else:
42 sys.exit(
43 "more than one reference found in JSON MSA list, cannot proceed"
44 )
45 min_start = min([int(al["leadingGaps"]) for al in msas])
46 max_end = max([int(al["leadingGaps"]) + len(al["align"]) for al in msas])
47 base_state = ["n"] * len(reference["align"])
48 mismatch_bases = {}
49 for i, base in enumerate(reference["align"]):
50 for k, al in enumerate(msas):
51 leading_gaps = int(al["leadingGaps"])
52 align_len = len(al["align"])
53 if leading_gaps < i and (leading_gaps + align_len) > i:
54 vp1pos = i - offset
55 if vp1only and vp1pos < 0 or vp1pos > length:
56 # skip positions outside of vp1 gene region
57 continue
58 al_base = al["align"][i - leading_gaps]
59 has_secondary_basecall = False
60 if sec_is_conflict:
61 gappedTrace = data["gappedTraces"][k]
62 pos = i - int(gappedTrace["leadingGaps"])
63 # print(len(gappedTrace['basecallPos']), pos, k, len(gappedTrace['basecalls']), gappedTrace['basecallPos'][pos])
64 basecall_str = gappedTrace["basecalls"][
65 str(gappedTrace["basecallPos"][pos])
66 ]
67 if "|" in basecall_str:
68 has_secondary_basecall = True
69 # set this position to conflicted
70 base_state[i] = "C"
71 if al_base != base:
72 # let's deal with all the cases where the base state doesn't match the reference
73 if base_state[i] == "G":
74 # the base state was G (a trace matches reference) and now we see a mismatch
75 base_state[i] = "C"
76 elif base_state[i] == "C":
77 # already marked as conflicting - a mismatch doesn't change that
78 pass
79 elif base_state[i] == "n" or base_state[i] == "M":
80 # we never saw this before or its already marked as a mismatch
81 base_state[i] = "M"
82 mismatch_bases[i] = al_base
83 else:
84 sys.exit("unexpected base state: " + base_state[i])
85 else:
86 if base_state[i] == "G" or base_state[i] == "n":
87 # we saw this before and got a match or
88 # we never saw this before
89 base_state[i] = "G"
90 elif base_state[i] == "M":
91 # we saw this before but it was a mismatch - mark this as a conflict
92 base_state[i] = "C"
93 if i in mismatch_bases:
94 del mismatch_bases[i]
95 elif base_state[i] == "C":
96 # we have seen a conflict here before
97 pass
98 else:
99 sys.exit("unexpected base_state: " + base_state[i])
100 conflicts = base_state.count("C")
101 matches = base_state.count("G")
102 mismatches = base_state.count("M")
103 mismatch_list = []
104 for i, state in enumerate(base_state):
105 # i is in zero-based genome coordinates
106 if state == "M":
107 # for mismatch store [pos_in_genome, pos_in_vp1, reference_base, sequenced_base]
108 mismatch_list.append(
109 [i, i - offset, reference["align"][i], mismatch_bases[i]]
110 )
111 return [conflicts, matches, mismatches, mismatch_list]
112
113
114 def analyse_trace_quality(json_file: TextIO) -> float:
115 data = load_json(json_file)
116
117 traces = data["gappedTraces"]
118 overall_avg = 0
119 for trace in traces:
120 start = min(trace["basecallPos"])
121 end = max(trace["basecallPos"])
122 call_quality = {}
123 avg_ratio = 0
124 for base in ("A", "C", "G", "T"):
125 calls = trace["peak" + base][start : end + 1]
126 min_call = min(calls)
127 max_call = max(calls)
128 avg_call = sum(calls) / len(calls)
129 ratio = max_call / avg_call
130 call_quality["avg" + base] = avg_call
131 call_quality["min" + base] = min_call
132 call_quality["max" + base] = max_call
133 call_quality["ratio" + base] = ratio
134 avg_ratio += ratio
135 avg_ratio = avg_ratio / 4
136 overall_avg += avg_ratio
137 overall_avg = overall_avg / len(traces)
138 return overall_avg
139
140
141 def comma_split(args: str) -> list[str]:
142 return args.split(",")
143
144
145 if __name__ == "__main__":
146 parser = argparse.ArgumentParser()
147 parser.add_argument("--output_filename", help="Path to output file")
148 parser.add_argument("--sample_name", help="Name of sample being analysed")
149 parser.add_argument(
150 "--dataset_names", type=comma_split, help="Comma separated names for datasets"
151 )
152 parser.add_argument("--datasets", nargs="+")
153 args = parser.parse_args()
154
155 offsets = {
156 "poliovirus1sabin": 2480,
157 "poliovirus2sabin": 2482,
158 "poliovirus3sabin": 2477,
159 }
160
161 lengths = {
162 "poliovirus1sabin": 906,
163 "poliovirus2sabin": 903,
164 "poliovirus3sabin": 900,
165 }
166
167 min_mismatches = None
168 for file_index, json_filename in enumerate(args.datasets):
169 dataset_name = args.dataset_names[file_index].replace(
170 ".json", ""
171 ) # take the name but remove any json suffix
172 offset = offsets[dataset_name]
173 length = lengths[dataset_name]
174 (conflicts, matches, mismatches, mismatch_list) = analyse_mismatches(
175 open(json_filename), offset, length
176 )
177 # analyse_mismatches(json_filename, True)
178 quality = analyse_trace_quality(open(json_filename))
179 if min_mismatches is None or mismatches < min_mismatches:
180 min_mismatches = mismatches
181 best_match_mismatch_list = mismatch_list
182 best_match_quality = quality
183 best_match_reference = dataset_name
184
185 info = {
186 "sample_name": args.sample_name,
187 "best_reference": best_match_reference,
188 "mismatches": min_mismatches,
189 "mismatch_list": best_match_mismatch_list,
190 "quality": best_match_quality,
191 }
192 json.dump(info, open(args.output_filename, "w"))