Mercurial > repos > bgruening > sklearn_searchcv
diff train_test_split.py @ 23:bc3b489825b2 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 80417bf0158a9b596e485dd66408f738f405145a
author | bgruening |
---|---|
date | Mon, 02 Oct 2023 07:59:32 +0000 |
parents | 006db575e1f3 |
children |
line wrap: on
line diff
--- a/train_test_split.py Thu Aug 11 07:41:31 2022 +0000 +++ b/train_test_split.py Mon Oct 02 07:59:32 2023 +0000 @@ -1,8 +1,10 @@ import argparse import json import warnings +from distutils.version import LooseVersion as Version import pandas as pd +from galaxy_ml import __version__ as galaxy_ml_version from galaxy_ml.model_validations import train_test_split from galaxy_ml.utils import get_cv, read_columns @@ -69,7 +71,10 @@ y = df.iloc[:, col_index].values # construct the cv splitter object - splitter, groups = get_cv(params["mode_selection"]["cv_selector"]) + cv_selector = params["mode_selection"]["cv_selector"] + if Version(galaxy_ml_version) < Version("0.8.3"): + cv_selector.pop("n_stratification_bins", None) + splitter, groups = get_cv(cv_selector) total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) if nth_split > total_n_splits: