comparison main.py @ 2:abe1c3ac9145 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
author bgruening
date Fri, 17 Jan 2025 22:23:21 +0000
parents 2a0c6d2090f4
children 4a92db686946
comparison
equal deleted inserted replaced
1:bb60469276fd 2:abe1c3ac9145
3 """ 3 """
4 import argparse 4 import argparse
5 import time 5 import time
6 6
7 import matplotlib.pyplot as plt 7 import matplotlib.pyplot as plt
8 import numpy as np
8 import pandas as pd 9 import pandas as pd
9 from sklearn.metrics import accuracy_score, average_precision_score, precision_recall_curve 10 from sklearn.metrics import (
10 from tabpfn import TabPFNClassifier 11 average_precision_score,
12 precision_recall_curve,
13 r2_score,
14 root_mean_squared_error
15 )
16 from tabpfn import TabPFNClassifier, TabPFNRegressor
11 17
12 18
13 def separate_features_labels(data): 19 def separate_features_labels(data):
14 df = pd.read_csv(data, sep="\t") 20 df = pd.read_csv(data, sep="\t")
15 labels = df.iloc[:, -1] 21 labels = df.iloc[:, -1]
16 features = df.iloc[:, :-1] 22 features = df.iloc[:, :-1]
17 return features, labels 23 return features, labels
18 24
19 25
26 def classification_plot(xval, yval, leg_label, title, xlabel, ylabel):
27 plt.figure(figsize=(8, 6))
28 plt.plot(xval, yval, label=leg_label)
29 plt.xlabel(xlabel)
30 plt.ylabel(ylabel)
31 plt.title(title)
32 plt.legend(loc="lower left")
33 plt.grid(True)
34 plt.savefig("output_plot.png")
35
36
37 def regression_plot(xval, yval, title, xlabel, ylabel):
38 plt.figure(figsize=(8, 6))
39 plt.xlabel(xlabel)
40 plt.ylabel(ylabel)
41 plt.title(title)
42 plt.legend(loc="lower left")
43 plt.grid(True)
44 plt.scatter(xval, yval, alpha=0.8)
45 xticks = np.arange(len(xval))
46 plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x")
47 plt.savefig("output_plot.png")
48
49
20 def train_evaluate(args): 50 def train_evaluate(args):
21 """ 51 """
22 Train TabPFN 52 Train TabPFN and predict
23 """ 53 """
54 # prepare train data
24 tr_features, tr_labels = separate_features_labels(args["train_data"]) 55 tr_features, tr_labels = separate_features_labels(args["train_data"])
25 te_features, te_labels = separate_features_labels(args["test_data"]) 56 # prepare test data
26 classifier = TabPFNClassifier(device='cpu') 57 if args["testhaslabels"] == "haslabels":
58 te_features, te_labels = separate_features_labels(args["test_data"])
59 else:
60 te_features = pd.read_csv(args["test_data"], sep="\t")
61 te_labels = []
27 s_time = time.time() 62 s_time = time.time()
28 classifier.fit(tr_features, tr_labels) 63 if args["selected_task"] == "Classification":
64 classifier = TabPFNClassifier(device="cpu")
65 classifier.fit(tr_features, tr_labels)
66 y_eval = classifier.predict(te_features)
67 pred_probas_test = classifier.predict_proba(te_features)
68 if len(te_labels) > 0:
69 precision, recall, thresholds = precision_recall_curve(
70 te_labels, pred_probas_test[:, 1]
71 )
72 average_precision = average_precision_score(
73 te_labels, pred_probas_test[:, 1]
74 )
75 classification_plot(
76 recall,
77 precision,
78 f"Precision-Recall Curve (AP={average_precision:.2f})",
79 "Precision-Recall Curve",
80 "Recall",
81 "Precision",
82 )
83 else:
84 regressor = TabPFNRegressor(device="cpu")
85 regressor.fit(tr_features, tr_labels)
86 y_eval = regressor.predict(te_features)
87 if len(te_labels) > 0:
88 score = root_mean_squared_error(te_labels, y_eval)
89 r2_metric_score = r2_score(te_labels, y_eval)
90 regression_plot(
91 te_labels,
92 y_eval,
93 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}",
94 "True values",
95 "Predicted values",
96 )
29 e_time = time.time() 97 e_time = time.time()
30 print("Time taken by TabPFN for training: {} seconds".format(e_time - s_time)) 98 print(
31 y_eval = classifier.predict(te_features) 99 "Time taken by TabPFN for training and prediction: {} seconds".format(
32 print('Accuracy', accuracy_score(te_labels, y_eval)) 100 e_time - s_time
33 pred_probas_test = classifier.predict_proba(te_features) 101 )
102 )
34 te_features["predicted_labels"] = y_eval 103 te_features["predicted_labels"] = y_eval
35 te_features.to_csv("output_predicted_data", sep="\t", index=None) 104 te_features.to_csv("output_predicted_data", sep="\t", index=None)
36 precision, recall, thresholds = precision_recall_curve(te_labels, pred_probas_test[:, 1])
37 average_precision = average_precision_score(te_labels, pred_probas_test[:, 1])
38 plt.figure(figsize=(8, 6))
39 plt.plot(recall, precision, label=f'Precision-Recall Curve (AP={average_precision:.2f})')
40 plt.xlabel('Recall')
41 plt.ylabel('Precision')
42 plt.title('Precision-Recall Curve')
43 plt.legend(loc='lower left')
44 plt.grid(True)
45 plt.savefig("output_prec_recall_curve.png")
46 105
47 106
48 if __name__ == "__main__": 107 if __name__ == "__main__":
49 arg_parser = argparse.ArgumentParser() 108 arg_parser = argparse.ArgumentParser()
50 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") 109 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data")
51 arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data") 110 arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data")
111 arg_parser.add_argument(
112 "-testhaslabels",
113 "--testhaslabels",
114 required=True,
115 help="if test data contain labels",
116 )
117 arg_parser.add_argument(
118 "-selectedtask",
119 "--selected_task",
120 required=True,
121 help="Type of machine learning task",
122 )
52 # get argument values 123 # get argument values
53 args = vars(arg_parser.parse_args()) 124 args = vars(arg_parser.parse_args())
54 train_evaluate(args) 125 train_evaluate(args)