comparison ml_visualization_ex.py @ 7:c9b521fcc3ac draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9e28f4466084464d38d3f8db2aff07974be4ba69"
author bgruening
date Wed, 11 Mar 2020 16:55:36 +0000
parents e08eceb9b333
children f2c240cce242
comparison
equal deleted inserted replaced
6:415088b9e6e2 7:c9b521fcc3ac
11 11
12 from keras.models import model_from_json 12 from keras.models import model_from_json
13 from keras.utils import plot_model 13 from keras.utils import plot_model
14 from sklearn.feature_selection.base import SelectorMixin 14 from sklearn.feature_selection.base import SelectorMixin
15 from sklearn.metrics import precision_recall_curve, average_precision_score 15 from sklearn.metrics import precision_recall_curve, average_precision_score
16 from sklearn.metrics import roc_curve, auc 16 from sklearn.metrics import roc_curve, auc, confusion_matrix
17 from sklearn.pipeline import Pipeline 17 from sklearn.pipeline import Pipeline
18 from galaxy_ml.utils import load_model, read_columns, SafeEval 18 from galaxy_ml.utils import load_model, read_columns, SafeEval
19 19
20 20
21 safe_eval = SafeEval() 21 safe_eval = SafeEval()
264 plt.savefig(os.path.join(folder, "output.svg"), format="svg") 264 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
265 os.rename(os.path.join(folder, "output.svg"), 265 os.rename(os.path.join(folder, "output.svg"),
266 os.path.join(folder, "output")) 266 os.path.join(folder, "output"))
267 267
268 268
269 def get_dataframe(file_path, plot_selection, header_name, column_name):
270 header = 'infer' if plot_selection[header_name] else None
271 column_option = plot_selection[column_name]["selected_column_selector_option"]
272 if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]:
273 col = plot_selection[column_name]["col1"]
274 else:
275 col = None
276 _, input_df = read_columns(file_path, c=col,
277 c_option=column_option,
278 return_df=True,
279 sep='\t', header=header,
280 parse_dates=True)
281 return input_df
282
283
269 def main(inputs, infile_estimator=None, infile1=None, 284 def main(inputs, infile_estimator=None, infile1=None,
270 infile2=None, outfile_result=None, 285 infile2=None, outfile_result=None,
271 outfile_object=None, groups=None, 286 outfile_object=None, groups=None,
272 ref_seq=None, intervals=None, 287 ref_seq=None, intervals=None,
273 targets=None, fasta_path=None, 288 targets=None, fasta_path=None,
274 model_config=None): 289 model_config=None, true_labels=None,
290 predicted_labels=None, plot_color=None,
291 title=None):
275 """ 292 """
276 Parameter 293 Parameter
277 --------- 294 ---------
278 inputs : str 295 inputs : str
279 File path to galaxy tool parameter 296 File path to galaxy tool parameter
309 fasta_path : str, default is None 326 fasta_path : str, default is None
310 File path to dataset containing fasta file 327 File path to dataset containing fasta file
311 328
312 model_config : str, default is None 329 model_config : str, default is None
313 File path to dataset containing JSON config for neural networks 330 File path to dataset containing JSON config for neural networks
331
332 true_labels : str, default is None
333 File path to dataset containing true labels
334
335 predicted_labels : str, default is None
336 File path to dataset containing true predicted labels
337
338 plot_color : str, default is None
339 Color of the confusion matrix heatmap
340
341 title : str, default is None
342 Title of the confusion matrix heatmap
314 """ 343 """
315 warnings.simplefilter('ignore') 344 warnings.simplefilter('ignore')
316 345
317 with open(inputs, 'r') as param_handler: 346 with open(inputs, 'r') as param_handler:
318 params = json.load(param_handler) 347 params = json.load(param_handler)
541 plot_model(model, to_file="output.png") 570 plot_model(model, to_file="output.png")
542 os.rename('output.png', 'output') 571 os.rename('output.png', 'output')
543 572
544 return 0 573 return 0
545 574
575 elif plot_type == 'classification_confusion_matrix':
576 plot_selection = params["plotting_selection"]
577 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true")
578 header_predicted = 'infer' if plot_selection["header_predicted"] else None
579 input_predicted = pd.read_csv(predicted_labels, sep='\t', parse_dates=True, header=header_predicted)
580 true_classes = input_true.iloc[:, -1].copy()
581 predicted_classes = input_predicted.iloc[:, -1].copy()
582 axis_labels = list(set(true_classes))
583 c_matrix = confusion_matrix(true_classes, predicted_classes)
584 fig, ax = plt.subplots(figsize=(7, 7))
585 im = plt.imshow(c_matrix, cmap=plot_color)
586 for i in range(len(c_matrix)):
587 for j in range(len(c_matrix)):
588 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
589 ax.set_ylabel('True class labels')
590 ax.set_xlabel('Predicted class labels')
591 ax.set_title(title)
592 ax.set_xticks(axis_labels)
593 ax.set_yticks(axis_labels)
594 fig.colorbar(im, ax=ax)
595 fig.tight_layout()
596 plt.savefig("output.png", dpi=125)
597 os.rename('output.png', 'output')
598
599 return 0
600
546 # save pdf file to disk 601 # save pdf file to disk
547 # fig.write_image("image.pdf", format='pdf') 602 # fig.write_image("image.pdf", format='pdf')
548 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) 603 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
549 604
550 605
560 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") 615 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
561 aparser.add_argument("-b", "--intervals", dest="intervals") 616 aparser.add_argument("-b", "--intervals", dest="intervals")
562 aparser.add_argument("-t", "--targets", dest="targets") 617 aparser.add_argument("-t", "--targets", dest="targets")
563 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 618 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
564 aparser.add_argument("-c", "--model_config", dest="model_config") 619 aparser.add_argument("-c", "--model_config", dest="model_config")
620 aparser.add_argument("-tl", "--true_labels", dest="true_labels")
621 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels")
622 aparser.add_argument("-pc", "--plot_color", dest="plot_color")
623 aparser.add_argument("-pt", "--title", dest="title")
565 args = aparser.parse_args() 624 args = aparser.parse_args()
566 625
567 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, 626 main(args.inputs, args.infile_estimator, args.infile1, args.infile2,
568 args.outfile_result, outfile_object=args.outfile_object, 627 args.outfile_result, outfile_object=args.outfile_object,
569 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, 628 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals,
570 targets=args.targets, fasta_path=args.fasta_path, 629 targets=args.targets, fasta_path=args.fasta_path,
571 model_config=args.model_config) 630 model_config=args.model_config, true_labels=args.true_labels,
631 predicted_labels=args.predicted_labels,
632 plot_color=args.plot_color,
633 title=args.title)