view flexynesis_plot.py @ 3:525c661a7fdc draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flexynesis commit b2463fb68d0ae54864d87718ee72f5e063aa4587
author bgruening
date Tue, 24 Jun 2025 05:55:40 +0000
parents
children
line wrap: on
line source

#!/usr/bin/env python
"""Generate plots using flexynesis
This script generates dimensionality reduction plots, Kaplan-Meier survival curves,
and Cox proportional hazards models from data processed by flexynesis."""

import argparse
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from flexynesis import (
    build_cox_model,
    get_important_features,
    plot_dim_reduced,
    plot_hazard_ratios,
    plot_kaplan_meier_curves,
    plot_pr_curves,
    plot_roc_curves,
    plot_scatter
)
from scipy.stats import kruskal, mannwhitneyu


def load_embeddings(embeddings_path):
    """Load embeddings from a file"""
    try:
        # Determine file extension
        file_ext = Path(embeddings_path).suffix.lower()

        if file_ext == '.csv':
            df = pd.read_csv(embeddings_path, index_col=0)
        elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
            df = pd.read_csv(embeddings_path, sep='\t', index_col=0)
        else:
            raise ValueError(f"Unsupported file extension: {file_ext}")

        return df, df.index.tolist()

    except Exception as e:
        raise ValueError(f"Error loading embeddings from {embeddings_path}: {e}") from e


def load_labels(labels_input):
    """Load predicted labels from flexynesis"""
    try:
        # Determine file extension
        file_ext = Path(labels_input).suffix.lower()

        if file_ext == '.csv':
            df = pd.read_csv(labels_input)
        elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
            df = pd.read_csv(labels_input, sep='\t')

        # Check if this is the specific format with sample_id, known_label, predicted_label
        required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
        if all(col in df.columns for col in required_cols):
            return df
        else:
            raise ValueError(f"Labels file {labels_input} does not contain required columns: {required_cols}")

    except Exception as e:
        raise ValueError(f"Error loading labels from {labels_input}: {e}") from e


def load_survival_data(survival_path):
    """Load survival data from a file. First column should be sample_id"""
    try:
        # Determine file extension
        file_ext = Path(survival_path).suffix.lower()

        if file_ext == '.csv':
            df = pd.read_csv(survival_path, index_col=0)
        elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
            df = pd.read_csv(survival_path, sep='\t', index_col=0)
        else:
            raise ValueError(f"Unsupported file extension: {file_ext}")
        return df

    except Exception as e:
        raise ValueError(f"Error loading survival data from {survival_path}: {e}") from e


def load_omics(omics_path):
    """Load omics data from a file. First column should be features"""
    try:
        # Determine file extension
        file_ext = Path(omics_path).suffix.lower()

        if file_ext == '.csv':
            df = pd.read_csv(omics_path, index_col=0)
        elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
            df = pd.read_csv(omics_path, sep='\t', index_col=0)
        else:
            raise ValueError(f"Unsupported file extension: {file_ext}")
        return df

    except Exception as e:
        raise ValueError(f"Error loading omics data from {omics_path}: {e}") from e


def load_model(model_path):
    """Load flexynesis model from pickle file"""
    try:
        with open(model_path, 'rb') as f:
            model = torch.load(f, weights_only=False)
        return model
    except Exception as e:
        raise ValueError(f"Error loading model from {model_path}: {e}") from e


def match_samples_to_embeddings(sample_names, label_data):
    """Filter label data to match sample names in the embeddings"""
    df_matched = label_data[label_data['sample_id'].isin(sample_names)]
    return df_matched


def detect_color_type(labels_series):
    """Auto-detect whether target variables should be treated as categorical or numerical"""
    # Remove NaN
    clean_labels = labels_series.dropna()

    if clean_labels.empty:
        return 'categorical'  # default output if no labels

    # Check if all values can be converted to numbers
    try:
        numeric_labels = pd.to_numeric(clean_labels, errors='coerce')

        # If conversion failed -> categorical
        if numeric_labels.isna().any():
            return 'categorical'

        # Check number of unique values
        unique_count = len(clean_labels.unique())
        total_count = len(clean_labels)

        # If few unique values relative to total -> categorical
        # Threshold: if unique values < 10 OR unique/total < 0.1
        if unique_count < 10 or (unique_count / total_count) < 0.1:
            return 'categorical'
        else:
            return 'numerical'

    except Exception:
        return 'categorical'


def plot_label_concordance_heatmap(labels1, labels2, figsize=(12, 10)):
    """
    Plot a heatmap reflecting the concordance between two sets of labels using pandas crosstab.

    Parameters:
    - labels1: The first set of labels.
    - labels2: The second set of labels.
    """
    # Compute the cross-tabulation
    ct = pd.crosstab(pd.Series(labels1, name='Labels Set 1'), pd.Series(labels2, name='Labels Set 2'))
    # Normalize the cross-tabulation matrix column-wise
    ct_normalized = ct.div(ct.sum(axis=1), axis=0)

    # Plot the heatmap
    plt.figure(figsize=figsize)
    sns.heatmap(ct_normalized, annot=True, cmap='viridis', linewidths=.5)  # col_cluster=False)
    plt.title('Concordance between label groups')

    return plt.gcf()


def plot_boxplot(categorical_x, numerical_y, title_x='Categories', title_y='Values', figsize=(10, 6), jittersize=4):
    """
    Create a boxplot with to visualize the distribution of predicted probabilities across different categories.
    the x axis represents the true labels, and the y axis represents the predicted probabilities for specific categories.
    """
    df = pd.DataFrame({title_x: categorical_x, title_y: numerical_y})

    # Compute p-value
    groups = df[title_x].unique()
    if len(groups) == 2:
        group1 = df[df[title_x] == groups[0]][title_y]
        group2 = df[df[title_x] == groups[1]][title_y]
        stat, p = mannwhitneyu(group1, group2, alternative='two-sided')
        test_name = "Mann-Whitney U"
    else:
        group_data = [df[df[title_x] == group][title_y] for group in groups]
        stat, p = kruskal(*group_data)
        test_name = "Kruskal-Wallis"

    # Create a boxplot with jittered points
    plt.figure(figsize=figsize)
    sns.boxplot(x=title_x, y=title_y, hue=title_x, data=df, palette='Set2', legend=False, fill=False)
    sns.stripplot(x=title_x, y=title_y, data=df, color='black', size=jittersize, jitter=True, dodge=True, alpha=0.4)

    # Labels and p-value annotation
    plt.xlabel(title_x)
    plt.ylabel(title_y)
    plt.text(
        x=-0.4,
        y=plt.ylim()[1],
        s=f'{test_name} p = {p:.3e}',
        verticalalignment='top',
        horizontalalignment='left',
        fontsize=12,
        bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='gray')
    )

    plt.tight_layout()
    return plt.gcf()


def generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base):
    """Generate dimensionality reduction plots"""

    # Parse target variables
    target_vars = [var.strip() for var in args.target_variables.split(',')]

    print(f"Generating {args.method.upper()} plots for {len(target_vars)} target variable(s): {', '.join(target_vars)}")

    # Check variables
    available_vars = matched_labels['variable'].unique()
    missing_vars = [var for var in target_vars if var not in available_vars]

    if missing_vars:
        print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}")
        print(f"Available variables: {', '.join(available_vars)}")

    # Filter to only process available variables
    valid_vars = [var for var in target_vars if var in available_vars]

    if not valid_vars:
        raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}")

    # Generate plots for each valid target variable
    for var in valid_vars:
        print(f"\nPlotting variable: {var}")

        # Filter matched labels for current variable
        var_labels = matched_labels[matched_labels['variable'] == var].copy()
        var_labels = var_labels.drop_duplicates(subset='sample_id')

        if var_labels.empty:
            print(f"Warning: No data found for variable '{var}', skipping...")
            continue

        # Auto-detect color type
        known_color_type = detect_color_type(var_labels['known_label'])
        predicted_color_type = detect_color_type(var_labels['predicted_label'])

        print(f"  Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}")

        try:
            # Plot 1: Known labels
            print(f"  Creating known labels plot for {var}...")
            fig_known = plot_dim_reduced(
                matrix=embeddings,
                labels=var_labels['known_label'],
                method=args.method,
                color_type=known_color_type
            )

            output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}"
            print(f"  Saving known labels plot to: {output_path_known.name}")
            fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight')

            # Plot 2: Predicted labels
            print(f"  Creating predicted labels plot for {var}...")
            fig_predicted = plot_dim_reduced(
                matrix=embeddings,
                labels=var_labels['predicted_label'],
                method=args.method,
                color_type=predicted_color_type
            )

            output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}"
            print(f"  Saving predicted labels plot to: {output_path_predicted.name}")
            fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight')

            print(f"  ✓ Successfully created plots for variable '{var}'")

        except Exception as e:
            print(f"  ✗ Error creating plots for variable '{var}': {e}")
            continue

    print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!")


def generate_km_plots(survival_data, label_data, args, output_dir, output_name_base):
    """Generate Kaplan-Meier plots"""
    print("Generating Kaplan-Meier curves of risk subtypes...")

    # Reset index and rename the index column to sample_id
    survival_data = survival_data.reset_index()
    if survival_data.columns[0] != 'sample_id':
        survival_data = survival_data.rename(columns={survival_data.columns[0]: 'sample_id'})

    # Convert survival event column to binary (0/1) based on event_value
    # Check if the event column exists
    if args.surv_event_var not in survival_data.columns:
        raise ValueError(f"Column '{args.surv_event_var}' not found in survival data")

    # Convert to string for comparison to handle mixed types
    survival_data[args.surv_event_var] = survival_data[args.surv_event_var].astype(str)
    event_value_str = str(args.event_value)

    # Create binary event column (1 if matches event_value, 0 otherwise)
    survival_data[f'{args.surv_event_var}_binary'] = (
        survival_data[args.surv_event_var] == event_value_str
    ).astype(int)

    # Filter for survival category and class_label == '1:DECEASED'
    label_data['class_label'] = label_data['class_label'].astype(str)

    label_data = label_data[(label_data['variable'] == args.surv_event_var) & (label_data['class_label'] == event_value_str)]

    # check survival data
    for col in [args.surv_time_var, args.surv_event_var]:
        if col not in survival_data.columns:
            raise ValueError(f"Column '{col}' not found in survival data")

    # Merge survival data with labels
    df_deceased = pd.merge(survival_data, label_data, on='sample_id', how='inner')

    if df_deceased.empty:
        raise ValueError("No matching samples found after merging survival and label data.")

    # Get risk scores
    risk_scores = df_deceased['probability'].values

    # Compute groups (e.g., median split)
    quantiles = np.quantile(risk_scores, [0.5])
    groups = np.digitize(risk_scores, quantiles)
    group_labels = ['low_risk' if g == 0 else 'high_risk' for g in groups]

    fig_known = plot_kaplan_meier_curves(
        durations=df_deceased[args.surv_time_var],
        events=df_deceased[f'{args.surv_event_var}_binary'],
        categorical_variable=group_labels
    )

    output_path_known = output_dir / f"{output_name_base}_km_risk_subtypes.{args.format}"
    print(f"Saving Kaplan-Meier plot to: {output_path_known.absolute()}")
    fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight')

    print("Kaplan-Meier plot saved successfully!")


def generate_cox_plots(model, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base):
    """Generate Cox proportional hazards plots"""
    print("Generating Cox proportional hazards analysis...")

    # Parse clinical variables
    clinical_vars = [var.strip() for var in args.clinical_variables.split(',')]

    # Validate that survival variables are included
    required_vars = [args.surv_time_var, args.surv_event_var]
    for var in required_vars:
        if var not in clinical_vars:
            clinical_vars.append(var)

    print(f"Using clinical variables: {', '.join(clinical_vars)}")

    # filter datasets for clinical variables
    if all(var in clinical_train.columns and var in clinical_test.columns for var in clinical_vars):
        df_clin_train = clinical_train[clinical_vars]
        df_clin_test = clinical_test[clinical_vars]
        # Drop rows with NaN in clinical variables
        df_clin_train = df_clin_train.dropna(subset=clinical_vars)
        df_clin_test = df_clin_test.dropna(subset=clinical_vars)
    else:
        raise ValueError(f"Not all clinical variables found in datasets. Available in train dataset: {clinical_train.columns.tolist()}, Available in test dataset: {clinical_test.columns.tolist()}")

    # Combine
    df_clin = pd.concat([df_clin_train, df_clin_test], axis=0)

    # Get top survival markers
    print(f"Extracting top {args.top_features} important features for {args.surv_event_var}...")
    try:
        imp = get_important_features(model,
                                     var=args.surv_event_var,
                                     top=args.top_features
                                     )['name'].unique().tolist()
        print(f"Top features: {', '.join(imp)}")
    except Exception as e:
        raise ValueError(f"Error getting important features: {e}")

    # Extract feature data from omics datasets
    try:
        omics_test = omics_test.loc[omics_test.index.isin(imp)]
        omics_train = omics_train.loc[omics_train.index.isin(imp)]
        # Drop rows with NaN in omics datasets
        omics_test = omics_test.dropna(subset=omics_test.columns)
        omics_train = omics_train.dropna(subset=omics_train.columns)

        df_imp = pd.concat([omics_train, omics_test], axis=1)
        df_imp = df_imp.T  # Transpose to have samples as rows

        print(f"Feature data shape: {df_imp.shape}")
    except Exception as e:
        raise ValueError(f"Error extracting feature subset: {e}")

    # Combine markers with clinical variables
    df = pd.merge(df_imp, df_clin, left_index=True, right_index=True)
    print(f"Combined data shape: {df.shape}")

    # Remove samples without survival endpoints
    initial_samples = len(df)
    df = df[df[args.surv_event_var].notna()]
    final_samples = len(df)
    print(f"Removed {initial_samples - final_samples} samples without survival data")

    if df.empty:
        raise ValueError("No samples remain after filtering for survival data")

    # Convert survival event column to binary (0/1) based on event_value
    # Convert to string for comparison to handle mixed types
    df[args.surv_event_var] = df[args.surv_event_var].astype(str)
    event_value_str = str(args.event_value)

    df[f'{args.surv_event_var}'] = (
        df[args.surv_event_var] == event_value_str
    ).astype(int)

    # Build Cox model
    print(f"Building Cox model with time variable: {args.surv_time_var}, event variable: {args.surv_event_var}")
    try:
        coxm = build_cox_model(df,
                               duration_col=args.surv_time_var,
                               event_col=args.surv_event_var,
                               crossval=args.crossval,
                               n_splits=args.n_splits,
                               random_state=args.random_state)
        print("Cox model built successfully")
    except Exception as e:
        raise ValueError(f"Error building Cox model: {e}")

    # Generate hazard ratios plot
    try:
        print("Generating hazard ratios plot...")
        fig = plot_hazard_ratios(coxm)

        output_path = output_dir / f"{output_name_base}_hazard_ratios.{args.format}"
        print(f"Saving hazard ratios plot to: {output_path.absolute()}")
        fig.save(output_path, dpi=args.dpi, bbox_inches='tight')

        print("Cox proportional hazards analysis completed successfully!")

    except Exception as e:
        raise ValueError(f"Error generating hazard ratios plot: {e}")


def generate_plot_scatter(labels, args, output_dir, output_name_base):
    """Generate scatter plot of known vs predicted labels"""
    print("Generating scatter plots of known vs predicted labels...")

    # Parse target values from comma-separated string
    if args.target_value:
        target_values = [val.strip() for val in args.target_value.split(',')]
    else:
        # If no target values specified, use all unique variables
        target_values = labels['variable'].unique().tolist()

    print(f"Processing target values: {target_values}")

    successful_plots = 0
    skipped_plots = 0

    for target_value in target_values:
        print(f"\nProcessing target value: '{target_value}'")

        # Filter labels for the current target value
        target_labels = labels[labels['variable'] == target_value]

        if target_labels.empty:
            print(f"  Warning: No data found for target value '{target_value}' - skipping")
            skipped_plots += 1
            continue

        # Check if labels are numeric and convert
        true_values = pd.to_numeric(target_labels['known_label'], errors='coerce')
        predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce')

        if true_values.isna().all() or predicted_values.isna().all():
            print(f"No valid numeric values found for known or predicted labels in '{target_value}'")
            skipped_plots += 1
            continue

        try:
            print(f"  Generating scatter plot for '{target_value}'...")
            fig = plot_scatter(true_values, predicted_values)

            # Create output filename with target value
            safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
            if len(target_values) > 1:
                output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
            else:
                output_filename = f"{output_name_base}.{args.format}"

            output_path = output_dir / output_filename
            print(f"  Saving scatter plot to: {output_path.absolute()}")
            fig.save(output_path, dpi=args.dpi, bbox_inches='tight')

            successful_plots += 1
            print(f"  Scatter plot for '{target_value}' generated successfully!")

        except Exception as e:
            print(f"  Error generating plot for '{target_value}': {str(e)}")
            skipped_plots += 1

    # Summary
    print("  Summary:")
    print(f"  Successfully generated: {successful_plots} plots")
    print(f"  Skipped: {skipped_plots} plots")

    if successful_plots == 0:
        raise ValueError("No scatter plots could be generated. Check your data and target values.")

    print("Scatter plot generation completed!")


def generate_label_concordance_heatmap(labels, args, output_dir, output_name_base):
    """Generate label concordance heatmap"""
    print("Generating label concordance heatmaps...")

    # Parse target values from comma-separated string
    if args.target_value:
        target_values = [val.strip() for val in args.target_value.split(',')]
    else:
        # If no target values specified, use all unique variables
        target_values = labels['variable'].unique().tolist()

    print(f"Processing target values: {target_values}")

    for target_value in target_values:
        print(f"\nProcessing target value: '{target_value}'")

        # Filter labels for the current target value
        target_labels = labels[labels['variable'] == target_value]

        if target_labels.empty:
            print(f"  Warning: No data found for target value '{target_value}' - skipping")
            continue

        true_values = target_labels['known_label'].tolist()
        predicted_values = target_labels['predicted_label'].tolist()

        try:
            print(f"  Generating heatmap for '{target_value}'...")
            fig = plot_label_concordance_heatmap(true_values, predicted_values)
            plt.close(fig)

            # Create output filename with target value
            safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
            if len(target_values) > 1:
                output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
            else:
                output_filename = f"{output_name_base}.{args.format}"

            output_path = output_dir / output_filename
            print(f"  Saving heatmap to: {output_path.absolute()}")
            fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight')

        except Exception as e:
            print(f"  Error generating heatmap for '{target_value}': {str(e)}")
            continue

    print("Label concordance heatmap generated successfully!")


def generate_pr_curves(labels, args, output_dir, output_name_base):
    """Generate precision-recall curves"""
    print("Generating precision-recall curves...")

    # Parse target values from comma-separated string
    if args.target_value:
        target_values = [val.strip() for val in args.target_value.split(',')]
    else:
        # If no target values specified, use all unique variables
        target_values = labels['variable'].unique().tolist()

    print(f"Processing target values: {target_values}")

    for target_value in target_values:
        print(f"\nProcessing target value: '{target_value}'")

        # Filter labels for the current target value
        target_labels = labels[labels['variable'] == target_value]

        # Check if this is a regression problem (no class probabilities)
        prob_columns = target_labels['class_label'].unique()
        non_na_probs = target_labels['probability'].notna().sum()

        print(f"  Class labels found: {list(prob_columns)}")
        print(f"  Non-NaN probabilities: {non_na_probs}/{len(target_labels)}")

        # If most probabilities are NaN, this is likely a regression problem
        if non_na_probs < len(target_labels) * 0.1:  # Less than 10% valid probabilities
            print("  Detected regression problem - precision-recall curves not applicable")
            print(f"  Skipping '{target_value}' (use regression evaluation metrics instead)")
            continue

        # Debug: Check data quality
        total_rows = len(target_labels)
        missing_labels = target_labels['known_label'].isna().sum()
        missing_probs = target_labels['probability'].isna().sum()
        unique_samples = target_labels['sample_id'].nunique()
        unique_classes = target_labels['class_label'].nunique()

        print(f"  Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes")
        print(f"  Missing data: {missing_labels} missing known_label, {missing_probs} missing probability")

        if missing_labels > 0:
            print(f"  Warning: Found {missing_labels} missing known_label values")
            missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5]
            print(f"  Sample IDs with missing known_label: {list(missing_samples)}")

            # Remove rows with missing known_label
            target_labels = target_labels.dropna(subset=['known_label'])
            if target_labels.empty:
                print(f"  Error: No valid known_label data remaining for '{target_value}' - skipping")
                continue

        # 1. Pivot to wide format
        prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability')

        print(f"  After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes")
        print(f"  Class columns: {list(prob_df.columns)}")

        # Check for NaN values in probability data
        nan_counts = prob_df.isna().sum()
        if nan_counts.any():
            print(f"  NaN counts per class: {dict(nan_counts)}")
            print(f"  Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}")

            # Drop only rows where ALL probabilities are NaN
            all_nan_rows = prob_df.isna().all(axis=1)
            if all_nan_rows.any():
                print(f"  Dropping {all_nan_rows.sum()} samples with all NaN probabilities")
                prob_df = prob_df[~all_nan_rows]

            remaining_nans = prob_df.isna().sum().sum()
            if remaining_nans > 0:
                print(f"  Warning: {remaining_nans} individual NaN values remain - filling with 0")
                prob_df = prob_df.fillna(0)

            if prob_df.empty:
                print(f"  Error: No valid probability data remaining for '{target_value}' - skipping")
                continue

        # 2. Get true labels
        true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id')

        # 3. Align indices - only keep samples that exist in both datasets
        common_indices = prob_df.index.intersection(true_labels_df.index)
        if len(common_indices) == 0:
            print(f"  Error: No common sample_ids between probability and true label data for '{target_value}' - skipping")
            continue

        print(f"  Found {len(common_indices)} samples with both probability and true label data")

        # Filter both datasets to common indices
        prob_df_aligned = prob_df.loc[common_indices]
        y_true = true_labels_df.loc[common_indices]['known_label']

        # 4. Final check for NaN values
        if y_true.isna().any():
            print(f"  Error: True labels still contain NaN after alignment for '{target_value}' - skipping")
            continue

        if prob_df_aligned.isna().any().any():
            print(f"  Error: Probability data still contains NaN after alignment for '{target_value}' - skipping")
            continue

        # 5. Convert categorical labels to integer labels
        # Create a mapping from class names to integers
        class_names = list(prob_df_aligned.columns)
        class_to_int = {class_name: i for i, class_name in enumerate(class_names)}

        print(f"  Class mapping: {class_to_int}")

        # Convert true labels to integers
        y_true_np = y_true.map(class_to_int).to_numpy()
        y_probs_np = prob_df_aligned.to_numpy()

        print(f"  Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}")
        print(f"  Unique true labels (integers): {set(y_true_np)}")
        print(f"  Class labels (columns): {class_names}")
        print(f"  Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}")

        # Check for any unmapped labels (will be NaN)
        if pd.isna(y_true_np).any():
            print("  Error: Some true labels could not be mapped to class columns")
            unmapped_labels = set(y_true[y_true.map(class_to_int).isna()])
            print(f"  Unmapped labels: {unmapped_labels}")
            print(f"  Available classes: {class_names}")
            continue

        try:
            print(f"  Generating precision-recall curve for '{target_value}'...")
            fig = plot_pr_curves(y_true_np, y_probs_np)

            # Create output filename with target value
            safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
            if len(target_values) > 1:
                output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
            else:
                output_filename = f"{output_name_base}.{args.format}"

            output_path = output_dir / output_filename
            print(f"  Saving precision-recall curve to: {output_path.absolute()}")
            fig.save(output_path, dpi=args.dpi, bbox_inches='tight')

        except Exception as e:
            print(f"  Error generating precision-recall curve for '{target_value}': {str(e)}")
            print(f"  Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}")
            print(f"  Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}")
            continue

    print("Precision-recall curves generated successfully!")


def generate_roc_curves(labels, args, output_dir, output_name_base):
    """Generate ROC curves"""
    print("Generating ROC curves...")

    # Parse target values from comma-separated string
    if args.target_value:
        target_values = [val.strip() for val in args.target_value.split(',')]
    else:
        # If no target values specified, use all unique variables
        target_values = labels['variable'].unique().tolist()

    print(f"Processing target values: {target_values}")

    for target_value in target_values:
        print(f"\nProcessing target value: '{target_value}'")

        # Filter labels for the current target value
        target_labels = labels[labels['variable'] == target_value]

        # Check if this is a regression problem (no class probabilities)
        prob_columns = target_labels['class_label'].unique()
        non_na_probs = target_labels['probability'].notna().sum()

        print(f"  Class labels found: {list(prob_columns)}")
        print(f"  Non-NaN probabilities: {non_na_probs}/{len(target_labels)}")

        # If most probabilities are NaN, this is likely a regression problem
        if non_na_probs < len(target_labels) * 0.1:  # Less than 10% valid probabilities
            print("  Detected regression problem - ROC curves not applicable")
            print(f"  Skipping '{target_value}' (use regression evaluation metrics instead)")
            continue

        # Debug: Check data quality
        total_rows = len(target_labels)
        missing_labels = target_labels['known_label'].isna().sum()
        missing_probs = target_labels['probability'].isna().sum()
        unique_samples = target_labels['sample_id'].nunique()
        unique_classes = target_labels['class_label'].nunique()

        print(f"  Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes")
        print(f"  Missing data: {missing_labels} missing known_label, {missing_probs} missing probability")

        if missing_labels > 0:
            print(f"  Warning: Found {missing_labels} missing known_label values")
            missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5]
            print(f"  Sample IDs with missing known_label: {list(missing_samples)}")

            # Remove rows with missing known_label
            target_labels = target_labels.dropna(subset=['known_label'])
            if target_labels.empty:
                print(f"  Error: No valid known_label data remaining for '{target_value}' - skipping")
                continue

        # 1. Pivot to wide format
        prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability')

        print(f"  After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes")
        print(f"  Class columns: {list(prob_df.columns)}")

        # Check for NaN values in probability data
        nan_counts = prob_df.isna().sum()
        if nan_counts.any():
            print(f"  NaN counts per class: {dict(nan_counts)}")
            print(f"  Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}")

            # Drop only rows where ALL probabilities are NaN
            all_nan_rows = prob_df.isna().all(axis=1)
            if all_nan_rows.any():
                print(f"  Dropping {all_nan_rows.sum()} samples with all NaN probabilities")
                prob_df = prob_df[~all_nan_rows]

            remaining_nans = prob_df.isna().sum().sum()
            if remaining_nans > 0:
                print(f"  Warning: {remaining_nans} individual NaN values remain - filling with 0")
                prob_df = prob_df.fillna(0)

            if prob_df.empty:
                print(f"  Error: No valid probability data remaining for '{target_value}' - skipping")
                continue

        # 2. Get true labels
        true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id')

        # 3. Align indices - only keep samples that exist in both datasets
        common_indices = prob_df.index.intersection(true_labels_df.index)
        if len(common_indices) == 0:
            print(f"  Error: No common sample_ids between probability and true label data for '{target_value}' - skipping")
            continue

        print(f"  Found {len(common_indices)} samples with both probability and true label data")

        # Filter both datasets to common indices
        prob_df_aligned = prob_df.loc[common_indices]
        y_true = true_labels_df.loc[common_indices]['known_label']

        # 4. Final check for NaN values
        if y_true.isna().any():
            print(f"  Error: True labels still contain NaN after alignment for '{target_value}' - skipping")
            continue

        if prob_df_aligned.isna().any().any():
            print(f"  Error: Probability data still contains NaN after alignment for '{target_value}' - skipping")
            continue

        # 5. Convert categorical labels to integer labels
        # Create a mapping from class names to integers
        class_names = list(prob_df_aligned.columns)
        class_to_int = {class_name: i for i, class_name in enumerate(class_names)}

        print(f"  Class mapping: {class_to_int}")

        # Convert true labels to integers
        y_true_np = y_true.map(class_to_int).to_numpy()
        y_probs_np = prob_df_aligned.to_numpy()

        print(f"  Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}")
        print(f"  Unique true labels (integers): {set(y_true_np)}")
        print(f"  Class labels (columns): {class_names}")
        print(f"  Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}")

        # Check for any unmapped labels (will be NaN)
        if pd.isna(y_true_np).any():
            print("  Error: Some true labels could not be mapped to class columns")
            unmapped_labels = set(y_true[y_true.map(class_to_int).isna()])
            print(f"  Unmapped labels: {unmapped_labels}")
            print(f"  Available classes: {class_names}")
            continue

        try:
            print(f"  Generating ROC curve for '{target_value}'...")
            fig = plot_roc_curves(y_true_np, y_probs_np)

            # Create output filename with target value
            safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
            if len(target_values) > 1:
                output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
            else:
                output_filename = f"{output_name_base}.{args.format}"

            output_path = output_dir / output_filename
            print(f"  Saving ROC curve to: {output_path.absolute()}")
            fig.save(output_path, dpi=args.dpi, bbox_inches='tight')

        except Exception as e:
            print(f"  Error generating ROC curve for '{target_value}': {str(e)}")
            print(f"  Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}")
            print(f"  Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}")
            continue

    print("ROC curves generated successfully!")


def generate_box_plots(labels, args, output_dir, output_name_base):
    """Generate box plots for model predictions"""

    print("Generating box plots...")

    # Parse target values from comma-separated string
    if args.target_value:
        target_values = [val.strip() for val in args.target_value.split(',')]
    else:
        # If no target values specified, use all unique variables
        target_values = labels['variable'].unique().tolist()

    print(f"Processing target values: {target_values}")

    for target_value in target_values:
        print(f"\nProcessing target value: '{target_value}'")

        # Filter labels for the current target value
        target_labels = labels[labels['variable'] == target_value]

        if target_labels.empty:
            print(f"  Warning: No data found for target value '{target_value}' - skipping")
            continue

        # Check if this is a classification problem (has probabilities)
        prob_columns = target_labels['class_label'].unique()
        non_na_probs = target_labels['probability'].notna().sum()

        print(f"  Class labels found: {list(prob_columns)}")
        print(f"  Non-NaN probabilities: {non_na_probs}/{len(target_labels)}")

        # If most probabilities are NaN, this is likely a regression problem
        if non_na_probs < len(target_labels) * 0.1:  # Less than 10% valid probabilities
            print("  Detected regression problem - precision-recall curves not applicable")
            print(f"  Skipping '{target_value}' (use regression evaluation metrics instead)")
            continue

        # Debug: Check data quality
        total_rows = len(target_labels)
        missing_labels = target_labels['known_label'].isna().sum()
        missing_probs = target_labels['probability'].isna().sum()
        unique_samples = target_labels['sample_id'].nunique()
        unique_classes = target_labels['class_label'].nunique()

        print(f"  Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes")
        print(f"  Missing data: {missing_labels} missing known_label, {missing_probs} missing probability")

        if missing_labels > 0:
            print(f"  Warning: Found {missing_labels} missing known_label values")
            missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5]
            print(f"  Sample IDs with missing known_label: {list(missing_samples)}")

            # Remove rows with missing known_label
            target_labels = target_labels.dropna(subset=['known_label'])
            if target_labels.empty:
                print(f"  Error: No valid known_label data remaining for '{target_value}' - skipping")
                continue

        # Remove rows with missing data
        clean_data = target_labels.dropna(subset=['known_label', 'probability'])

        if clean_data.empty:
            print("    No valid data after cleaning - skipping")
            continue

        # Get unique classes
        classes = clean_data['class_label'].unique()

        for class_label in classes:
            print(f"    Generating box plot for class: {class_label}")

            # Filter for current class
            class_data = clean_data[clean_data['class_label'] == class_label]

            try:
                # Create the box plot
                fig = plot_boxplot(
                    categorical_x=class_data['known_label'],
                    numerical_y=class_data['probability'],
                    title_x='True Label',
                    title_y=f'Predicted Probability ({class_label})',
                )

                # Save the plot
                safe_class_name = str(class_label).replace('/', '_').replace('\\', '_').replace(' ', '_').replace(':', '_')
                safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
                output_filename = f"{output_name_base}_{safe_target_name}_{safe_class_name}.{args.format}"
                output_path = output_dir / output_filename

                print(f"      Saving box plot to: {output_path.absolute()}")
                fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight')
                plt.close(fig)

            except Exception as e:
                print(f"      Error generating box plot for class '{class_label}': {str(e)}")
                continue


def main():
    """Main function to parse arguments and generate plots"""
    parser = argparse.ArgumentParser(description="Generate plots using flexynesis")

    # Required arguments
    parser.add_argument("--labels", type=str, required=False,
                        help="Path to labels file generated by flexynesis")

    # Plot type
    parser.add_argument("--plot_type", type=str, required=True,
                        choices=['dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot'],
                        help="Type of plot to generate: 'dimred' for dimensionality reduction, 'kaplan_meier' for survival analysis, 'cox' for Cox proportional hazards analysis, 'scatter' for scatter plots, 'concordance_heatmap' for label concordance heatmaps, 'pr_curve' for precision-recall curves, 'roc_curve' for ROC curves, or 'box_plot' for box plots.")

    # Arguments for dimensionality reduction
    parser.add_argument("--embeddings", type=str,
                        help="Path to input data embeddings file (CSV or tabular format). Required for dimred plots.")
    parser.add_argument("--method", type=str, default='pca', choices=['pca', 'umap'],
                        help="Transformation method ('pca' or 'umap'). Default is 'pca'. Used for dimred plots.")
    parser.add_argument("--target_variables", type=str, required=False,
                        help="Comma-separated list of target variables to plot.")

    # Arguments for Kaplan-Meier
    parser.add_argument("--survival_data", type=str,
                        help="Path to survival data file with columns: duration and event. Required for kaplan_meier plots.")
    parser.add_argument("--surv_time_var", type=str, required=False,
                        help="Column name for survival time")
    parser.add_argument("--surv_event_var", type=str, required=False,
                        help="Column name for survival event")
    parser.add_argument("--event_value", type=str, required=False,
                        help="Value in event column that represents an event (e.g., 'DECEASED')")

    # Arguments for Cox analysis
    parser.add_argument("--model", type=str,
                        help="Path to trained flexynesis model (pickle file). Required for cox plots.")
    parser.add_argument("--clinical_train", type=str,
                        help="Path to training dataset (pickle file). Required for cox plots.")
    parser.add_argument("--clinical_test", type=str,
                        help="Path to test dataset (pickle file). Required for cox plots.")
    parser.add_argument("--omics_train", type=str, default=None,
                        help="Path to training omics dataset. Optional for cox plots.")
    parser.add_argument("--omics_test", type=str, default=None,
                        help="Path to test omics dataset. Optional for cox plots.")
    parser.add_argument("--clinical_variables", type=str,
                        help="Comma-separated list of clinical variables to include in Cox model (e.g., 'AGE,SEX,HISTOLOGICAL_DIAGNOSIS,STUDY')")
    parser.add_argument("--top_features", type=int, default=20,
                        help="Number of top important features to include in Cox model. Default is 5")
    parser.add_argument("--crossval", action='store_true',
                        help="If True, performs K-fold cross-validation and returns average C-index. Default is False")
    parser.add_argument("--n_splits", type=int, default=5,
                        help="Number of folds for cross-validation. Default is 5")
    parser.add_argument("--random_state", type=int, default=42,
                        help="Random seed for reproducibility. Default is 42")

    # Arguments for scatter plot, heatmap, PR curves, ROC curves, and box plots
    parser.add_argument("--target_value", type=str, default=None,
                        help="Target value for scatter plot.")

    # Common arguments
    parser.add_argument("--output_dir", type=str, default='output',
                        help="Output directory. Default is 'output'")
    parser.add_argument("--output_name", type=str, default=None,
                        help="Output filename base")
    parser.add_argument("--format", type=str, default='jpg', choices=['png', 'pdf', 'svg', 'jpg'],
                        help="Output format for the plot. Default is 'jpg'")
    parser.add_argument("--dpi", type=int, default=300,
                        help="DPI for the output image. Default is 300")

    args = parser.parse_args()

    try:
        # validate plot type
        if not args.plot_type:
            raise ValueError("Please specify a plot type using --plot_type")
        if args.plot_type not in ['dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot']:
            raise ValueError(f"Invalid plot type: {args.plot_type}. Must be one of: 'dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot'")

        # Validate plot type requirements
        if args.plot_type in ['dimred']:
            if not args.embeddings:
                raise ValueError("--embeddings is required when plot_type is 'dimred'")
            if not os.path.isfile(args.embeddings):
                raise FileNotFoundError(f"embeddings file not found: {args.embeddings}")
            if not args.labels:
                raise ValueError("--labels is required for dimensionality reduction plots")
            if not args.method:
                raise ValueError("--method is required for dimensionality reduction plots")
            if not args.target_variables:
                raise ValueError("--target_variables is required for dimensionality reduction plots")

        if args.plot_type in ['kaplan_meier']:
            if not args.survival_data:
                raise ValueError("--survival_data is required when plot_type is 'kaplan_meier'")
            if not os.path.isfile(args.survival_data):
                raise FileNotFoundError(f"Survival data file not found: {args.survival_data}")
            if not args.labels:
                raise ValueError("--labels is required for dimensionality reduction plots")
            if not args.method:
                raise ValueError("--method is required for dimensionality reduction plots")
            if not args.surv_time_var:
                raise ValueError("--surv_time_var is required for Kaplan-Meier plots")
            if not args.surv_event_var:
                raise ValueError("--surv_event_var is required for Kaplan-Meier plots")
            if not args.event_value:
                raise ValueError("--event_value is required for Kaplan-Meier plots")

        if args.plot_type in ['cox']:
            if not args.model:
                raise ValueError("--model is required when plot_type is 'cox'")
            if not os.path.isfile(args.model):
                raise FileNotFoundError(f"Model file not found: {args.model}")
            if not args.clinical_train:
                raise ValueError("--clinical_train is required when plot_type is 'cox'")
            if not os.path.isfile(args.clinical_train):
                raise FileNotFoundError(f"Training dataset file not found: {args.clinical_train}")
            if not args.clinical_test:
                raise ValueError("--clinical_test is required when plot_type is 'cox'")
            if not os.path.isfile(args.clinical_test):
                raise FileNotFoundError(f"Test dataset file not found: {args.clinical_test}")
            if not args.omics_train:
                raise ValueError("--omics_train is required when plot_type is 'cox'")
            if not os.path.isfile(args.omics_train):
                raise FileNotFoundError(f"Training omics dataset file not found: {args.omics_train}")
            if not args.omics_test:
                raise ValueError("--omics_test is required when plot_type is 'cox'")
            if not os.path.isfile(args.omics_test):
                raise FileNotFoundError(f"Test omics dataset file not found: {args.omics_test}")
            if not args.surv_time_var:
                raise ValueError("--surv_time_var is required for Cox plots")
            if not args.surv_event_var:
                raise ValueError("--surv_event_var is required for Cox plots")
            if not args.clinical_variables:
                raise ValueError("--clinical_variables is required for Cox plots")
            if not isinstance(args.top_features, int) or args.top_features <= 0:
                raise ValueError("--top_features must be a positive integer")
            if not args.event_value:
                raise ValueError("--event_value is required for Kaplan-Meier plots")
            if not args.crossval:
                args.crossval = False
            if not isinstance(args.n_splits, int) or args.n_splits <= 0:
                raise ValueError("--n_splits must be a positive integer")
            if not isinstance(args.random_state, int):
                raise ValueError("--random_state must be an integer")

        if args.plot_type in ['scatter']:
            if not args.labels:
                raise ValueError("--labels is required for scatter plots")
            if not args.target_value:
                print("--target_value is not specified, using all unique variables from labels")
            if not os.path.isfile(args.labels):
                raise FileNotFoundError(f"Labels file not found: {args.labels}")

        if args.plot_type in ['concordance_heatmap']:
            if not args.labels:
                raise ValueError("--labels is required for concordance heatmap")
            if not args.target_value:
                print("--target_value is not specified, using all unique variables from labels")
            if not os.path.isfile(args.labels):
                raise FileNotFoundError(f"Labels file not found: {args.labels}")

        if args.plot_type in ['pr_curve']:
            if not args.labels:
                raise ValueError("--labels is required for precision-recall curves")
            if not args.target_value:
                print("--target_value is not specified, using all unique variables from labels")
            if not os.path.isfile(args.labels):
                raise FileNotFoundError(f"Labels file not found: {args.labels}")

        if args.plot_type in ['roc_curve']:
            if not args.labels:
                raise ValueError("--labels is required for ROC curves")
            if not args.target_value:
                print("--target_value is not specified, using all unique variables from labels")
            if not os.path.isfile(args.labels):
                raise FileNotFoundError(f"Labels file not found: {args.labels}")

        if args.plot_type in ['box_plot']:
            if not args.labels:
                raise ValueError("--labels is required for box plots")
            if not args.target_value:
                print("--target_value is not specified, using all unique variables from labels")
            if not os.path.isfile(args.labels):
                raise FileNotFoundError(f"Labels file not found: {args.labels}")

        # Validate other arguments
        if args.method not in ['pca', 'umap']:
            raise ValueError("Method must be 'pca' or 'umap'")

        # Create output directory
        output_dir = Path(args.output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        print(f"Output directory: {output_dir.absolute()}")

        # Generate output filename base
        if args.output_name:
            output_name_base = args.output_name
        else:
            if args.plot_type == 'dimred':
                embeddings_name = Path(args.embeddings).stem
                output_name_base = f"{embeddings_name}_{args.method}"
            elif args.plot_type == 'kaplan_meier':
                survival_name = Path(args.survival_data).stem
                output_name_base = f"{survival_name}_km"
            elif args.plot_type == 'cox':
                model_name = Path(args.model).stem
                output_name_base = f"{model_name}_cox"
            elif args.plot_type == 'scatter':
                labels_name = Path(args.labels).stem
                output_name_base = f"{labels_name}_scatter"
            elif args.plot_type == 'concordance_heatmap':
                labels_name = Path(args.labels).stem
                output_name_base = f"{labels_name}_concordance"
            elif args.plot_type == 'pr_curve':
                labels_name = Path(args.labels).stem
                output_name_base = f"{labels_name}_pr_curves"
            elif args.plot_type == 'roc_curve':
                labels_name = Path(args.labels).stem
                output_name_base = f"{labels_name}_roc_curves"
            elif args.plot_type == 'box_plot':
                labels_name = Path(args.labels).stem
                output_name_base = f"{labels_name}_box_plot"

        # Generate plots based on type
        if args.plot_type in ['dimred']:
            # Load labels
            print(f"Loading labels from: {args.labels}")
            label_data = load_labels(args.labels)
            # Load embeddings data
            print(f"Loading embeddings from: {args.embeddings}")
            embeddings, sample_names = load_embeddings(args.embeddings)
            print(f"embeddings shape: {embeddings.shape}")

            # Match samples to embeddings
            matched_labels = match_samples_to_embeddings(sample_names, label_data)
            print(f"Successfully matched {len(matched_labels)} samples for dimensionality reduction")

            generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base)

        elif args.plot_type in ['kaplan_meier']:
            # Load labels
            print(f"Loading labels from: {args.labels}")
            label_data = load_labels(args.labels)
            # Load survival data
            print(f"Loading survival data from: {args.survival_data}")
            survival_data = load_survival_data(args.survival_data)
            print(f"Survival data shape: {survival_data.shape}")

            generate_km_plots(survival_data, label_data, args, output_dir, output_name_base)

        elif args.plot_type in ['cox']:
            # Load model and datasets
            print(f"Loading model from: {args.model}")
            model = load_model(args.model)
            print(f"Loading training dataset from: {args.clinical_train}")
            clinical_train = load_omics(args.clinical_train)
            print(f"Loading test dataset from: {args.clinical_test}")
            clinical_test = load_omics(args.clinical_test)
            print(f"Loading training omics dataset from: {args.omics_train}")
            omics_train = load_omics(args.omics_train)
            print(f"Loading test omics dataset from: {args.omics_test}")
            omics_test = load_omics(args.omics_test)

            generate_cox_plots(model, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base)

        elif args.plot_type in ['scatter']:
            # Load labels
            print(f"Loading labels from: {args.labels}")
            label_data = load_labels(args.labels)

            generate_plot_scatter(label_data, args, output_dir, output_name_base)

        elif args.plot_type in ['concordance_heatmap']:
            # Load labels
            print(f"Loading labels from: {args.labels}")
            label_data = load_labels(args.labels)

            generate_label_concordance_heatmap(label_data, args, output_dir, output_name_base)

        elif args.plot_type in ['pr_curve']:
            # Load labels
            print(f"Loading labels from: {args.labels}")
            label_data = load_labels(args.labels)

            generate_pr_curves(label_data, args, output_dir, output_name_base)

        elif args.plot_type in ['roc_curve']:
            # Load labels
            print(f"Loading labels from: {args.labels}")
            label_data = load_labels(args.labels)

            generate_roc_curves(label_data, args, output_dir, output_name_base)

        elif args.plot_type in ['box_plot']:
            # Load labels
            print(f"Loading labels from: {args.labels}")
            label_data = load_labels(args.labels)

            generate_box_plots(label_data, args, output_dir, output_name_base)

        print("All plots generated successfully!")

    except (FileNotFoundError, ValueError, pd.errors.ParserError) as e:
        print(f"Error: {e}")
        return 1

    return 0


if __name__ == "__main__":
    exit(main())