Mercurial > repos > bgruening > stacking_ensemble_models
comparison model_prediction.py @ 2:38c4f8a98038 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5b2ac730ec6d3b762faa9034eddd19ad1b347476"
author | bgruening |
---|---|
date | Mon, 16 Dec 2019 10:07:37 +0000 |
parents | c1b0c8232816 |
children | 0a1812986bc3 |
comparison
equal
deleted
inserted
replaced
1:c1b0c8232816 | 2:38c4f8a98038 |
---|---|
134 options['blacklist_regions'] = None | 134 options['blacklist_regions'] = None |
135 | 135 |
136 pred_data_generator = klass( | 136 pred_data_generator = klass( |
137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options) | 137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options) |
138 | 138 |
139 pred_data_generator.fit() | 139 pred_data_generator.set_processing_attrs() |
140 | 140 |
141 preds = estimator.model_.predict_generator( | 141 variants = pred_data_generator.variants |
142 pred_data_generator.flow(batch_size=32), | 142 |
143 workers=N_JOBS, | 143 # predict 1600 sample at once then write to file |
144 use_multiprocessing=True) | 144 gen_flow = pred_data_generator.flow(batch_size=1600) |
145 | 145 |
146 if preds.min() < 0. or preds.max() > 1.: | 146 file_writer = open(outfile_predict, 'w') |
147 warnings.warn('Network returning invalid probability values. ' | 147 header_row = '\t'.join(['chrom', 'pos', 'name', 'ref', |
148 'The last layer might not normalize predictions ' | 148 'alt', 'strand']) |
149 'into probabilities ' | 149 file_writer.write(header_row) |
150 '(like softmax or sigmoid would).') | 150 header_done = False |
151 | 151 |
152 if params['method'] == 'predict_proba' and preds.shape[1] == 1: | 152 steps_done = 0 |
153 # first column is probability of class 0 and second is of class 1 | 153 |
154 preds = np.hstack([1 - preds, preds]) | 154 # TODO: multiple threading |
155 | 155 try: |
156 elif params['method'] == 'predict': | 156 while steps_done < len(gen_flow): |
157 if preds.shape[-1] > 1: | 157 index_array = next(gen_flow.index_generator) |
158 # if the last activation is `softmax`, the sum of all | 158 batch_X = gen_flow._get_batches_of_transformed_samples( |
159 # probibilities will 1, the classification is considered as | 159 index_array) |
160 # multi-class problem, otherwise, we take it as multi-label. | 160 |
161 act = getattr(estimator.model_.layers[-1], 'activation', None) | 161 if params['method'] == 'predict': |
162 if act and act.__name__ == 'softmax': | 162 batch_preds = estimator.predict( |
163 classes = preds.argmax(axis=-1) | 163 batch_X, |
164 # The presence of `pred_data_generator` below is to | |
165 # override model carrying data_generator if there | |
166 # is any. | |
167 data_generator=pred_data_generator) | |
164 else: | 168 else: |
165 preds = (preds > 0.5).astype('int32') | 169 batch_preds = estimator.predict_proba( |
166 else: | 170 batch_X, |
167 classes = (preds > 0.5).astype('int32') | 171 # The presence of `pred_data_generator` below is to |
168 | 172 # override model carrying data_generator if there |
169 preds = estimator.classes_[classes] | 173 # is any. |
174 data_generator=pred_data_generator) | |
175 | |
176 if batch_preds.ndim == 1: | |
177 batch_preds = batch_preds[:, np.newaxis] | |
178 | |
179 batch_meta = variants[index_array] | |
180 batch_out = np.column_stack([batch_meta, batch_preds]) | |
181 | |
182 if not header_done: | |
183 heads = np.arange(batch_preds.shape[-1]).astype(str) | |
184 heads_str = '\t'.join(heads) | |
185 file_writer.write("\t%s\n" % heads_str) | |
186 header_done = True | |
187 | |
188 for row in batch_out: | |
189 row_str = '\t'.join(row) | |
190 file_writer.write("%s\n" % row_str) | |
191 | |
192 steps_done += 1 | |
193 | |
194 finally: | |
195 file_writer.close() | |
196 # TODO: make api `pred_data_generator.close()` | |
197 pred_data_generator.close() | |
198 return 0 | |
170 # end input | 199 # end input |
171 | 200 |
172 # output | 201 # output |
173 if input_type == 'variant_effect': # TODO: save in batchs | 202 if len(preds.shape) == 1: |
174 rval = pd.DataFrame(preds) | |
175 meta = pd.DataFrame( | |
176 pred_data_generator.variants, | |
177 columns=['chrom', 'pos', 'name', 'ref', 'alt', 'strand']) | |
178 | |
179 rval = pd.concat([meta, rval], axis=1) | |
180 | |
181 elif len(preds.shape) == 1: | |
182 rval = pd.DataFrame(preds, columns=['Predicted']) | 203 rval = pd.DataFrame(preds, columns=['Predicted']) |
183 else: | 204 else: |
184 rval = pd.DataFrame(preds) | 205 rval = pd.DataFrame(preds) |
185 | 206 |
186 rval.to_csv(outfile_predict, sep='\t', | 207 rval.to_csv(outfile_predict, sep='\t', header=True, index=False) |
187 header=True, index=False) | |
188 | 208 |
189 | 209 |
190 if __name__ == '__main__': | 210 if __name__ == '__main__': |
191 aparser = argparse.ArgumentParser() | 211 aparser = argparse.ArgumentParser() |
192 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 212 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |