Mercurial > repos > bgruening > tabpfn
annotate main.py @ 4:3957cd124013 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
author | bgruening |
---|---|
date | Tue, 11 Feb 2025 10:14:02 +0000 |
parents | 4a92db686946 |
children | 7808193b5626 |
rev | line source |
---|---|
0
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
1 """ |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
2 Tabular data prediction using TabPFN |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
3 """ |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
4 import argparse |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
5 import time |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
6 |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
7 import matplotlib.pyplot as plt |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
8 import numpy as np |
0
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
9 import pandas as pd |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
10 from sklearn.metrics import ( |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
11 average_precision_score, |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
12 precision_recall_curve, |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
13 r2_score, |
4
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
14 root_mean_squared_error, |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
15 ) |
4
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
16 from sklearn.preprocessing import label_binarize |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
17 from tabpfn import TabPFNClassifier, TabPFNRegressor |
0
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
18 |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
19 |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
20 def separate_features_labels(data): |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
21 df = pd.read_csv(data, sep="\t") |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
22 labels = df.iloc[:, -1] |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
23 features = df.iloc[:, :-1] |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
24 return features, labels |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
25 |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
26 |
4
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
27 def classification_plot(y_true, y_scores): |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
28 plt.figure(figsize=(8, 6)) |
4
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
29 is_binary = len(np.unique(y_true)) == 2 |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
30 if is_binary: |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
31 # Compute precision-recall curve |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
32 precision, recall, _ = precision_recall_curve(y_true, y_scores[:, 1]) |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
33 average_precision = average_precision_score(y_true, y_scores[:, 1]) |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
34 plt.plot( |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
35 recall, |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
36 precision, |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
37 label=f"Precision-Recall Curve (AP={average_precision:.2f})", |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
38 ) |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
39 plt.title("Precision-Recall Curve (binary classification)") |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
40 else: |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
41 y_true_bin = label_binarize(y_true, classes=np.unique(y_true)) |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
42 n_classes = y_true_bin.shape[1] |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
43 class_labels = [f"Class {i}" for i in range(n_classes)] |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
44 # Plot PR curve for each class |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
45 for i in range(n_classes): |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
46 precision, recall, _ = precision_recall_curve( |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
47 y_true_bin[:, i], y_scores[:, i] |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
48 ) |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
49 ap_score = average_precision_score(y_true_bin[:, i], y_scores[:, i]) |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
50 plt.plot( |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
51 recall, precision, label=f"{class_labels[i]} (AP = {ap_score:.2f})" |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
52 ) |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
53 # Compute micro-average PR curve |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
54 precision, recall, _ = precision_recall_curve( |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
55 y_true_bin.ravel(), y_scores.ravel() |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
56 ) |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
57 plt.plot( |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
58 recall, precision, linestyle="--", color="black", label="Micro-average" |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
59 ) |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
60 plt.title("Precision-Recall Curve (Multiclass Classification)") |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
61 plt.xlabel("Recall") |
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
62 plt.ylabel("Precision") |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
63 plt.legend(loc="lower left") |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
64 plt.grid(True) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
65 plt.savefig("output_plot.png") |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
66 |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
67 |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
68 def regression_plot(xval, yval, title, xlabel, ylabel): |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
69 plt.figure(figsize=(8, 6)) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
70 plt.xlabel(xlabel) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
71 plt.ylabel(ylabel) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
72 plt.title(title) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
73 plt.legend(loc="lower left") |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
74 plt.grid(True) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
75 plt.scatter(xval, yval, alpha=0.8) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
76 xticks = np.arange(len(xval)) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
77 plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x") |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
78 plt.savefig("output_plot.png") |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
79 |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
80 |
0
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
81 def train_evaluate(args): |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
82 """ |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
83 Train TabPFN and predict |
0
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
84 """ |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
85 # prepare train data |
0
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
86 tr_features, tr_labels = separate_features_labels(args["train_data"]) |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
87 # prepare test data |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
88 if args["testhaslabels"] == "haslabels": |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
89 te_features, te_labels = separate_features_labels(args["test_data"]) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
90 else: |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
91 te_features = pd.read_csv(args["test_data"], sep="\t") |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
92 te_labels = [] |
0
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
93 s_time = time.time() |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
94 if args["selected_task"] == "Classification": |
3
4a92db686946
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit c1c3b6d5abd35875890c45baacab073b5e749537
bgruening
parents:
2
diff
changeset
|
95 classifier = TabPFNClassifier() |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
96 classifier.fit(tr_features, tr_labels) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
97 y_eval = classifier.predict(te_features) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
98 pred_probas_test = classifier.predict_proba(te_features) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
99 if len(te_labels) > 0: |
4
3957cd124013
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
100 classification_plot(te_labels, pred_probas_test) |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
101 else: |
3
4a92db686946
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit c1c3b6d5abd35875890c45baacab073b5e749537
bgruening
parents:
2
diff
changeset
|
102 regressor = TabPFNRegressor() |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
103 regressor.fit(tr_features, tr_labels) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
104 y_eval = regressor.predict(te_features) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
105 if len(te_labels) > 0: |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
106 score = root_mean_squared_error(te_labels, y_eval) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
107 r2_metric_score = r2_score(te_labels, y_eval) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
108 regression_plot( |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
109 te_labels, |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
110 y_eval, |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
111 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}", |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
112 "True values", |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
113 "Predicted values", |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
114 ) |
0
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
115 e_time = time.time() |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
116 print( |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
117 "Time taken by TabPFN for training and prediction: {} seconds".format( |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
118 e_time - s_time |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
119 ) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
120 ) |
0
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
121 te_features["predicted_labels"] = y_eval |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
122 te_features.to_csv("output_predicted_data", sep="\t", index=None) |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
123 |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
124 |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
125 if __name__ == "__main__": |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
126 arg_parser = argparse.ArgumentParser() |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
127 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
128 arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data") |
2
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
129 arg_parser.add_argument( |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
130 "-testhaslabels", |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
131 "--testhaslabels", |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
132 required=True, |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
133 help="if test data contain labels", |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
134 ) |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
135 arg_parser.add_argument( |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
136 "-selectedtask", |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
137 "--selected_task", |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
138 required=True, |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
139 help="Type of machine learning task", |
abe1c3ac9145
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
140 ) |
0
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
141 # get argument values |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
142 args = vars(arg_parser.parse_args()) |
2a0c6d2090f4
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
143 train_evaluate(args) |