Mercurial > repos > bgruening > stacking_ensemble_models
comparison model_prediction.py @ 3:0a1812986bc3 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 11:10:37 +0000 (20 months ago) |
parents | 38c4f8a98038 |
children |
comparison
equal
deleted
inserted
replaced
2:38c4f8a98038 | 3:0a1812986bc3 |
---|---|
1 import argparse | 1 import argparse |
2 import json | 2 import json |
3 import warnings | |
4 | |
3 import numpy as np | 5 import numpy as np |
4 import pandas as pd | 6 import pandas as pd |
5 import warnings | 7 from galaxy_ml.model_persist import load_model_from_h5 |
6 | 8 from galaxy_ml.utils import (clean_params, get_module, read_columns, |
9 try_get_attr) | |
7 from scipy.io import mmread | 10 from scipy.io import mmread |
8 from sklearn.pipeline import Pipeline | 11 |
9 | 12 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) |
10 from galaxy_ml.utils import (load_model, read_columns, | 13 |
11 get_module, try_get_attr) | 14 |
12 | 15 def main( |
13 | 16 inputs, |
14 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | 17 infile_estimator, |
15 | 18 outfile_predict, |
16 | 19 infile1=None, |
17 def main(inputs, infile_estimator, outfile_predict, | 20 fasta_path=None, |
18 infile_weights=None, infile1=None, | 21 ref_seq=None, |
19 fasta_path=None, ref_seq=None, | 22 vcf_path=None, |
20 vcf_path=None): | 23 ): |
21 """ | 24 """ |
22 Parameter | 25 Parameter |
23 --------- | 26 --------- |
24 inputs : str | 27 inputs : str |
25 File path to galaxy tool parameter | 28 File path to galaxy tool parameter |
26 | 29 |
27 infile_estimator : strgit | 30 infile_estimator : str |
28 File path to trained estimator input | 31 File path to trained estimator input |
29 | 32 |
30 outfile_predict : str | 33 outfile_predict : str |
31 File path to save the prediction results, tabular | 34 File path to save the prediction results, tabular |
32 | |
33 infile_weights : str | |
34 File path to weights input | |
35 | 35 |
36 infile1 : str | 36 infile1 : str |
37 File path to dataset containing features | 37 File path to dataset containing features |
38 | 38 |
39 fasta_path : str | 39 fasta_path : str |
43 File path to dataset containing the reference genome sequence. | 43 File path to dataset containing the reference genome sequence. |
44 | 44 |
45 vcf_path : str | 45 vcf_path : str |
46 File path to dataset containing variants info. | 46 File path to dataset containing variants info. |
47 """ | 47 """ |
48 warnings.filterwarnings('ignore') | 48 warnings.filterwarnings("ignore") |
49 | 49 |
50 with open(inputs, 'r') as param_handler: | 50 with open(inputs, "r") as param_handler: |
51 params = json.load(param_handler) | 51 params = json.load(param_handler) |
52 | 52 |
53 # load model | 53 # load model |
54 with open(infile_estimator, 'rb') as est_handler: | 54 estimator = load_model_from_h5(infile_estimator) |
55 estimator = load_model(est_handler) | 55 estimator = clean_params(estimator) |
56 | |
57 main_est = estimator | |
58 if isinstance(estimator, Pipeline): | |
59 main_est = estimator.steps[-1][-1] | |
60 if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'): | |
61 if not infile_weights or infile_weights == 'None': | |
62 raise ValueError("The selected model skeleton asks for weights, " | |
63 "but dataset for weights wan not selected!") | |
64 main_est.load_weights(infile_weights) | |
65 | 56 |
66 # handle data input | 57 # handle data input |
67 input_type = params['input_options']['selected_input'] | 58 input_type = params["input_options"]["selected_input"] |
68 # tabular input | 59 # tabular input |
69 if input_type == 'tabular': | 60 if input_type == "tabular": |
70 header = 'infer' if params['input_options']['header1'] else None | 61 header = "infer" if params["input_options"]["header1"] else None |
71 column_option = (params['input_options'] | 62 column_option = params["input_options"]["column_selector_options_1"][ |
72 ['column_selector_options_1'] | 63 "selected_column_selector_option" |
73 ['selected_column_selector_option']) | 64 ] |
74 if column_option in ['by_index_number', 'all_but_by_index_number', | 65 if column_option in [ |
75 'by_header_name', 'all_but_by_header_name']: | 66 "by_index_number", |
76 c = params['input_options']['column_selector_options_1']['col1'] | 67 "all_but_by_index_number", |
68 "by_header_name", | |
69 "all_but_by_header_name", | |
70 ]: | |
71 c = params["input_options"]["column_selector_options_1"]["col1"] | |
77 else: | 72 else: |
78 c = None | 73 c = None |
79 | 74 |
80 df = pd.read_csv(infile1, sep='\t', header=header, parse_dates=True) | 75 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True) |
81 | 76 |
82 X = read_columns(df, c=c, c_option=column_option).astype(float) | 77 X = read_columns(df, c=c, c_option=column_option).astype(float) |
83 | 78 |
84 if params['method'] == 'predict': | 79 if params["method"] == "predict": |
85 preds = estimator.predict(X) | 80 preds = estimator.predict(X) |
86 else: | 81 else: |
87 preds = estimator.predict_proba(X) | 82 preds = estimator.predict_proba(X) |
88 | 83 |
89 # sparse input | 84 # sparse input |
90 elif input_type == 'sparse': | 85 elif input_type == "sparse": |
91 X = mmread(open(infile1, 'r')) | 86 X = mmread(open(infile1, "r")) |
92 if params['method'] == 'predict': | 87 if params["method"] == "predict": |
93 preds = estimator.predict(X) | 88 preds = estimator.predict(X) |
94 else: | 89 else: |
95 preds = estimator.predict_proba(X) | 90 preds = estimator.predict_proba(X) |
96 | 91 |
97 # fasta input | 92 # fasta input |
98 elif input_type == 'seq_fasta': | 93 elif input_type == "seq_fasta": |
99 if not hasattr(estimator, 'data_batch_generator'): | 94 if not hasattr(estimator, "data_batch_generator"): |
100 raise ValueError( | 95 raise ValueError( |
101 "To do prediction on sequences in fasta input, " | 96 "To do prediction on sequences in fasta input, " |
102 "the estimator must be a `KerasGBatchClassifier`" | 97 "the estimator must be a `KerasGBatchClassifier`" |
103 "equipped with data_batch_generator!") | 98 "equipped with data_batch_generator!" |
104 pyfaidx = get_module('pyfaidx') | 99 ) |
100 pyfaidx = get_module("pyfaidx") | |
105 sequences = pyfaidx.Fasta(fasta_path) | 101 sequences = pyfaidx.Fasta(fasta_path) |
106 n_seqs = len(sequences.keys()) | 102 n_seqs = len(sequences.keys()) |
107 X = np.arange(n_seqs)[:, np.newaxis] | 103 X = np.arange(n_seqs)[:, np.newaxis] |
108 seq_length = estimator.data_batch_generator.seq_length | 104 seq_length = estimator.data_batch_generator.seq_length |
109 batch_size = getattr(estimator, 'batch_size', 32) | 105 batch_size = getattr(estimator, "batch_size", 32) |
110 steps = (n_seqs + batch_size - 1) // batch_size | 106 steps = (n_seqs + batch_size - 1) // batch_size |
111 | 107 |
112 seq_type = params['input_options']['seq_type'] | 108 seq_type = params["input_options"]["seq_type"] |
113 klass = try_get_attr( | 109 klass = try_get_attr("galaxy_ml.preprocessors", seq_type) |
114 'galaxy_ml.preprocessors', seq_type) | 110 |
111 pred_data_generator = klass(fasta_path, seq_length=seq_length) | |
112 | |
113 if params["method"] == "predict": | |
114 preds = estimator.predict( | |
115 X, data_generator=pred_data_generator, steps=steps | |
116 ) | |
117 else: | |
118 preds = estimator.predict_proba( | |
119 X, data_generator=pred_data_generator, steps=steps | |
120 ) | |
121 | |
122 # vcf input | |
123 elif input_type == "variant_effect": | |
124 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator") | |
125 | |
126 options = params["input_options"] | |
127 options.pop("selected_input") | |
128 if options["blacklist_regions"] == "none": | |
129 options["blacklist_regions"] = None | |
115 | 130 |
116 pred_data_generator = klass( | 131 pred_data_generator = klass( |
117 fasta_path, seq_length=seq_length) | 132 ref_genome_path=ref_seq, vcf_path=vcf_path, **options |
118 | 133 ) |
119 if params['method'] == 'predict': | |
120 preds = estimator.predict( | |
121 X, data_generator=pred_data_generator, steps=steps) | |
122 else: | |
123 preds = estimator.predict_proba( | |
124 X, data_generator=pred_data_generator, steps=steps) | |
125 | |
126 # vcf input | |
127 elif input_type == 'variant_effect': | |
128 klass = try_get_attr('galaxy_ml.preprocessors', | |
129 'GenomicVariantBatchGenerator') | |
130 | |
131 options = params['input_options'] | |
132 options.pop('selected_input') | |
133 if options['blacklist_regions'] == 'none': | |
134 options['blacklist_regions'] = None | |
135 | |
136 pred_data_generator = klass( | |
137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options) | |
138 | 134 |
139 pred_data_generator.set_processing_attrs() | 135 pred_data_generator.set_processing_attrs() |
140 | 136 |
141 variants = pred_data_generator.variants | 137 variants = pred_data_generator.variants |
142 | 138 |
143 # predict 1600 sample at once then write to file | 139 # predict 1600 sample at once then write to file |
144 gen_flow = pred_data_generator.flow(batch_size=1600) | 140 gen_flow = pred_data_generator.flow(batch_size=1600) |
145 | 141 |
146 file_writer = open(outfile_predict, 'w') | 142 file_writer = open(outfile_predict, "w") |
147 header_row = '\t'.join(['chrom', 'pos', 'name', 'ref', | 143 header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"]) |
148 'alt', 'strand']) | |
149 file_writer.write(header_row) | 144 file_writer.write(header_row) |
150 header_done = False | 145 header_done = False |
151 | 146 |
152 steps_done = 0 | 147 steps_done = 0 |
153 | 148 |
154 # TODO: multiple threading | 149 # TODO: multiple threading |
155 try: | 150 try: |
156 while steps_done < len(gen_flow): | 151 while steps_done < len(gen_flow): |
157 index_array = next(gen_flow.index_generator) | 152 index_array = next(gen_flow.index_generator) |
158 batch_X = gen_flow._get_batches_of_transformed_samples( | 153 batch_X = gen_flow._get_batches_of_transformed_samples(index_array) |
159 index_array) | 154 |
160 | 155 if params["method"] == "predict": |
161 if params['method'] == 'predict': | |
162 batch_preds = estimator.predict( | 156 batch_preds = estimator.predict( |
163 batch_X, | 157 batch_X, |
164 # The presence of `pred_data_generator` below is to | 158 # The presence of `pred_data_generator` below is to |
165 # override model carrying data_generator if there | 159 # override model carrying data_generator if there |
166 # is any. | 160 # is any. |
167 data_generator=pred_data_generator) | 161 data_generator=pred_data_generator, |
162 ) | |
168 else: | 163 else: |
169 batch_preds = estimator.predict_proba( | 164 batch_preds = estimator.predict_proba( |
170 batch_X, | 165 batch_X, |
171 # The presence of `pred_data_generator` below is to | 166 # The presence of `pred_data_generator` below is to |
172 # override model carrying data_generator if there | 167 # override model carrying data_generator if there |
173 # is any. | 168 # is any. |
174 data_generator=pred_data_generator) | 169 data_generator=pred_data_generator, |
170 ) | |
175 | 171 |
176 if batch_preds.ndim == 1: | 172 if batch_preds.ndim == 1: |
177 batch_preds = batch_preds[:, np.newaxis] | 173 batch_preds = batch_preds[:, np.newaxis] |
178 | 174 |
179 batch_meta = variants[index_array] | 175 batch_meta = variants[index_array] |
180 batch_out = np.column_stack([batch_meta, batch_preds]) | 176 batch_out = np.column_stack([batch_meta, batch_preds]) |
181 | 177 |
182 if not header_done: | 178 if not header_done: |
183 heads = np.arange(batch_preds.shape[-1]).astype(str) | 179 heads = np.arange(batch_preds.shape[-1]).astype(str) |
184 heads_str = '\t'.join(heads) | 180 heads_str = "\t".join(heads) |
185 file_writer.write("\t%s\n" % heads_str) | 181 file_writer.write("\t%s\n" % heads_str) |
186 header_done = True | 182 header_done = True |
187 | 183 |
188 for row in batch_out: | 184 for row in batch_out: |
189 row_str = '\t'.join(row) | 185 row_str = "\t".join(row) |
190 file_writer.write("%s\n" % row_str) | 186 file_writer.write("%s\n" % row_str) |
191 | 187 |
192 steps_done += 1 | 188 steps_done += 1 |
193 | 189 |
194 finally: | 190 finally: |
198 return 0 | 194 return 0 |
199 # end input | 195 # end input |
200 | 196 |
201 # output | 197 # output |
202 if len(preds.shape) == 1: | 198 if len(preds.shape) == 1: |
203 rval = pd.DataFrame(preds, columns=['Predicted']) | 199 rval = pd.DataFrame(preds, columns=["Predicted"]) |
204 else: | 200 else: |
205 rval = pd.DataFrame(preds) | 201 rval = pd.DataFrame(preds) |
206 | 202 |
207 rval.to_csv(outfile_predict, sep='\t', header=True, index=False) | 203 rval.to_csv(outfile_predict, sep="\t", header=True, index=False) |
208 | 204 |
209 | 205 |
210 if __name__ == '__main__': | 206 if __name__ == "__main__": |
211 aparser = argparse.ArgumentParser() | 207 aparser = argparse.ArgumentParser() |
212 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 208 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
213 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") | 209 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") |
214 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | |
215 aparser.add_argument("-X", "--infile1", dest="infile1") | 210 aparser.add_argument("-X", "--infile1", dest="infile1") |
216 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") | 211 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") |
217 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | 212 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") |
218 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 213 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") |
219 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") | 214 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") |
220 args = aparser.parse_args() | 215 args = aparser.parse_args() |
221 | 216 |
222 main(args.inputs, args.infile_estimator, args.outfile_predict, | 217 main( |
223 infile_weights=args.infile_weights, infile1=args.infile1, | 218 args.inputs, |
224 fasta_path=args.fasta_path, ref_seq=args.ref_seq, | 219 args.infile_estimator, |
225 vcf_path=args.vcf_path) | 220 args.outfile_predict, |
221 infile1=args.infile1, | |
222 fasta_path=args.fasta_path, | |
223 ref_seq=args.ref_seq, | |
224 vcf_path=args.vcf_path, | |
225 ) |