comparison model_prediction.py @ 2:38c4f8a98038 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5b2ac730ec6d3b762faa9034eddd19ad1b347476"
author bgruening
date Mon, 16 Dec 2019 10:07:37 +0000
parents c1b0c8232816
children 0a1812986bc3
comparison
equal deleted inserted replaced
1:c1b0c8232816 2:38c4f8a98038
134 options['blacklist_regions'] = None 134 options['blacklist_regions'] = None
135 135
136 pred_data_generator = klass( 136 pred_data_generator = klass(
137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options) 137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options)
138 138
139 pred_data_generator.fit() 139 pred_data_generator.set_processing_attrs()
140 140
141 preds = estimator.model_.predict_generator( 141 variants = pred_data_generator.variants
142 pred_data_generator.flow(batch_size=32), 142
143 workers=N_JOBS, 143 # predict 1600 sample at once then write to file
144 use_multiprocessing=True) 144 gen_flow = pred_data_generator.flow(batch_size=1600)
145 145
146 if preds.min() < 0. or preds.max() > 1.: 146 file_writer = open(outfile_predict, 'w')
147 warnings.warn('Network returning invalid probability values. ' 147 header_row = '\t'.join(['chrom', 'pos', 'name', 'ref',
148 'The last layer might not normalize predictions ' 148 'alt', 'strand'])
149 'into probabilities ' 149 file_writer.write(header_row)
150 '(like softmax or sigmoid would).') 150 header_done = False
151 151
152 if params['method'] == 'predict_proba' and preds.shape[1] == 1: 152 steps_done = 0
153 # first column is probability of class 0 and second is of class 1 153
154 preds = np.hstack([1 - preds, preds]) 154 # TODO: multiple threading
155 155 try:
156 elif params['method'] == 'predict': 156 while steps_done < len(gen_flow):
157 if preds.shape[-1] > 1: 157 index_array = next(gen_flow.index_generator)
158 # if the last activation is `softmax`, the sum of all 158 batch_X = gen_flow._get_batches_of_transformed_samples(
159 # probibilities will 1, the classification is considered as 159 index_array)
160 # multi-class problem, otherwise, we take it as multi-label. 160
161 act = getattr(estimator.model_.layers[-1], 'activation', None) 161 if params['method'] == 'predict':
162 if act and act.__name__ == 'softmax': 162 batch_preds = estimator.predict(
163 classes = preds.argmax(axis=-1) 163 batch_X,
164 # The presence of `pred_data_generator` below is to
165 # override model carrying data_generator if there
166 # is any.
167 data_generator=pred_data_generator)
164 else: 168 else:
165 preds = (preds > 0.5).astype('int32') 169 batch_preds = estimator.predict_proba(
166 else: 170 batch_X,
167 classes = (preds > 0.5).astype('int32') 171 # The presence of `pred_data_generator` below is to
168 172 # override model carrying data_generator if there
169 preds = estimator.classes_[classes] 173 # is any.
174 data_generator=pred_data_generator)
175
176 if batch_preds.ndim == 1:
177 batch_preds = batch_preds[:, np.newaxis]
178
179 batch_meta = variants[index_array]
180 batch_out = np.column_stack([batch_meta, batch_preds])
181
182 if not header_done:
183 heads = np.arange(batch_preds.shape[-1]).astype(str)
184 heads_str = '\t'.join(heads)
185 file_writer.write("\t%s\n" % heads_str)
186 header_done = True
187
188 for row in batch_out:
189 row_str = '\t'.join(row)
190 file_writer.write("%s\n" % row_str)
191
192 steps_done += 1
193
194 finally:
195 file_writer.close()
196 # TODO: make api `pred_data_generator.close()`
197 pred_data_generator.close()
198 return 0
170 # end input 199 # end input
171 200
172 # output 201 # output
173 if input_type == 'variant_effect': # TODO: save in batchs 202 if len(preds.shape) == 1:
174 rval = pd.DataFrame(preds)
175 meta = pd.DataFrame(
176 pred_data_generator.variants,
177 columns=['chrom', 'pos', 'name', 'ref', 'alt', 'strand'])
178
179 rval = pd.concat([meta, rval], axis=1)
180
181 elif len(preds.shape) == 1:
182 rval = pd.DataFrame(preds, columns=['Predicted']) 203 rval = pd.DataFrame(preds, columns=['Predicted'])
183 else: 204 else:
184 rval = pd.DataFrame(preds) 205 rval = pd.DataFrame(preds)
185 206
186 rval.to_csv(outfile_predict, sep='\t', 207 rval.to_csv(outfile_predict, sep='\t', header=True, index=False)
187 header=True, index=False)
188 208
189 209
190 if __name__ == '__main__': 210 if __name__ == '__main__':
191 aparser = argparse.ArgumentParser() 211 aparser = argparse.ArgumentParser()
192 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 212 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)