Mercurial > repos > sanbi-uwc > assess_poliovirus_alignment
view assess_alignment.py @ 5:0e556a3f85d6 draft
planemo upload for repository https://github.com/pvanheus/polio_report commit a99e10fec2fac5aae70974c977eb3b362a1a8429-dirty
author | sanbi-uwc |
---|---|
date | Wed, 14 Sep 2022 09:21:44 +0000 |
parents | 1897677e107c |
children | 9fd6dde72d2e |
line wrap: on
line source
#!/usr/bin/env python import argparse import json import sys from dataclasses import dataclass from typing import TextIO @dataclass class Sample: name: str reference: str mismatches: int quality: str def load_json(json_file: TextIO) -> dict: data = json.load(json_file) if "msa" not in data: raise ValueError("MSA missing from JSON, cannot proceed") if "gappedTraces" not in data: raise ValueError("gappedTraces missing from JSON, cannot proceed") return data def analyse_mismatches( json_file: TextIO, offset: int, length: int, vp1only: bool = True, sec_is_conflict: bool = False, ) -> list: data = load_json(json_file) msas = [al for al in data["msa"] if not al["reference"]] reference = None for al in data["msa"]: if al["reference"]: if reference is None: reference = al else: sys.exit( "more than one reference found in JSON MSA list, cannot proceed" ) min_start = min([int(al["leadingGaps"]) for al in msas]) max_end = max([int(al["leadingGaps"]) + len(al["align"]) for al in msas]) base_state = ["n"] * len(reference["align"]) mismatch_bases = {} for i, base in enumerate(reference["align"]): for k, al in enumerate(msas): leading_gaps = int(al["leadingGaps"]) align_len = len(al["align"]) if leading_gaps < i and (leading_gaps + align_len) > i: vp1pos = i - offset if vp1only and vp1pos < 0 or vp1pos > length: # skip positions outside of vp1 gene region continue al_base = al["align"][i - leading_gaps] has_secondary_basecall = False if sec_is_conflict: gappedTrace = data["gappedTraces"][k] pos = i - int(gappedTrace["leadingGaps"]) # print(len(gappedTrace['basecallPos']), pos, k, len(gappedTrace['basecalls']), gappedTrace['basecallPos'][pos]) basecall_str = gappedTrace["basecalls"][ str(gappedTrace["basecallPos"][pos]) ] if "|" in basecall_str: has_secondary_basecall = True # set this position to conflicted base_state[i] = "C" if al_base != base: # let's deal with all the cases where the base state doesn't match the reference if base_state[i] == "G": # the base state was G (a trace matches reference) and now we see a mismatch base_state[i] = "C" elif base_state[i] == "C": # already marked as conflicting - a mismatch doesn't change that pass elif base_state[i] == "n" or base_state[i] == "M": # we never saw this before or its already marked as a mismatch base_state[i] = "M" mismatch_bases[i] = al_base else: sys.exit("unexpected base state: " + base_state[i]) else: if base_state[i] == "G" or base_state[i] == "n": # we saw this before and got a match or # we never saw this before base_state[i] = "G" elif base_state[i] == "M": # we saw this before but it was a mismatch - mark this as a conflict base_state[i] = "C" if i in mismatch_bases: del mismatch_bases[i] elif base_state[i] == "C": # we have seen a conflict here before pass else: sys.exit("unexpected base_state: " + base_state[i]) conflicts = base_state.count("C") matches = base_state.count("G") mismatches = base_state.count("M") mismatch_list = [] for i, state in enumerate(base_state): # i is in zero-based genome coordinates if state == "M": # for mismatch store [pos_in_genome, pos_in_vp1, reference_base, sequenced_base] mismatch_list.append( [i + 1, i - offset + 1, reference["align"][i], mismatch_bases[i]] ) return [conflicts, matches, mismatches, mismatch_list] def analyse_trace_quality(json_file: TextIO) -> float: data = load_json(json_file) traces = data["gappedTraces"] overall_avg = 0 for trace in traces: start = min(trace["basecallPos"]) end = max(trace["basecallPos"]) call_quality = {} avg_ratio = 0 for base in ("A", "C", "G", "T"): calls = trace["peak" + base][start : end + 1] min_call = min(calls) max_call = max(calls) avg_call = sum(calls) / len(calls) ratio = max_call / avg_call call_quality["avg" + base] = avg_call call_quality["min" + base] = min_call call_quality["max" + base] = max_call call_quality["ratio" + base] = ratio avg_ratio += ratio avg_ratio = avg_ratio / 4 overall_avg += avg_ratio overall_avg = overall_avg / len(traces) return overall_avg def comma_split(args: str) -> list[str]: return args.split(",") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--output_filename", help="Path to output file") parser.add_argument("--consensus_output_filename", help="Path to output file for best consensus") parser.add_argument("--sample_name", help="Name of sample being analysed") parser.add_argument( "--dataset_names", type=comma_split, help="Comma separated names for datasets" ) parser.add_argument("--datasets", nargs="+") parser.add_argument("--consensi", nargs="+") args = parser.parse_args() offsets = { # these are in 0-based coordinates, so off-by-one from NCBI 1-based coordinates "poliovirus1sabin": 2479, # V01150 "poliovirus2sabin": 2481, # AY184220 "poliovirus3sabin": 2478, # X00925 } lengths = { "poliovirus1sabin": 906, "poliovirus2sabin": 903, "poliovirus3sabin": 900, } min_mismatches = None for file_index, json_filename in enumerate(args.datasets): dataset_name = args.dataset_names[file_index].replace( ".json", "" ) # take the name but remove any json suffix offset = offsets[dataset_name] length = lengths[dataset_name] (conflicts, matches, mismatches, mismatch_list) = analyse_mismatches( open(json_filename), offset, length ) # analyse_mismatches(json_filename, True) quality = analyse_trace_quality(open(json_filename)) if min_mismatches is None or mismatches < min_mismatches: min_mismatches = mismatches best_match_mismatch_list = mismatch_list best_match_quality = quality best_match_reference = dataset_name best_consensus = args.consensi[file_index] percent_mismatches = round(min_mismatches / lengths[best_match_reference] * 100, 2) info = { "sample_name": args.sample_name, "best_reference": best_match_reference, "mismatches": min_mismatches, "mismatch_list": best_match_mismatch_list, "quality": best_match_quality, "perc_mismatches": percent_mismatches, } json.dump(info, open(args.output_filename, "w")) open(args.consensus_output_filename, "w").write(best_consensus)