comparison plot_ml_performance.py @ 3:e73eb091612b draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
author bgruening
date Wed, 07 Aug 2024 10:20:05 +0000
parents 2cfa4aabda3e
children
comparison
equal deleted inserted replaced
2:2cfa4aabda3e 3:e73eb091612b
1 import argparse 1 import argparse
2
3 import matplotlib.pyplot as plt
2 import pandas as pd 4 import pandas as pd
3 import plotly 5 import plotly
4 import pickle
5 import plotly.graph_objs as go 6 import plotly.graph_objs as go
6 from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve, auc 7 from galaxy_ml.model_persist import load_model_from_h5
8 from galaxy_ml.utils import clean_params
9 from sklearn.metrics import (
10 auc,
11 confusion_matrix,
12 precision_recall_fscore_support,
13 roc_curve,
14 )
7 from sklearn.preprocessing import label_binarize 15 from sklearn.preprocessing import label_binarize
8 16
9 17
10 def main(infile_input, infile_output, infile_trained_model): 18 def main(infile_input, infile_output, infile_trained_model):
11 """ 19 """
12 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots 20 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots
13 Args: 21 Args:
14 infile_input: str, input tabular file with true labels 22 infile_input: str, input tabular file with true labels
15 infile_output: str, input tabular file with predicted labels 23 infile_output: str, input tabular file with predicted labels
16 infile_trained_model: str, input trained model file (zip) 24 infile_trained_model: str, input trained model file (h5mlm)
17 """ 25 """
18 26
19 df_input = pd.read_csv(infile_input, sep='\t', parse_dates=True) 27 df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True)
20 df_output = pd.read_csv(infile_output, sep='\t', parse_dates=True) 28 df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True)
21 true_labels = df_input.iloc[:, -1].copy() 29 true_labels = df_input.iloc[:, -1].copy()
22 predicted_labels = df_output.iloc[:, -1].copy() 30 predicted_labels = df_output.iloc[:, -1].copy()
23 axis_labels = list(set(true_labels)) 31 axis_labels = list(set(true_labels))
24 c_matrix = confusion_matrix(true_labels, predicted_labels) 32 c_matrix = confusion_matrix(true_labels, predicted_labels)
25 data = [ 33 fig, ax = plt.subplots(figsize=(7, 7))
26 go.Heatmap( 34 im = plt.imshow(c_matrix, cmap="viridis")
27 z=c_matrix, 35 # add number of samples to each cell of confusion matrix plot
28 x=axis_labels, 36 for i in range(len(c_matrix)):
29 y=axis_labels, 37 for j in range(len(c_matrix)):
30 colorscale='Portland', 38 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
31 ) 39 ax.set_ylabel("True class labels")
32 ] 40 ax.set_xlabel("Predicted class labels")
41 ax.set_title("Confusion Matrix between true and predicted class labels")
42 ax.set_xticks(axis_labels)
43 ax.set_yticks(axis_labels)
44 fig.colorbar(im, ax=ax)
45 fig.tight_layout()
46 plt.savefig("output_confusion.png", dpi=120)
33 47
34 layout = go.Layout( 48 # plot precision, recall and f_score for each class label
35 title='Confusion Matrix between true and predicted class labels', 49 precision, recall, f_score, _ = precision_recall_fscore_support(
36 xaxis=dict(title='Predicted class labels'), 50 true_labels, predicted_labels
37 yaxis=dict(title='True class labels')
38 ) 51 )
39 52
40 fig = go.Figure(data=data, layout=layout)
41 plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False)
42
43 # plot precision, recall and f_score for each class label
44 precision, recall, f_score, _ = precision_recall_fscore_support(true_labels, predicted_labels)
45
46 trace_precision = go.Scatter( 53 trace_precision = go.Scatter(
47 x=axis_labels, 54 x=axis_labels, y=precision, mode="lines+markers", name="Precision"
48 y=precision,
49 mode='lines+markers',
50 name='Precision'
51 ) 55 )
52 56
53 trace_recall = go.Scatter( 57 trace_recall = go.Scatter(
54 x=axis_labels, 58 x=axis_labels, y=recall, mode="lines+markers", name="Recall"
55 y=recall,
56 mode='lines+markers',
57 name='Recall'
58 ) 59 )
59 60
60 trace_fscore = go.Scatter( 61 trace_fscore = go.Scatter(
61 x=axis_labels, 62 x=axis_labels, y=f_score, mode="lines+markers", name="F-score"
62 y=f_score,
63 mode='lines+markers',
64 name='F-score'
65 ) 63 )
66 64
67 layout_prf = go.Layout( 65 layout_prf = go.Layout(
68 title='Precision, recall and f-score of true and predicted class labels', 66 title="Precision, recall and f-score of true and predicted class labels",
69 xaxis=dict(title='Class labels'), 67 xaxis=dict(title="Class labels"),
70 yaxis=dict(title='Precision, recall and f-score') 68 yaxis=dict(title="Precision, recall and f-score"),
71 ) 69 )
72 70
73 data_prf = [trace_precision, trace_recall, trace_fscore] 71 data_prf = [trace_precision, trace_recall, trace_fscore]
74 fig_prf = go.Figure(data=data_prf, layout=layout_prf) 72 fig_prf = go.Figure(data=data_prf, layout=layout_prf)
75 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False) 73 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False)
76 74
77 # plot roc and auc curves for different classes 75 # plot roc and auc curves for different classes
78 with open(infile_trained_model, 'rb') as model_file: 76 classifier_object = load_model_from_h5(infile_trained_model)
79 model = pickle.load(model_file) 77 model = clean_params(classifier_object)
80 78
81 # remove the last column (label column) 79 # remove the last column (label column)
82 test_data = df_input.iloc[:, :-1] 80 test_data = df_input.iloc[:, :-1]
83 model_items = dir(model) 81 model_items = dir(model)
84 82
85 try: 83 try:
86 # find the probability estimating method 84 # find the probability estimating method
87 if 'predict_proba' in model_items: 85 if "predict_proba" in model_items:
88 y_score = model.predict_proba(test_data) 86 y_score = model.predict_proba(test_data)
89 elif 'decision_function' in model_items: 87 elif "decision_function" in model_items:
90 y_score = model.decision_function(test_data) 88 y_score = model.decision_function(test_data)
91 89
92 true_labels_list = true_labels.tolist() 90 true_labels_list = true_labels.tolist()
93 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels) 91 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels)
94 data_roc = list() 92 data_roc = list()
102 roc_auc[i] = auc(fpr[i], tpr[i]) 100 roc_auc[i] = auc(fpr[i], tpr[i])
103 for i in range(len(axis_labels)): 101 for i in range(len(axis_labels)):
104 trace = go.Scatter( 102 trace = go.Scatter(
105 x=fpr[i], 103 x=fpr[i],
106 y=tpr[i], 104 y=tpr[i],
107 mode='lines+markers', 105 mode="lines+markers",
108 name='ROC curve of class {0} (AUC = {1:0.2f})'.format(i, roc_auc[i]) 106 name="ROC curve of class {0} (AUC = {1:0.2f})".format(
107 i, roc_auc[i]
108 ),
109 ) 109 )
110 data_roc.append(trace) 110 data_roc.append(trace)
111 else: 111 else:
112 try: 112 try:
113 y_score_binary = y_score[:, 1] 113 y_score_binary = y_score[:, 1]
114 except: 114 except Exception:
115 y_score_binary = y_score 115 y_score_binary = y_score
116 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1) 116 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1)
117 roc_auc = auc(fpr, tpr) 117 roc_auc = auc(fpr, tpr)
118 trace = go.Scatter( 118 trace = go.Scatter(
119 x=fpr, 119 x=fpr,
120 y=tpr, 120 y=tpr,
121 mode='lines+markers', 121 mode="lines+markers",
122 name='ROC curve (AUC = {0:0.2f})'.format(roc_auc) 122 name="ROC curve (AUC = {0:0.2f})".format(roc_auc),
123 ) 123 )
124 data_roc.append(trace) 124 data_roc.append(trace)
125 125
126 trace_diag = go.Scatter( 126 trace_diag = go.Scatter(x=[0, 1], y=[0, 1], mode="lines", name="Chance")
127 x=[0, 1],
128 y=[0, 1],
129 mode='lines',
130 name='Chance'
131 )
132 data_roc.append(trace_diag) 127 data_roc.append(trace_diag)
133 layout_roc = go.Layout( 128 layout_roc = go.Layout(
134 title='Receiver operating characteristics (ROC) and area under curve (AUC)', 129 title="Receiver operating characteristics (ROC) and area under curve (AUC)",
135 xaxis=dict(title='False positive rate'), 130 xaxis=dict(title="False positive rate"),
136 yaxis=dict(title='True positive rate') 131 yaxis=dict(title="True positive rate"),
137 ) 132 )
138 133
139 fig_roc = go.Figure(data=data_roc, layout=layout_roc) 134 fig_roc = go.Figure(data=data_roc, layout=layout_roc)
140 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False) 135 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False)
141 136
142 except Exception as exp: 137 except Exception as exp:
143 print("Plotting the ROC-AUC graph failed. This exception was raised: {}".format(exp)) 138 print(
139 "Plotting the ROC-AUC graph failed. This exception was raised: {}".format(
140 exp
141 )
142 )
144 143
145 144
146 if __name__ == "__main__": 145 if __name__ == "__main__":
147 aparser = argparse.ArgumentParser() 146 aparser = argparse.ArgumentParser()
148 aparser.add_argument("-i", "--input", dest="infile_input", required=True) 147 aparser.add_argument("-i", "--input", dest="infile_input", required=True)