annotate main.py @ 4:3957cd124013 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
author bgruening
date Tue, 11 Feb 2025 10:14:02 +0000
parents 4a92db686946
children 7808193b5626
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
1 """
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
2 Tabular data prediction using TabPFN
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
3 """
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
4 import argparse
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
5 import time
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
6
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
7 import matplotlib.pyplot as plt
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
8 import numpy as np
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
9 import pandas as pd
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
10 from sklearn.metrics import (
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
11 average_precision_score,
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
12 precision_recall_curve,
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
13 r2_score,
4
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
14 root_mean_squared_error,
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
15 )
4
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
16 from sklearn.preprocessing import label_binarize
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
17 from tabpfn import TabPFNClassifier, TabPFNRegressor
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
18
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
19
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
20 def separate_features_labels(data):
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
21 df = pd.read_csv(data, sep="\t")
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
22 labels = df.iloc[:, -1]
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
23 features = df.iloc[:, :-1]
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
24 return features, labels
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
25
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
26
4
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
27 def classification_plot(y_true, y_scores):
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
28 plt.figure(figsize=(8, 6))
4
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
29 is_binary = len(np.unique(y_true)) == 2
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
30 if is_binary:
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
31 # Compute precision-recall curve
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
32 precision, recall, _ = precision_recall_curve(y_true, y_scores[:, 1])
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
33 average_precision = average_precision_score(y_true, y_scores[:, 1])
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
34 plt.plot(
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
35 recall,
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
36 precision,
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
37 label=f"Precision-Recall Curve (AP={average_precision:.2f})",
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
38 )
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
39 plt.title("Precision-Recall Curve (binary classification)")
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
40 else:
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
41 y_true_bin = label_binarize(y_true, classes=np.unique(y_true))
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
42 n_classes = y_true_bin.shape[1]
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
43 class_labels = [f"Class {i}" for i in range(n_classes)]
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
44 # Plot PR curve for each class
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
45 for i in range(n_classes):
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
46 precision, recall, _ = precision_recall_curve(
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
47 y_true_bin[:, i], y_scores[:, i]
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
48 )
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
49 ap_score = average_precision_score(y_true_bin[:, i], y_scores[:, i])
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
50 plt.plot(
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
51 recall, precision, label=f"{class_labels[i]} (AP = {ap_score:.2f})"
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
52 )
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
53 # Compute micro-average PR curve
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
54 precision, recall, _ = precision_recall_curve(
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
55 y_true_bin.ravel(), y_scores.ravel()
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
56 )
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
57 plt.plot(
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
58 recall, precision, linestyle="--", color="black", label="Micro-average"
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
59 )
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
60 plt.title("Precision-Recall Curve (Multiclass Classification)")
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
61 plt.xlabel("Recall")
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
62 plt.ylabel("Precision")
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
63 plt.legend(loc="lower left")
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
64 plt.grid(True)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
65 plt.savefig("output_plot.png")
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
66
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
67
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
68 def regression_plot(xval, yval, title, xlabel, ylabel):
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
69 plt.figure(figsize=(8, 6))
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
70 plt.xlabel(xlabel)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
71 plt.ylabel(ylabel)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
72 plt.title(title)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
73 plt.legend(loc="lower left")
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
74 plt.grid(True)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
75 plt.scatter(xval, yval, alpha=0.8)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
76 xticks = np.arange(len(xval))
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
77 plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x")
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
78 plt.savefig("output_plot.png")
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
79
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
80
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
81 def train_evaluate(args):
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
82 """
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
83 Train TabPFN and predict
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
84 """
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
85 # prepare train data
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
86 tr_features, tr_labels = separate_features_labels(args["train_data"])
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
87 # prepare test data
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
88 if args["testhaslabels"] == "haslabels":
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
89 te_features, te_labels = separate_features_labels(args["test_data"])
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
90 else:
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
91 te_features = pd.read_csv(args["test_data"], sep="\t")
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
92 te_labels = []
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
93 s_time = time.time()
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
94 if args["selected_task"] == "Classification":
3
4a92db686946 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit c1c3b6d5abd35875890c45baacab073b5e749537
bgruening
parents: 2
diff changeset
95 classifier = TabPFNClassifier()
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
96 classifier.fit(tr_features, tr_labels)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
97 y_eval = classifier.predict(te_features)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
98 pred_probas_test = classifier.predict_proba(te_features)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
99 if len(te_labels) > 0:
4
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
100 classification_plot(te_labels, pred_probas_test)
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
101 else:
3
4a92db686946 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit c1c3b6d5abd35875890c45baacab073b5e749537
bgruening
parents: 2
diff changeset
102 regressor = TabPFNRegressor()
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
103 regressor.fit(tr_features, tr_labels)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
104 y_eval = regressor.predict(te_features)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
105 if len(te_labels) > 0:
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
106 score = root_mean_squared_error(te_labels, y_eval)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
107 r2_metric_score = r2_score(te_labels, y_eval)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
108 regression_plot(
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
109 te_labels,
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
110 y_eval,
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
111 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
112 "True values",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
113 "Predicted values",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
114 )
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
115 e_time = time.time()
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
116 print(
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
117 "Time taken by TabPFN for training and prediction: {} seconds".format(
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
118 e_time - s_time
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
119 )
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
120 )
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
121 te_features["predicted_labels"] = y_eval
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
122 te_features.to_csv("output_predicted_data", sep="\t", index=None)
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
123
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
124
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
125 if __name__ == "__main__":
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
126 arg_parser = argparse.ArgumentParser()
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
127 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data")
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
128 arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data")
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
129 arg_parser.add_argument(
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
130 "-testhaslabels",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
131 "--testhaslabels",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
132 required=True,
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
133 help="if test data contain labels",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
134 )
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
135 arg_parser.add_argument(
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
136 "-selectedtask",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
137 "--selected_task",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
138 required=True,
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
139 help="Type of machine learning task",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
140 )
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
141 # get argument values
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
142 args = vars(arg_parser.parse_args())
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
143 train_evaluate(args)