Mercurial > repos > bgruening > plotly_ml_performance_plots
comparison plot_ml_performance.py @ 3:e73eb091612b draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
author | bgruening |
---|---|
date | Wed, 07 Aug 2024 10:20:05 +0000 |
parents | 2cfa4aabda3e |
children |
comparison
equal
deleted
inserted
replaced
2:2cfa4aabda3e | 3:e73eb091612b |
---|---|
1 import argparse | 1 import argparse |
2 | |
3 import matplotlib.pyplot as plt | |
2 import pandas as pd | 4 import pandas as pd |
3 import plotly | 5 import plotly |
4 import pickle | |
5 import plotly.graph_objs as go | 6 import plotly.graph_objs as go |
6 from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve, auc | 7 from galaxy_ml.model_persist import load_model_from_h5 |
8 from galaxy_ml.utils import clean_params | |
9 from sklearn.metrics import ( | |
10 auc, | |
11 confusion_matrix, | |
12 precision_recall_fscore_support, | |
13 roc_curve, | |
14 ) | |
7 from sklearn.preprocessing import label_binarize | 15 from sklearn.preprocessing import label_binarize |
8 | 16 |
9 | 17 |
10 def main(infile_input, infile_output, infile_trained_model): | 18 def main(infile_input, infile_output, infile_trained_model): |
11 """ | 19 """ |
12 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots | 20 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots |
13 Args: | 21 Args: |
14 infile_input: str, input tabular file with true labels | 22 infile_input: str, input tabular file with true labels |
15 infile_output: str, input tabular file with predicted labels | 23 infile_output: str, input tabular file with predicted labels |
16 infile_trained_model: str, input trained model file (zip) | 24 infile_trained_model: str, input trained model file (h5mlm) |
17 """ | 25 """ |
18 | 26 |
19 df_input = pd.read_csv(infile_input, sep='\t', parse_dates=True) | 27 df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True) |
20 df_output = pd.read_csv(infile_output, sep='\t', parse_dates=True) | 28 df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True) |
21 true_labels = df_input.iloc[:, -1].copy() | 29 true_labels = df_input.iloc[:, -1].copy() |
22 predicted_labels = df_output.iloc[:, -1].copy() | 30 predicted_labels = df_output.iloc[:, -1].copy() |
23 axis_labels = list(set(true_labels)) | 31 axis_labels = list(set(true_labels)) |
24 c_matrix = confusion_matrix(true_labels, predicted_labels) | 32 c_matrix = confusion_matrix(true_labels, predicted_labels) |
25 data = [ | 33 fig, ax = plt.subplots(figsize=(7, 7)) |
26 go.Heatmap( | 34 im = plt.imshow(c_matrix, cmap="viridis") |
27 z=c_matrix, | 35 # add number of samples to each cell of confusion matrix plot |
28 x=axis_labels, | 36 for i in range(len(c_matrix)): |
29 y=axis_labels, | 37 for j in range(len(c_matrix)): |
30 colorscale='Portland', | 38 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") |
31 ) | 39 ax.set_ylabel("True class labels") |
32 ] | 40 ax.set_xlabel("Predicted class labels") |
41 ax.set_title("Confusion Matrix between true and predicted class labels") | |
42 ax.set_xticks(axis_labels) | |
43 ax.set_yticks(axis_labels) | |
44 fig.colorbar(im, ax=ax) | |
45 fig.tight_layout() | |
46 plt.savefig("output_confusion.png", dpi=120) | |
33 | 47 |
34 layout = go.Layout( | 48 # plot precision, recall and f_score for each class label |
35 title='Confusion Matrix between true and predicted class labels', | 49 precision, recall, f_score, _ = precision_recall_fscore_support( |
36 xaxis=dict(title='Predicted class labels'), | 50 true_labels, predicted_labels |
37 yaxis=dict(title='True class labels') | |
38 ) | 51 ) |
39 | 52 |
40 fig = go.Figure(data=data, layout=layout) | |
41 plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False) | |
42 | |
43 # plot precision, recall and f_score for each class label | |
44 precision, recall, f_score, _ = precision_recall_fscore_support(true_labels, predicted_labels) | |
45 | |
46 trace_precision = go.Scatter( | 53 trace_precision = go.Scatter( |
47 x=axis_labels, | 54 x=axis_labels, y=precision, mode="lines+markers", name="Precision" |
48 y=precision, | |
49 mode='lines+markers', | |
50 name='Precision' | |
51 ) | 55 ) |
52 | 56 |
53 trace_recall = go.Scatter( | 57 trace_recall = go.Scatter( |
54 x=axis_labels, | 58 x=axis_labels, y=recall, mode="lines+markers", name="Recall" |
55 y=recall, | |
56 mode='lines+markers', | |
57 name='Recall' | |
58 ) | 59 ) |
59 | 60 |
60 trace_fscore = go.Scatter( | 61 trace_fscore = go.Scatter( |
61 x=axis_labels, | 62 x=axis_labels, y=f_score, mode="lines+markers", name="F-score" |
62 y=f_score, | |
63 mode='lines+markers', | |
64 name='F-score' | |
65 ) | 63 ) |
66 | 64 |
67 layout_prf = go.Layout( | 65 layout_prf = go.Layout( |
68 title='Precision, recall and f-score of true and predicted class labels', | 66 title="Precision, recall and f-score of true and predicted class labels", |
69 xaxis=dict(title='Class labels'), | 67 xaxis=dict(title="Class labels"), |
70 yaxis=dict(title='Precision, recall and f-score') | 68 yaxis=dict(title="Precision, recall and f-score"), |
71 ) | 69 ) |
72 | 70 |
73 data_prf = [trace_precision, trace_recall, trace_fscore] | 71 data_prf = [trace_precision, trace_recall, trace_fscore] |
74 fig_prf = go.Figure(data=data_prf, layout=layout_prf) | 72 fig_prf = go.Figure(data=data_prf, layout=layout_prf) |
75 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False) | 73 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False) |
76 | 74 |
77 # plot roc and auc curves for different classes | 75 # plot roc and auc curves for different classes |
78 with open(infile_trained_model, 'rb') as model_file: | 76 classifier_object = load_model_from_h5(infile_trained_model) |
79 model = pickle.load(model_file) | 77 model = clean_params(classifier_object) |
80 | 78 |
81 # remove the last column (label column) | 79 # remove the last column (label column) |
82 test_data = df_input.iloc[:, :-1] | 80 test_data = df_input.iloc[:, :-1] |
83 model_items = dir(model) | 81 model_items = dir(model) |
84 | 82 |
85 try: | 83 try: |
86 # find the probability estimating method | 84 # find the probability estimating method |
87 if 'predict_proba' in model_items: | 85 if "predict_proba" in model_items: |
88 y_score = model.predict_proba(test_data) | 86 y_score = model.predict_proba(test_data) |
89 elif 'decision_function' in model_items: | 87 elif "decision_function" in model_items: |
90 y_score = model.decision_function(test_data) | 88 y_score = model.decision_function(test_data) |
91 | 89 |
92 true_labels_list = true_labels.tolist() | 90 true_labels_list = true_labels.tolist() |
93 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels) | 91 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels) |
94 data_roc = list() | 92 data_roc = list() |
102 roc_auc[i] = auc(fpr[i], tpr[i]) | 100 roc_auc[i] = auc(fpr[i], tpr[i]) |
103 for i in range(len(axis_labels)): | 101 for i in range(len(axis_labels)): |
104 trace = go.Scatter( | 102 trace = go.Scatter( |
105 x=fpr[i], | 103 x=fpr[i], |
106 y=tpr[i], | 104 y=tpr[i], |
107 mode='lines+markers', | 105 mode="lines+markers", |
108 name='ROC curve of class {0} (AUC = {1:0.2f})'.format(i, roc_auc[i]) | 106 name="ROC curve of class {0} (AUC = {1:0.2f})".format( |
107 i, roc_auc[i] | |
108 ), | |
109 ) | 109 ) |
110 data_roc.append(trace) | 110 data_roc.append(trace) |
111 else: | 111 else: |
112 try: | 112 try: |
113 y_score_binary = y_score[:, 1] | 113 y_score_binary = y_score[:, 1] |
114 except: | 114 except Exception: |
115 y_score_binary = y_score | 115 y_score_binary = y_score |
116 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1) | 116 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1) |
117 roc_auc = auc(fpr, tpr) | 117 roc_auc = auc(fpr, tpr) |
118 trace = go.Scatter( | 118 trace = go.Scatter( |
119 x=fpr, | 119 x=fpr, |
120 y=tpr, | 120 y=tpr, |
121 mode='lines+markers', | 121 mode="lines+markers", |
122 name='ROC curve (AUC = {0:0.2f})'.format(roc_auc) | 122 name="ROC curve (AUC = {0:0.2f})".format(roc_auc), |
123 ) | 123 ) |
124 data_roc.append(trace) | 124 data_roc.append(trace) |
125 | 125 |
126 trace_diag = go.Scatter( | 126 trace_diag = go.Scatter(x=[0, 1], y=[0, 1], mode="lines", name="Chance") |
127 x=[0, 1], | |
128 y=[0, 1], | |
129 mode='lines', | |
130 name='Chance' | |
131 ) | |
132 data_roc.append(trace_diag) | 127 data_roc.append(trace_diag) |
133 layout_roc = go.Layout( | 128 layout_roc = go.Layout( |
134 title='Receiver operating characteristics (ROC) and area under curve (AUC)', | 129 title="Receiver operating characteristics (ROC) and area under curve (AUC)", |
135 xaxis=dict(title='False positive rate'), | 130 xaxis=dict(title="False positive rate"), |
136 yaxis=dict(title='True positive rate') | 131 yaxis=dict(title="True positive rate"), |
137 ) | 132 ) |
138 | 133 |
139 fig_roc = go.Figure(data=data_roc, layout=layout_roc) | 134 fig_roc = go.Figure(data=data_roc, layout=layout_roc) |
140 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False) | 135 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False) |
141 | 136 |
142 except Exception as exp: | 137 except Exception as exp: |
143 print("Plotting the ROC-AUC graph failed. This exception was raised: {}".format(exp)) | 138 print( |
139 "Plotting the ROC-AUC graph failed. This exception was raised: {}".format( | |
140 exp | |
141 ) | |
142 ) | |
144 | 143 |
145 | 144 |
146 if __name__ == "__main__": | 145 if __name__ == "__main__": |
147 aparser = argparse.ArgumentParser() | 146 aparser = argparse.ArgumentParser() |
148 aparser.add_argument("-i", "--input", dest="infile_input", required=True) | 147 aparser.add_argument("-i", "--input", dest="infile_input", required=True) |