comparison model_prediction.py @ 3:0a1812986bc3 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 11:10:37 +0000 (20 months ago)
parents 38c4f8a98038
children
comparison
equal deleted inserted replaced
2:38c4f8a98038 3:0a1812986bc3
1 import argparse 1 import argparse
2 import json 2 import json
3 import warnings
4
3 import numpy as np 5 import numpy as np
4 import pandas as pd 6 import pandas as pd
5 import warnings 7 from galaxy_ml.model_persist import load_model_from_h5
6 8 from galaxy_ml.utils import (clean_params, get_module, read_columns,
9 try_get_attr)
7 from scipy.io import mmread 10 from scipy.io import mmread
8 from sklearn.pipeline import Pipeline 11
9 12 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
10 from galaxy_ml.utils import (load_model, read_columns, 13
11 get_module, try_get_attr) 14
12 15 def main(
13 16 inputs,
14 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) 17 infile_estimator,
15 18 outfile_predict,
16 19 infile1=None,
17 def main(inputs, infile_estimator, outfile_predict, 20 fasta_path=None,
18 infile_weights=None, infile1=None, 21 ref_seq=None,
19 fasta_path=None, ref_seq=None, 22 vcf_path=None,
20 vcf_path=None): 23 ):
21 """ 24 """
22 Parameter 25 Parameter
23 --------- 26 ---------
24 inputs : str 27 inputs : str
25 File path to galaxy tool parameter 28 File path to galaxy tool parameter
26 29
27 infile_estimator : strgit 30 infile_estimator : str
28 File path to trained estimator input 31 File path to trained estimator input
29 32
30 outfile_predict : str 33 outfile_predict : str
31 File path to save the prediction results, tabular 34 File path to save the prediction results, tabular
32
33 infile_weights : str
34 File path to weights input
35 35
36 infile1 : str 36 infile1 : str
37 File path to dataset containing features 37 File path to dataset containing features
38 38
39 fasta_path : str 39 fasta_path : str
43 File path to dataset containing the reference genome sequence. 43 File path to dataset containing the reference genome sequence.
44 44
45 vcf_path : str 45 vcf_path : str
46 File path to dataset containing variants info. 46 File path to dataset containing variants info.
47 """ 47 """
48 warnings.filterwarnings('ignore') 48 warnings.filterwarnings("ignore")
49 49
50 with open(inputs, 'r') as param_handler: 50 with open(inputs, "r") as param_handler:
51 params = json.load(param_handler) 51 params = json.load(param_handler)
52 52
53 # load model 53 # load model
54 with open(infile_estimator, 'rb') as est_handler: 54 estimator = load_model_from_h5(infile_estimator)
55 estimator = load_model(est_handler) 55 estimator = clean_params(estimator)
56
57 main_est = estimator
58 if isinstance(estimator, Pipeline):
59 main_est = estimator.steps[-1][-1]
60 if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'):
61 if not infile_weights or infile_weights == 'None':
62 raise ValueError("The selected model skeleton asks for weights, "
63 "but dataset for weights wan not selected!")
64 main_est.load_weights(infile_weights)
65 56
66 # handle data input 57 # handle data input
67 input_type = params['input_options']['selected_input'] 58 input_type = params["input_options"]["selected_input"]
68 # tabular input 59 # tabular input
69 if input_type == 'tabular': 60 if input_type == "tabular":
70 header = 'infer' if params['input_options']['header1'] else None 61 header = "infer" if params["input_options"]["header1"] else None
71 column_option = (params['input_options'] 62 column_option = params["input_options"]["column_selector_options_1"][
72 ['column_selector_options_1'] 63 "selected_column_selector_option"
73 ['selected_column_selector_option']) 64 ]
74 if column_option in ['by_index_number', 'all_but_by_index_number', 65 if column_option in [
75 'by_header_name', 'all_but_by_header_name']: 66 "by_index_number",
76 c = params['input_options']['column_selector_options_1']['col1'] 67 "all_but_by_index_number",
68 "by_header_name",
69 "all_but_by_header_name",
70 ]:
71 c = params["input_options"]["column_selector_options_1"]["col1"]
77 else: 72 else:
78 c = None 73 c = None
79 74
80 df = pd.read_csv(infile1, sep='\t', header=header, parse_dates=True) 75 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
81 76
82 X = read_columns(df, c=c, c_option=column_option).astype(float) 77 X = read_columns(df, c=c, c_option=column_option).astype(float)
83 78
84 if params['method'] == 'predict': 79 if params["method"] == "predict":
85 preds = estimator.predict(X) 80 preds = estimator.predict(X)
86 else: 81 else:
87 preds = estimator.predict_proba(X) 82 preds = estimator.predict_proba(X)
88 83
89 # sparse input 84 # sparse input
90 elif input_type == 'sparse': 85 elif input_type == "sparse":
91 X = mmread(open(infile1, 'r')) 86 X = mmread(open(infile1, "r"))
92 if params['method'] == 'predict': 87 if params["method"] == "predict":
93 preds = estimator.predict(X) 88 preds = estimator.predict(X)
94 else: 89 else:
95 preds = estimator.predict_proba(X) 90 preds = estimator.predict_proba(X)
96 91
97 # fasta input 92 # fasta input
98 elif input_type == 'seq_fasta': 93 elif input_type == "seq_fasta":
99 if not hasattr(estimator, 'data_batch_generator'): 94 if not hasattr(estimator, "data_batch_generator"):
100 raise ValueError( 95 raise ValueError(
101 "To do prediction on sequences in fasta input, " 96 "To do prediction on sequences in fasta input, "
102 "the estimator must be a `KerasGBatchClassifier`" 97 "the estimator must be a `KerasGBatchClassifier`"
103 "equipped with data_batch_generator!") 98 "equipped with data_batch_generator!"
104 pyfaidx = get_module('pyfaidx') 99 )
100 pyfaidx = get_module("pyfaidx")
105 sequences = pyfaidx.Fasta(fasta_path) 101 sequences = pyfaidx.Fasta(fasta_path)
106 n_seqs = len(sequences.keys()) 102 n_seqs = len(sequences.keys())
107 X = np.arange(n_seqs)[:, np.newaxis] 103 X = np.arange(n_seqs)[:, np.newaxis]
108 seq_length = estimator.data_batch_generator.seq_length 104 seq_length = estimator.data_batch_generator.seq_length
109 batch_size = getattr(estimator, 'batch_size', 32) 105 batch_size = getattr(estimator, "batch_size", 32)
110 steps = (n_seqs + batch_size - 1) // batch_size 106 steps = (n_seqs + batch_size - 1) // batch_size
111 107
112 seq_type = params['input_options']['seq_type'] 108 seq_type = params["input_options"]["seq_type"]
113 klass = try_get_attr( 109 klass = try_get_attr("galaxy_ml.preprocessors", seq_type)
114 'galaxy_ml.preprocessors', seq_type) 110
111 pred_data_generator = klass(fasta_path, seq_length=seq_length)
112
113 if params["method"] == "predict":
114 preds = estimator.predict(
115 X, data_generator=pred_data_generator, steps=steps
116 )
117 else:
118 preds = estimator.predict_proba(
119 X, data_generator=pred_data_generator, steps=steps
120 )
121
122 # vcf input
123 elif input_type == "variant_effect":
124 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator")
125
126 options = params["input_options"]
127 options.pop("selected_input")
128 if options["blacklist_regions"] == "none":
129 options["blacklist_regions"] = None
115 130
116 pred_data_generator = klass( 131 pred_data_generator = klass(
117 fasta_path, seq_length=seq_length) 132 ref_genome_path=ref_seq, vcf_path=vcf_path, **options
118 133 )
119 if params['method'] == 'predict':
120 preds = estimator.predict(
121 X, data_generator=pred_data_generator, steps=steps)
122 else:
123 preds = estimator.predict_proba(
124 X, data_generator=pred_data_generator, steps=steps)
125
126 # vcf input
127 elif input_type == 'variant_effect':
128 klass = try_get_attr('galaxy_ml.preprocessors',
129 'GenomicVariantBatchGenerator')
130
131 options = params['input_options']
132 options.pop('selected_input')
133 if options['blacklist_regions'] == 'none':
134 options['blacklist_regions'] = None
135
136 pred_data_generator = klass(
137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options)
138 134
139 pred_data_generator.set_processing_attrs() 135 pred_data_generator.set_processing_attrs()
140 136
141 variants = pred_data_generator.variants 137 variants = pred_data_generator.variants
142 138
143 # predict 1600 sample at once then write to file 139 # predict 1600 sample at once then write to file
144 gen_flow = pred_data_generator.flow(batch_size=1600) 140 gen_flow = pred_data_generator.flow(batch_size=1600)
145 141
146 file_writer = open(outfile_predict, 'w') 142 file_writer = open(outfile_predict, "w")
147 header_row = '\t'.join(['chrom', 'pos', 'name', 'ref', 143 header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"])
148 'alt', 'strand'])
149 file_writer.write(header_row) 144 file_writer.write(header_row)
150 header_done = False 145 header_done = False
151 146
152 steps_done = 0 147 steps_done = 0
153 148
154 # TODO: multiple threading 149 # TODO: multiple threading
155 try: 150 try:
156 while steps_done < len(gen_flow): 151 while steps_done < len(gen_flow):
157 index_array = next(gen_flow.index_generator) 152 index_array = next(gen_flow.index_generator)
158 batch_X = gen_flow._get_batches_of_transformed_samples( 153 batch_X = gen_flow._get_batches_of_transformed_samples(index_array)
159 index_array) 154
160 155 if params["method"] == "predict":
161 if params['method'] == 'predict':
162 batch_preds = estimator.predict( 156 batch_preds = estimator.predict(
163 batch_X, 157 batch_X,
164 # The presence of `pred_data_generator` below is to 158 # The presence of `pred_data_generator` below is to
165 # override model carrying data_generator if there 159 # override model carrying data_generator if there
166 # is any. 160 # is any.
167 data_generator=pred_data_generator) 161 data_generator=pred_data_generator,
162 )
168 else: 163 else:
169 batch_preds = estimator.predict_proba( 164 batch_preds = estimator.predict_proba(
170 batch_X, 165 batch_X,
171 # The presence of `pred_data_generator` below is to 166 # The presence of `pred_data_generator` below is to
172 # override model carrying data_generator if there 167 # override model carrying data_generator if there
173 # is any. 168 # is any.
174 data_generator=pred_data_generator) 169 data_generator=pred_data_generator,
170 )
175 171
176 if batch_preds.ndim == 1: 172 if batch_preds.ndim == 1:
177 batch_preds = batch_preds[:, np.newaxis] 173 batch_preds = batch_preds[:, np.newaxis]
178 174
179 batch_meta = variants[index_array] 175 batch_meta = variants[index_array]
180 batch_out = np.column_stack([batch_meta, batch_preds]) 176 batch_out = np.column_stack([batch_meta, batch_preds])
181 177
182 if not header_done: 178 if not header_done:
183 heads = np.arange(batch_preds.shape[-1]).astype(str) 179 heads = np.arange(batch_preds.shape[-1]).astype(str)
184 heads_str = '\t'.join(heads) 180 heads_str = "\t".join(heads)
185 file_writer.write("\t%s\n" % heads_str) 181 file_writer.write("\t%s\n" % heads_str)
186 header_done = True 182 header_done = True
187 183
188 for row in batch_out: 184 for row in batch_out:
189 row_str = '\t'.join(row) 185 row_str = "\t".join(row)
190 file_writer.write("%s\n" % row_str) 186 file_writer.write("%s\n" % row_str)
191 187
192 steps_done += 1 188 steps_done += 1
193 189
194 finally: 190 finally:
198 return 0 194 return 0
199 # end input 195 # end input
200 196
201 # output 197 # output
202 if len(preds.shape) == 1: 198 if len(preds.shape) == 1:
203 rval = pd.DataFrame(preds, columns=['Predicted']) 199 rval = pd.DataFrame(preds, columns=["Predicted"])
204 else: 200 else:
205 rval = pd.DataFrame(preds) 201 rval = pd.DataFrame(preds)
206 202
207 rval.to_csv(outfile_predict, sep='\t', header=True, index=False) 203 rval.to_csv(outfile_predict, sep="\t", header=True, index=False)
208 204
209 205
210 if __name__ == '__main__': 206 if __name__ == "__main__":
211 aparser = argparse.ArgumentParser() 207 aparser = argparse.ArgumentParser()
212 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 208 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
213 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") 209 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator")
214 aparser.add_argument("-w", "--infile_weights", dest="infile_weights")
215 aparser.add_argument("-X", "--infile1", dest="infile1") 210 aparser.add_argument("-X", "--infile1", dest="infile1")
216 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") 211 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict")
217 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 212 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
218 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") 213 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
219 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") 214 aparser.add_argument("-v", "--vcf_path", dest="vcf_path")
220 args = aparser.parse_args() 215 args = aparser.parse_args()
221 216
222 main(args.inputs, args.infile_estimator, args.outfile_predict, 217 main(
223 infile_weights=args.infile_weights, infile1=args.infile1, 218 args.inputs,
224 fasta_path=args.fasta_path, ref_seq=args.ref_seq, 219 args.infile_estimator,
225 vcf_path=args.vcf_path) 220 args.outfile_predict,
221 infile1=args.infile1,
222 fasta_path=args.fasta_path,
223 ref_seq=args.ref_seq,
224 vcf_path=args.vcf_path,
225 )