Mercurial > repos > bgruening > tabpfn
comparison main.py @ 5:7808193b5626 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
| author | bgruening |
|---|---|
| date | Wed, 26 Mar 2025 16:32:35 +0000 |
| parents | 3957cd124013 |
| children |
comparison
equal
deleted
inserted
replaced
| 4:3957cd124013 | 5:7808193b5626 |
|---|---|
| 55 y_true_bin.ravel(), y_scores.ravel() | 55 y_true_bin.ravel(), y_scores.ravel() |
| 56 ) | 56 ) |
| 57 plt.plot( | 57 plt.plot( |
| 58 recall, precision, linestyle="--", color="black", label="Micro-average" | 58 recall, precision, linestyle="--", color="black", label="Micro-average" |
| 59 ) | 59 ) |
| 60 plt.title("Precision-Recall Curve (Multiclass Classification)") | 60 plt.title( |
| 61 "Precision-Recall Curve (Multiclass Classification)" | |
| 62 ) | |
| 61 plt.xlabel("Recall") | 63 plt.xlabel("Recall") |
| 62 plt.ylabel("Precision") | 64 plt.ylabel("Precision") |
| 63 plt.legend(loc="lower left") | 65 plt.legend(loc="lower left") |
| 64 plt.grid(True) | 66 plt.grid(True) |
| 65 plt.savefig("output_plot.png") | 67 plt.savefig("output_plot.png") |
| 83 Train TabPFN and predict | 85 Train TabPFN and predict |
| 84 """ | 86 """ |
| 85 # prepare train data | 87 # prepare train data |
| 86 tr_features, tr_labels = separate_features_labels(args["train_data"]) | 88 tr_features, tr_labels = separate_features_labels(args["train_data"]) |
| 87 # prepare test data | 89 # prepare test data |
| 88 if args["testhaslabels"] == "haslabels": | 90 if args["testhaslabels"] == "true": |
| 89 te_features, te_labels = separate_features_labels(args["test_data"]) | 91 te_features, te_labels = separate_features_labels(args["test_data"]) |
| 90 else: | 92 else: |
| 91 te_features = pd.read_csv(args["test_data"], sep="\t") | 93 te_features = pd.read_csv(args["test_data"], sep="\t") |
| 92 te_labels = [] | 94 te_labels = [] |
| 93 s_time = time.time() | 95 s_time = time.time() |
| 94 if args["selected_task"] == "Classification": | 96 if args["selected_task"] == "Classification": |
| 95 classifier = TabPFNClassifier() | 97 classifier = TabPFNClassifier(random_state=42) |
| 96 classifier.fit(tr_features, tr_labels) | 98 classifier.fit(tr_features, tr_labels) |
| 97 y_eval = classifier.predict(te_features) | 99 y_eval = classifier.predict(te_features) |
| 98 pred_probas_test = classifier.predict_proba(te_features) | 100 pred_probas_test = classifier.predict_proba(te_features) |
| 99 if len(te_labels) > 0: | 101 if len(te_labels) > 0: |
| 100 classification_plot(te_labels, pred_probas_test) | 102 classification_plot(te_labels, pred_probas_test) |
| 103 te_features["predicted_labels"] = y_eval | |
| 104 te_features.to_csv( | |
| 105 "output_predicted_data", sep="\t", index=None | |
| 106 ) | |
| 101 else: | 107 else: |
| 102 regressor = TabPFNRegressor() | 108 regressor = TabPFNRegressor(random_state=42) |
| 103 regressor.fit(tr_features, tr_labels) | 109 regressor.fit(tr_features, tr_labels) |
| 104 y_eval = regressor.predict(te_features) | 110 y_eval = regressor.predict(te_features) |
| 105 if len(te_labels) > 0: | 111 if len(te_labels) > 0: |
| 106 score = root_mean_squared_error(te_labels, y_eval) | 112 score = root_mean_squared_error(te_labels, y_eval) |
| 107 r2_metric_score = r2_score(te_labels, y_eval) | 113 r2_metric_score = r2_score(te_labels, y_eval) |
| 110 y_eval, | 116 y_eval, |
| 111 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}", | 117 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}", |
| 112 "True values", | 118 "True values", |
| 113 "Predicted values", | 119 "Predicted values", |
| 114 ) | 120 ) |
| 121 te_features["predicted_labels"] = y_eval | |
| 122 te_features.to_csv( | |
| 123 "output_predicted_data", sep="\t", index=None | |
| 124 ) | |
| 115 e_time = time.time() | 125 e_time = time.time() |
| 116 print( | 126 print( |
| 117 "Time taken by TabPFN for training and prediction: {} seconds".format( | 127 f"Time taken by TabPFN for training and prediction: {e_time - s_time} seconds" |
| 118 e_time - s_time | |
| 119 ) | |
| 120 ) | 128 ) |
| 121 te_features["predicted_labels"] = y_eval | |
| 122 te_features.to_csv("output_predicted_data", sep="\t", index=None) | |
| 123 | 129 |
| 124 | 130 |
| 125 if __name__ == "__main__": | 131 if __name__ == "__main__": |
| 126 arg_parser = argparse.ArgumentParser() | 132 arg_parser = argparse.ArgumentParser() |
| 127 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") | 133 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") |
