diff flexynesis_plot.py @ 3:525c661a7fdc draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flexynesis commit b2463fb68d0ae54864d87718ee72f5e063aa4587
author bgruening
date Tue, 24 Jun 2025 05:55:40 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/flexynesis_plot.py	Tue Jun 24 05:55:40 2025 +0000
@@ -0,0 +1,1282 @@
+#!/usr/bin/env python
+"""Generate plots using flexynesis
+This script generates dimensionality reduction plots, Kaplan-Meier survival curves,
+and Cox proportional hazards models from data processed by flexynesis."""
+
+import argparse
+import os
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import seaborn as sns
+import torch
+from flexynesis import (
+    build_cox_model,
+    get_important_features,
+    plot_dim_reduced,
+    plot_hazard_ratios,
+    plot_kaplan_meier_curves,
+    plot_pr_curves,
+    plot_roc_curves,
+    plot_scatter
+)
+from scipy.stats import kruskal, mannwhitneyu
+
+
+def load_embeddings(embeddings_path):
+    """Load embeddings from a file"""
+    try:
+        # Determine file extension
+        file_ext = Path(embeddings_path).suffix.lower()
+
+        if file_ext == '.csv':
+            df = pd.read_csv(embeddings_path, index_col=0)
+        elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
+            df = pd.read_csv(embeddings_path, sep='\t', index_col=0)
+        else:
+            raise ValueError(f"Unsupported file extension: {file_ext}")
+
+        return df, df.index.tolist()
+
+    except Exception as e:
+        raise ValueError(f"Error loading embeddings from {embeddings_path}: {e}") from e
+
+
+def load_labels(labels_input):
+    """Load predicted labels from flexynesis"""
+    try:
+        # Determine file extension
+        file_ext = Path(labels_input).suffix.lower()
+
+        if file_ext == '.csv':
+            df = pd.read_csv(labels_input)
+        elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
+            df = pd.read_csv(labels_input, sep='\t')
+
+        # Check if this is the specific format with sample_id, known_label, predicted_label
+        required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
+        if all(col in df.columns for col in required_cols):
+            return df
+        else:
+            raise ValueError(f"Labels file {labels_input} does not contain required columns: {required_cols}")
+
+    except Exception as e:
+        raise ValueError(f"Error loading labels from {labels_input}: {e}") from e
+
+
+def load_survival_data(survival_path):
+    """Load survival data from a file. First column should be sample_id"""
+    try:
+        # Determine file extension
+        file_ext = Path(survival_path).suffix.lower()
+
+        if file_ext == '.csv':
+            df = pd.read_csv(survival_path, index_col=0)
+        elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
+            df = pd.read_csv(survival_path, sep='\t', index_col=0)
+        else:
+            raise ValueError(f"Unsupported file extension: {file_ext}")
+        return df
+
+    except Exception as e:
+        raise ValueError(f"Error loading survival data from {survival_path}: {e}") from e
+
+
+def load_omics(omics_path):
+    """Load omics data from a file. First column should be features"""
+    try:
+        # Determine file extension
+        file_ext = Path(omics_path).suffix.lower()
+
+        if file_ext == '.csv':
+            df = pd.read_csv(omics_path, index_col=0)
+        elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
+            df = pd.read_csv(omics_path, sep='\t', index_col=0)
+        else:
+            raise ValueError(f"Unsupported file extension: {file_ext}")
+        return df
+
+    except Exception as e:
+        raise ValueError(f"Error loading omics data from {omics_path}: {e}") from e
+
+
+def load_model(model_path):
+    """Load flexynesis model from pickle file"""
+    try:
+        with open(model_path, 'rb') as f:
+            model = torch.load(f, weights_only=False)
+        return model
+    except Exception as e:
+        raise ValueError(f"Error loading model from {model_path}: {e}") from e
+
+
+def match_samples_to_embeddings(sample_names, label_data):
+    """Filter label data to match sample names in the embeddings"""
+    df_matched = label_data[label_data['sample_id'].isin(sample_names)]
+    return df_matched
+
+
+def detect_color_type(labels_series):
+    """Auto-detect whether target variables should be treated as categorical or numerical"""
+    # Remove NaN
+    clean_labels = labels_series.dropna()
+
+    if clean_labels.empty:
+        return 'categorical'  # default output if no labels
+
+    # Check if all values can be converted to numbers
+    try:
+        numeric_labels = pd.to_numeric(clean_labels, errors='coerce')
+
+        # If conversion failed -> categorical
+        if numeric_labels.isna().any():
+            return 'categorical'
+
+        # Check number of unique values
+        unique_count = len(clean_labels.unique())
+        total_count = len(clean_labels)
+
+        # If few unique values relative to total -> categorical
+        # Threshold: if unique values < 10 OR unique/total < 0.1
+        if unique_count < 10 or (unique_count / total_count) < 0.1:
+            return 'categorical'
+        else:
+            return 'numerical'
+
+    except Exception:
+        return 'categorical'
+
+
+def plot_label_concordance_heatmap(labels1, labels2, figsize=(12, 10)):
+    """
+    Plot a heatmap reflecting the concordance between two sets of labels using pandas crosstab.
+
+    Parameters:
+    - labels1: The first set of labels.
+    - labels2: The second set of labels.
+    """
+    # Compute the cross-tabulation
+    ct = pd.crosstab(pd.Series(labels1, name='Labels Set 1'), pd.Series(labels2, name='Labels Set 2'))
+    # Normalize the cross-tabulation matrix column-wise
+    ct_normalized = ct.div(ct.sum(axis=1), axis=0)
+
+    # Plot the heatmap
+    plt.figure(figsize=figsize)
+    sns.heatmap(ct_normalized, annot=True, cmap='viridis', linewidths=.5)  # col_cluster=False)
+    plt.title('Concordance between label groups')
+
+    return plt.gcf()
+
+
+def plot_boxplot(categorical_x, numerical_y, title_x='Categories', title_y='Values', figsize=(10, 6), jittersize=4):
+    """
+    Create a boxplot with to visualize the distribution of predicted probabilities across different categories.
+    the x axis represents the true labels, and the y axis represents the predicted probabilities for specific categories.
+    """
+    df = pd.DataFrame({title_x: categorical_x, title_y: numerical_y})
+
+    # Compute p-value
+    groups = df[title_x].unique()
+    if len(groups) == 2:
+        group1 = df[df[title_x] == groups[0]][title_y]
+        group2 = df[df[title_x] == groups[1]][title_y]
+        stat, p = mannwhitneyu(group1, group2, alternative='two-sided')
+        test_name = "Mann-Whitney U"
+    else:
+        group_data = [df[df[title_x] == group][title_y] for group in groups]
+        stat, p = kruskal(*group_data)
+        test_name = "Kruskal-Wallis"
+
+    # Create a boxplot with jittered points
+    plt.figure(figsize=figsize)
+    sns.boxplot(x=title_x, y=title_y, hue=title_x, data=df, palette='Set2', legend=False, fill=False)
+    sns.stripplot(x=title_x, y=title_y, data=df, color='black', size=jittersize, jitter=True, dodge=True, alpha=0.4)
+
+    # Labels and p-value annotation
+    plt.xlabel(title_x)
+    plt.ylabel(title_y)
+    plt.text(
+        x=-0.4,
+        y=plt.ylim()[1],
+        s=f'{test_name} p = {p:.3e}',
+        verticalalignment='top',
+        horizontalalignment='left',
+        fontsize=12,
+        bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='gray')
+    )
+
+    plt.tight_layout()
+    return plt.gcf()
+
+
+def generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base):
+    """Generate dimensionality reduction plots"""
+
+    # Parse target variables
+    target_vars = [var.strip() for var in args.target_variables.split(',')]
+
+    print(f"Generating {args.method.upper()} plots for {len(target_vars)} target variable(s): {', '.join(target_vars)}")
+
+    # Check variables
+    available_vars = matched_labels['variable'].unique()
+    missing_vars = [var for var in target_vars if var not in available_vars]
+
+    if missing_vars:
+        print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}")
+        print(f"Available variables: {', '.join(available_vars)}")
+
+    # Filter to only process available variables
+    valid_vars = [var for var in target_vars if var in available_vars]
+
+    if not valid_vars:
+        raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}")
+
+    # Generate plots for each valid target variable
+    for var in valid_vars:
+        print(f"\nPlotting variable: {var}")
+
+        # Filter matched labels for current variable
+        var_labels = matched_labels[matched_labels['variable'] == var].copy()
+        var_labels = var_labels.drop_duplicates(subset='sample_id')
+
+        if var_labels.empty:
+            print(f"Warning: No data found for variable '{var}', skipping...")
+            continue
+
+        # Auto-detect color type
+        known_color_type = detect_color_type(var_labels['known_label'])
+        predicted_color_type = detect_color_type(var_labels['predicted_label'])
+
+        print(f"  Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}")
+
+        try:
+            # Plot 1: Known labels
+            print(f"  Creating known labels plot for {var}...")
+            fig_known = plot_dim_reduced(
+                matrix=embeddings,
+                labels=var_labels['known_label'],
+                method=args.method,
+                color_type=known_color_type
+            )
+
+            output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}"
+            print(f"  Saving known labels plot to: {output_path_known.name}")
+            fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight')
+
+            # Plot 2: Predicted labels
+            print(f"  Creating predicted labels plot for {var}...")
+            fig_predicted = plot_dim_reduced(
+                matrix=embeddings,
+                labels=var_labels['predicted_label'],
+                method=args.method,
+                color_type=predicted_color_type
+            )
+
+            output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}"
+            print(f"  Saving predicted labels plot to: {output_path_predicted.name}")
+            fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight')
+
+            print(f"  ✓ Successfully created plots for variable '{var}'")
+
+        except Exception as e:
+            print(f"  ✗ Error creating plots for variable '{var}': {e}")
+            continue
+
+    print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!")
+
+
+def generate_km_plots(survival_data, label_data, args, output_dir, output_name_base):
+    """Generate Kaplan-Meier plots"""
+    print("Generating Kaplan-Meier curves of risk subtypes...")
+
+    # Reset index and rename the index column to sample_id
+    survival_data = survival_data.reset_index()
+    if survival_data.columns[0] != 'sample_id':
+        survival_data = survival_data.rename(columns={survival_data.columns[0]: 'sample_id'})
+
+    # Convert survival event column to binary (0/1) based on event_value
+    # Check if the event column exists
+    if args.surv_event_var not in survival_data.columns:
+        raise ValueError(f"Column '{args.surv_event_var}' not found in survival data")
+
+    # Convert to string for comparison to handle mixed types
+    survival_data[args.surv_event_var] = survival_data[args.surv_event_var].astype(str)
+    event_value_str = str(args.event_value)
+
+    # Create binary event column (1 if matches event_value, 0 otherwise)
+    survival_data[f'{args.surv_event_var}_binary'] = (
+        survival_data[args.surv_event_var] == event_value_str
+    ).astype(int)
+
+    # Filter for survival category and class_label == '1:DECEASED'
+    label_data['class_label'] = label_data['class_label'].astype(str)
+
+    label_data = label_data[(label_data['variable'] == args.surv_event_var) & (label_data['class_label'] == event_value_str)]
+
+    # check survival data
+    for col in [args.surv_time_var, args.surv_event_var]:
+        if col not in survival_data.columns:
+            raise ValueError(f"Column '{col}' not found in survival data")
+
+    # Merge survival data with labels
+    df_deceased = pd.merge(survival_data, label_data, on='sample_id', how='inner')
+
+    if df_deceased.empty:
+        raise ValueError("No matching samples found after merging survival and label data.")
+
+    # Get risk scores
+    risk_scores = df_deceased['probability'].values
+
+    # Compute groups (e.g., median split)
+    quantiles = np.quantile(risk_scores, [0.5])
+    groups = np.digitize(risk_scores, quantiles)
+    group_labels = ['low_risk' if g == 0 else 'high_risk' for g in groups]
+
+    fig_known = plot_kaplan_meier_curves(
+        durations=df_deceased[args.surv_time_var],
+        events=df_deceased[f'{args.surv_event_var}_binary'],
+        categorical_variable=group_labels
+    )
+
+    output_path_known = output_dir / f"{output_name_base}_km_risk_subtypes.{args.format}"
+    print(f"Saving Kaplan-Meier plot to: {output_path_known.absolute()}")
+    fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight')
+
+    print("Kaplan-Meier plot saved successfully!")
+
+
+def generate_cox_plots(model, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base):
+    """Generate Cox proportional hazards plots"""
+    print("Generating Cox proportional hazards analysis...")
+
+    # Parse clinical variables
+    clinical_vars = [var.strip() for var in args.clinical_variables.split(',')]
+
+    # Validate that survival variables are included
+    required_vars = [args.surv_time_var, args.surv_event_var]
+    for var in required_vars:
+        if var not in clinical_vars:
+            clinical_vars.append(var)
+
+    print(f"Using clinical variables: {', '.join(clinical_vars)}")
+
+    # filter datasets for clinical variables
+    if all(var in clinical_train.columns and var in clinical_test.columns for var in clinical_vars):
+        df_clin_train = clinical_train[clinical_vars]
+        df_clin_test = clinical_test[clinical_vars]
+        # Drop rows with NaN in clinical variables
+        df_clin_train = df_clin_train.dropna(subset=clinical_vars)
+        df_clin_test = df_clin_test.dropna(subset=clinical_vars)
+    else:
+        raise ValueError(f"Not all clinical variables found in datasets. Available in train dataset: {clinical_train.columns.tolist()}, Available in test dataset: {clinical_test.columns.tolist()}")
+
+    # Combine
+    df_clin = pd.concat([df_clin_train, df_clin_test], axis=0)
+
+    # Get top survival markers
+    print(f"Extracting top {args.top_features} important features for {args.surv_event_var}...")
+    try:
+        imp = get_important_features(model,
+                                     var=args.surv_event_var,
+                                     top=args.top_features
+                                     )['name'].unique().tolist()
+        print(f"Top features: {', '.join(imp)}")
+    except Exception as e:
+        raise ValueError(f"Error getting important features: {e}")
+
+    # Extract feature data from omics datasets
+    try:
+        omics_test = omics_test.loc[omics_test.index.isin(imp)]
+        omics_train = omics_train.loc[omics_train.index.isin(imp)]
+        # Drop rows with NaN in omics datasets
+        omics_test = omics_test.dropna(subset=omics_test.columns)
+        omics_train = omics_train.dropna(subset=omics_train.columns)
+
+        df_imp = pd.concat([omics_train, omics_test], axis=1)
+        df_imp = df_imp.T  # Transpose to have samples as rows
+
+        print(f"Feature data shape: {df_imp.shape}")
+    except Exception as e:
+        raise ValueError(f"Error extracting feature subset: {e}")
+
+    # Combine markers with clinical variables
+    df = pd.merge(df_imp, df_clin, left_index=True, right_index=True)
+    print(f"Combined data shape: {df.shape}")
+
+    # Remove samples without survival endpoints
+    initial_samples = len(df)
+    df = df[df[args.surv_event_var].notna()]
+    final_samples = len(df)
+    print(f"Removed {initial_samples - final_samples} samples without survival data")
+
+    if df.empty:
+        raise ValueError("No samples remain after filtering for survival data")
+
+    # Convert survival event column to binary (0/1) based on event_value
+    # Convert to string for comparison to handle mixed types
+    df[args.surv_event_var] = df[args.surv_event_var].astype(str)
+    event_value_str = str(args.event_value)
+
+    df[f'{args.surv_event_var}'] = (
+        df[args.surv_event_var] == event_value_str
+    ).astype(int)
+
+    # Build Cox model
+    print(f"Building Cox model with time variable: {args.surv_time_var}, event variable: {args.surv_event_var}")
+    try:
+        coxm = build_cox_model(df,
+                               duration_col=args.surv_time_var,
+                               event_col=args.surv_event_var,
+                               crossval=args.crossval,
+                               n_splits=args.n_splits,
+                               random_state=args.random_state)
+        print("Cox model built successfully")
+    except Exception as e:
+        raise ValueError(f"Error building Cox model: {e}")
+
+    # Generate hazard ratios plot
+    try:
+        print("Generating hazard ratios plot...")
+        fig = plot_hazard_ratios(coxm)
+
+        output_path = output_dir / f"{output_name_base}_hazard_ratios.{args.format}"
+        print(f"Saving hazard ratios plot to: {output_path.absolute()}")
+        fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
+
+        print("Cox proportional hazards analysis completed successfully!")
+
+    except Exception as e:
+        raise ValueError(f"Error generating hazard ratios plot: {e}")
+
+
+def generate_plot_scatter(labels, args, output_dir, output_name_base):
+    """Generate scatter plot of known vs predicted labels"""
+    print("Generating scatter plots of known vs predicted labels...")
+
+    # Parse target values from comma-separated string
+    if args.target_value:
+        target_values = [val.strip() for val in args.target_value.split(',')]
+    else:
+        # If no target values specified, use all unique variables
+        target_values = labels['variable'].unique().tolist()
+
+    print(f"Processing target values: {target_values}")
+
+    successful_plots = 0
+    skipped_plots = 0
+
+    for target_value in target_values:
+        print(f"\nProcessing target value: '{target_value}'")
+
+        # Filter labels for the current target value
+        target_labels = labels[labels['variable'] == target_value]
+
+        if target_labels.empty:
+            print(f"  Warning: No data found for target value '{target_value}' - skipping")
+            skipped_plots += 1
+            continue
+
+        # Check if labels are numeric and convert
+        true_values = pd.to_numeric(target_labels['known_label'], errors='coerce')
+        predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce')
+
+        if true_values.isna().all() or predicted_values.isna().all():
+            print(f"No valid numeric values found for known or predicted labels in '{target_value}'")
+            skipped_plots += 1
+            continue
+
+        try:
+            print(f"  Generating scatter plot for '{target_value}'...")
+            fig = plot_scatter(true_values, predicted_values)
+
+            # Create output filename with target value
+            safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
+            if len(target_values) > 1:
+                output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
+            else:
+                output_filename = f"{output_name_base}.{args.format}"
+
+            output_path = output_dir / output_filename
+            print(f"  Saving scatter plot to: {output_path.absolute()}")
+            fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
+
+            successful_plots += 1
+            print(f"  Scatter plot for '{target_value}' generated successfully!")
+
+        except Exception as e:
+            print(f"  Error generating plot for '{target_value}': {str(e)}")
+            skipped_plots += 1
+
+    # Summary
+    print("  Summary:")
+    print(f"  Successfully generated: {successful_plots} plots")
+    print(f"  Skipped: {skipped_plots} plots")
+
+    if successful_plots == 0:
+        raise ValueError("No scatter plots could be generated. Check your data and target values.")
+
+    print("Scatter plot generation completed!")
+
+
+def generate_label_concordance_heatmap(labels, args, output_dir, output_name_base):
+    """Generate label concordance heatmap"""
+    print("Generating label concordance heatmaps...")
+
+    # Parse target values from comma-separated string
+    if args.target_value:
+        target_values = [val.strip() for val in args.target_value.split(',')]
+    else:
+        # If no target values specified, use all unique variables
+        target_values = labels['variable'].unique().tolist()
+
+    print(f"Processing target values: {target_values}")
+
+    for target_value in target_values:
+        print(f"\nProcessing target value: '{target_value}'")
+
+        # Filter labels for the current target value
+        target_labels = labels[labels['variable'] == target_value]
+
+        if target_labels.empty:
+            print(f"  Warning: No data found for target value '{target_value}' - skipping")
+            continue
+
+        true_values = target_labels['known_label'].tolist()
+        predicted_values = target_labels['predicted_label'].tolist()
+
+        try:
+            print(f"  Generating heatmap for '{target_value}'...")
+            fig = plot_label_concordance_heatmap(true_values, predicted_values)
+            plt.close(fig)
+
+            # Create output filename with target value
+            safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
+            if len(target_values) > 1:
+                output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
+            else:
+                output_filename = f"{output_name_base}.{args.format}"
+
+            output_path = output_dir / output_filename
+            print(f"  Saving heatmap to: {output_path.absolute()}")
+            fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight')
+
+        except Exception as e:
+            print(f"  Error generating heatmap for '{target_value}': {str(e)}")
+            continue
+
+    print("Label concordance heatmap generated successfully!")
+
+
+def generate_pr_curves(labels, args, output_dir, output_name_base):
+    """Generate precision-recall curves"""
+    print("Generating precision-recall curves...")
+
+    # Parse target values from comma-separated string
+    if args.target_value:
+        target_values = [val.strip() for val in args.target_value.split(',')]
+    else:
+        # If no target values specified, use all unique variables
+        target_values = labels['variable'].unique().tolist()
+
+    print(f"Processing target values: {target_values}")
+
+    for target_value in target_values:
+        print(f"\nProcessing target value: '{target_value}'")
+
+        # Filter labels for the current target value
+        target_labels = labels[labels['variable'] == target_value]
+
+        # Check if this is a regression problem (no class probabilities)
+        prob_columns = target_labels['class_label'].unique()
+        non_na_probs = target_labels['probability'].notna().sum()
+
+        print(f"  Class labels found: {list(prob_columns)}")
+        print(f"  Non-NaN probabilities: {non_na_probs}/{len(target_labels)}")
+
+        # If most probabilities are NaN, this is likely a regression problem
+        if non_na_probs < len(target_labels) * 0.1:  # Less than 10% valid probabilities
+            print("  Detected regression problem - precision-recall curves not applicable")
+            print(f"  Skipping '{target_value}' (use regression evaluation metrics instead)")
+            continue
+
+        # Debug: Check data quality
+        total_rows = len(target_labels)
+        missing_labels = target_labels['known_label'].isna().sum()
+        missing_probs = target_labels['probability'].isna().sum()
+        unique_samples = target_labels['sample_id'].nunique()
+        unique_classes = target_labels['class_label'].nunique()
+
+        print(f"  Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes")
+        print(f"  Missing data: {missing_labels} missing known_label, {missing_probs} missing probability")
+
+        if missing_labels > 0:
+            print(f"  Warning: Found {missing_labels} missing known_label values")
+            missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5]
+            print(f"  Sample IDs with missing known_label: {list(missing_samples)}")
+
+            # Remove rows with missing known_label
+            target_labels = target_labels.dropna(subset=['known_label'])
+            if target_labels.empty:
+                print(f"  Error: No valid known_label data remaining for '{target_value}' - skipping")
+                continue
+
+        # 1. Pivot to wide format
+        prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability')
+
+        print(f"  After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes")
+        print(f"  Class columns: {list(prob_df.columns)}")
+
+        # Check for NaN values in probability data
+        nan_counts = prob_df.isna().sum()
+        if nan_counts.any():
+            print(f"  NaN counts per class: {dict(nan_counts)}")
+            print(f"  Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}")
+
+            # Drop only rows where ALL probabilities are NaN
+            all_nan_rows = prob_df.isna().all(axis=1)
+            if all_nan_rows.any():
+                print(f"  Dropping {all_nan_rows.sum()} samples with all NaN probabilities")
+                prob_df = prob_df[~all_nan_rows]
+
+            remaining_nans = prob_df.isna().sum().sum()
+            if remaining_nans > 0:
+                print(f"  Warning: {remaining_nans} individual NaN values remain - filling with 0")
+                prob_df = prob_df.fillna(0)
+
+            if prob_df.empty:
+                print(f"  Error: No valid probability data remaining for '{target_value}' - skipping")
+                continue
+
+        # 2. Get true labels
+        true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id')
+
+        # 3. Align indices - only keep samples that exist in both datasets
+        common_indices = prob_df.index.intersection(true_labels_df.index)
+        if len(common_indices) == 0:
+            print(f"  Error: No common sample_ids between probability and true label data for '{target_value}' - skipping")
+            continue
+
+        print(f"  Found {len(common_indices)} samples with both probability and true label data")
+
+        # Filter both datasets to common indices
+        prob_df_aligned = prob_df.loc[common_indices]
+        y_true = true_labels_df.loc[common_indices]['known_label']
+
+        # 4. Final check for NaN values
+        if y_true.isna().any():
+            print(f"  Error: True labels still contain NaN after alignment for '{target_value}' - skipping")
+            continue
+
+        if prob_df_aligned.isna().any().any():
+            print(f"  Error: Probability data still contains NaN after alignment for '{target_value}' - skipping")
+            continue
+
+        # 5. Convert categorical labels to integer labels
+        # Create a mapping from class names to integers
+        class_names = list(prob_df_aligned.columns)
+        class_to_int = {class_name: i for i, class_name in enumerate(class_names)}
+
+        print(f"  Class mapping: {class_to_int}")
+
+        # Convert true labels to integers
+        y_true_np = y_true.map(class_to_int).to_numpy()
+        y_probs_np = prob_df_aligned.to_numpy()
+
+        print(f"  Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}")
+        print(f"  Unique true labels (integers): {set(y_true_np)}")
+        print(f"  Class labels (columns): {class_names}")
+        print(f"  Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}")
+
+        # Check for any unmapped labels (will be NaN)
+        if pd.isna(y_true_np).any():
+            print("  Error: Some true labels could not be mapped to class columns")
+            unmapped_labels = set(y_true[y_true.map(class_to_int).isna()])
+            print(f"  Unmapped labels: {unmapped_labels}")
+            print(f"  Available classes: {class_names}")
+            continue
+
+        try:
+            print(f"  Generating precision-recall curve for '{target_value}'...")
+            fig = plot_pr_curves(y_true_np, y_probs_np)
+
+            # Create output filename with target value
+            safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
+            if len(target_values) > 1:
+                output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
+            else:
+                output_filename = f"{output_name_base}.{args.format}"
+
+            output_path = output_dir / output_filename
+            print(f"  Saving precision-recall curve to: {output_path.absolute()}")
+            fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
+
+        except Exception as e:
+            print(f"  Error generating precision-recall curve for '{target_value}': {str(e)}")
+            print(f"  Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}")
+            print(f"  Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}")
+            continue
+
+    print("Precision-recall curves generated successfully!")
+
+
+def generate_roc_curves(labels, args, output_dir, output_name_base):
+    """Generate ROC curves"""
+    print("Generating ROC curves...")
+
+    # Parse target values from comma-separated string
+    if args.target_value:
+        target_values = [val.strip() for val in args.target_value.split(',')]
+    else:
+        # If no target values specified, use all unique variables
+        target_values = labels['variable'].unique().tolist()
+
+    print(f"Processing target values: {target_values}")
+
+    for target_value in target_values:
+        print(f"\nProcessing target value: '{target_value}'")
+
+        # Filter labels for the current target value
+        target_labels = labels[labels['variable'] == target_value]
+
+        # Check if this is a regression problem (no class probabilities)
+        prob_columns = target_labels['class_label'].unique()
+        non_na_probs = target_labels['probability'].notna().sum()
+
+        print(f"  Class labels found: {list(prob_columns)}")
+        print(f"  Non-NaN probabilities: {non_na_probs}/{len(target_labels)}")
+
+        # If most probabilities are NaN, this is likely a regression problem
+        if non_na_probs < len(target_labels) * 0.1:  # Less than 10% valid probabilities
+            print("  Detected regression problem - ROC curves not applicable")
+            print(f"  Skipping '{target_value}' (use regression evaluation metrics instead)")
+            continue
+
+        # Debug: Check data quality
+        total_rows = len(target_labels)
+        missing_labels = target_labels['known_label'].isna().sum()
+        missing_probs = target_labels['probability'].isna().sum()
+        unique_samples = target_labels['sample_id'].nunique()
+        unique_classes = target_labels['class_label'].nunique()
+
+        print(f"  Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes")
+        print(f"  Missing data: {missing_labels} missing known_label, {missing_probs} missing probability")
+
+        if missing_labels > 0:
+            print(f"  Warning: Found {missing_labels} missing known_label values")
+            missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5]
+            print(f"  Sample IDs with missing known_label: {list(missing_samples)}")
+
+            # Remove rows with missing known_label
+            target_labels = target_labels.dropna(subset=['known_label'])
+            if target_labels.empty:
+                print(f"  Error: No valid known_label data remaining for '{target_value}' - skipping")
+                continue
+
+        # 1. Pivot to wide format
+        prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability')
+
+        print(f"  After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes")
+        print(f"  Class columns: {list(prob_df.columns)}")
+
+        # Check for NaN values in probability data
+        nan_counts = prob_df.isna().sum()
+        if nan_counts.any():
+            print(f"  NaN counts per class: {dict(nan_counts)}")
+            print(f"  Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}")
+
+            # Drop only rows where ALL probabilities are NaN
+            all_nan_rows = prob_df.isna().all(axis=1)
+            if all_nan_rows.any():
+                print(f"  Dropping {all_nan_rows.sum()} samples with all NaN probabilities")
+                prob_df = prob_df[~all_nan_rows]
+
+            remaining_nans = prob_df.isna().sum().sum()
+            if remaining_nans > 0:
+                print(f"  Warning: {remaining_nans} individual NaN values remain - filling with 0")
+                prob_df = prob_df.fillna(0)
+
+            if prob_df.empty:
+                print(f"  Error: No valid probability data remaining for '{target_value}' - skipping")
+                continue
+
+        # 2. Get true labels
+        true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id')
+
+        # 3. Align indices - only keep samples that exist in both datasets
+        common_indices = prob_df.index.intersection(true_labels_df.index)
+        if len(common_indices) == 0:
+            print(f"  Error: No common sample_ids between probability and true label data for '{target_value}' - skipping")
+            continue
+
+        print(f"  Found {len(common_indices)} samples with both probability and true label data")
+
+        # Filter both datasets to common indices
+        prob_df_aligned = prob_df.loc[common_indices]
+        y_true = true_labels_df.loc[common_indices]['known_label']
+
+        # 4. Final check for NaN values
+        if y_true.isna().any():
+            print(f"  Error: True labels still contain NaN after alignment for '{target_value}' - skipping")
+            continue
+
+        if prob_df_aligned.isna().any().any():
+            print(f"  Error: Probability data still contains NaN after alignment for '{target_value}' - skipping")
+            continue
+
+        # 5. Convert categorical labels to integer labels
+        # Create a mapping from class names to integers
+        class_names = list(prob_df_aligned.columns)
+        class_to_int = {class_name: i for i, class_name in enumerate(class_names)}
+
+        print(f"  Class mapping: {class_to_int}")
+
+        # Convert true labels to integers
+        y_true_np = y_true.map(class_to_int).to_numpy()
+        y_probs_np = prob_df_aligned.to_numpy()
+
+        print(f"  Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}")
+        print(f"  Unique true labels (integers): {set(y_true_np)}")
+        print(f"  Class labels (columns): {class_names}")
+        print(f"  Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}")
+
+        # Check for any unmapped labels (will be NaN)
+        if pd.isna(y_true_np).any():
+            print("  Error: Some true labels could not be mapped to class columns")
+            unmapped_labels = set(y_true[y_true.map(class_to_int).isna()])
+            print(f"  Unmapped labels: {unmapped_labels}")
+            print(f"  Available classes: {class_names}")
+            continue
+
+        try:
+            print(f"  Generating ROC curve for '{target_value}'...")
+            fig = plot_roc_curves(y_true_np, y_probs_np)
+
+            # Create output filename with target value
+            safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
+            if len(target_values) > 1:
+                output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
+            else:
+                output_filename = f"{output_name_base}.{args.format}"
+
+            output_path = output_dir / output_filename
+            print(f"  Saving ROC curve to: {output_path.absolute()}")
+            fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
+
+        except Exception as e:
+            print(f"  Error generating ROC curve for '{target_value}': {str(e)}")
+            print(f"  Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}")
+            print(f"  Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}")
+            continue
+
+    print("ROC curves generated successfully!")
+
+
+def generate_box_plots(labels, args, output_dir, output_name_base):
+    """Generate box plots for model predictions"""
+
+    print("Generating box plots...")
+
+    # Parse target values from comma-separated string
+    if args.target_value:
+        target_values = [val.strip() for val in args.target_value.split(',')]
+    else:
+        # If no target values specified, use all unique variables
+        target_values = labels['variable'].unique().tolist()
+
+    print(f"Processing target values: {target_values}")
+
+    for target_value in target_values:
+        print(f"\nProcessing target value: '{target_value}'")
+
+        # Filter labels for the current target value
+        target_labels = labels[labels['variable'] == target_value]
+
+        if target_labels.empty:
+            print(f"  Warning: No data found for target value '{target_value}' - skipping")
+            continue
+
+        # Check if this is a classification problem (has probabilities)
+        prob_columns = target_labels['class_label'].unique()
+        non_na_probs = target_labels['probability'].notna().sum()
+
+        print(f"  Class labels found: {list(prob_columns)}")
+        print(f"  Non-NaN probabilities: {non_na_probs}/{len(target_labels)}")
+
+        # If most probabilities are NaN, this is likely a regression problem
+        if non_na_probs < len(target_labels) * 0.1:  # Less than 10% valid probabilities
+            print("  Detected regression problem - precision-recall curves not applicable")
+            print(f"  Skipping '{target_value}' (use regression evaluation metrics instead)")
+            continue
+
+        # Debug: Check data quality
+        total_rows = len(target_labels)
+        missing_labels = target_labels['known_label'].isna().sum()
+        missing_probs = target_labels['probability'].isna().sum()
+        unique_samples = target_labels['sample_id'].nunique()
+        unique_classes = target_labels['class_label'].nunique()
+
+        print(f"  Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes")
+        print(f"  Missing data: {missing_labels} missing known_label, {missing_probs} missing probability")
+
+        if missing_labels > 0:
+            print(f"  Warning: Found {missing_labels} missing known_label values")
+            missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5]
+            print(f"  Sample IDs with missing known_label: {list(missing_samples)}")
+
+            # Remove rows with missing known_label
+            target_labels = target_labels.dropna(subset=['known_label'])
+            if target_labels.empty:
+                print(f"  Error: No valid known_label data remaining for '{target_value}' - skipping")
+                continue
+
+        # Remove rows with missing data
+        clean_data = target_labels.dropna(subset=['known_label', 'probability'])
+
+        if clean_data.empty:
+            print("    No valid data after cleaning - skipping")
+            continue
+
+        # Get unique classes
+        classes = clean_data['class_label'].unique()
+
+        for class_label in classes:
+            print(f"    Generating box plot for class: {class_label}")
+
+            # Filter for current class
+            class_data = clean_data[clean_data['class_label'] == class_label]
+
+            try:
+                # Create the box plot
+                fig = plot_boxplot(
+                    categorical_x=class_data['known_label'],
+                    numerical_y=class_data['probability'],
+                    title_x='True Label',
+                    title_y=f'Predicted Probability ({class_label})',
+                )
+
+                # Save the plot
+                safe_class_name = str(class_label).replace('/', '_').replace('\\', '_').replace(' ', '_').replace(':', '_')
+                safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
+                output_filename = f"{output_name_base}_{safe_target_name}_{safe_class_name}.{args.format}"
+                output_path = output_dir / output_filename
+
+                print(f"      Saving box plot to: {output_path.absolute()}")
+                fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight')
+                plt.close(fig)
+
+            except Exception as e:
+                print(f"      Error generating box plot for class '{class_label}': {str(e)}")
+                continue
+
+
+def main():
+    """Main function to parse arguments and generate plots"""
+    parser = argparse.ArgumentParser(description="Generate plots using flexynesis")
+
+    # Required arguments
+    parser.add_argument("--labels", type=str, required=False,
+                        help="Path to labels file generated by flexynesis")
+
+    # Plot type
+    parser.add_argument("--plot_type", type=str, required=True,
+                        choices=['dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot'],
+                        help="Type of plot to generate: 'dimred' for dimensionality reduction, 'kaplan_meier' for survival analysis, 'cox' for Cox proportional hazards analysis, 'scatter' for scatter plots, 'concordance_heatmap' for label concordance heatmaps, 'pr_curve' for precision-recall curves, 'roc_curve' for ROC curves, or 'box_plot' for box plots.")
+
+    # Arguments for dimensionality reduction
+    parser.add_argument("--embeddings", type=str,
+                        help="Path to input data embeddings file (CSV or tabular format). Required for dimred plots.")
+    parser.add_argument("--method", type=str, default='pca', choices=['pca', 'umap'],
+                        help="Transformation method ('pca' or 'umap'). Default is 'pca'. Used for dimred plots.")
+    parser.add_argument("--target_variables", type=str, required=False,
+                        help="Comma-separated list of target variables to plot.")
+
+    # Arguments for Kaplan-Meier
+    parser.add_argument("--survival_data", type=str,
+                        help="Path to survival data file with columns: duration and event. Required for kaplan_meier plots.")
+    parser.add_argument("--surv_time_var", type=str, required=False,
+                        help="Column name for survival time")
+    parser.add_argument("--surv_event_var", type=str, required=False,
+                        help="Column name for survival event")
+    parser.add_argument("--event_value", type=str, required=False,
+                        help="Value in event column that represents an event (e.g., 'DECEASED')")
+
+    # Arguments for Cox analysis
+    parser.add_argument("--model", type=str,
+                        help="Path to trained flexynesis model (pickle file). Required for cox plots.")
+    parser.add_argument("--clinical_train", type=str,
+                        help="Path to training dataset (pickle file). Required for cox plots.")
+    parser.add_argument("--clinical_test", type=str,
+                        help="Path to test dataset (pickle file). Required for cox plots.")
+    parser.add_argument("--omics_train", type=str, default=None,
+                        help="Path to training omics dataset. Optional for cox plots.")
+    parser.add_argument("--omics_test", type=str, default=None,
+                        help="Path to test omics dataset. Optional for cox plots.")
+    parser.add_argument("--clinical_variables", type=str,
+                        help="Comma-separated list of clinical variables to include in Cox model (e.g., 'AGE,SEX,HISTOLOGICAL_DIAGNOSIS,STUDY')")
+    parser.add_argument("--top_features", type=int, default=20,
+                        help="Number of top important features to include in Cox model. Default is 5")
+    parser.add_argument("--crossval", action='store_true',
+                        help="If True, performs K-fold cross-validation and returns average C-index. Default is False")
+    parser.add_argument("--n_splits", type=int, default=5,
+                        help="Number of folds for cross-validation. Default is 5")
+    parser.add_argument("--random_state", type=int, default=42,
+                        help="Random seed for reproducibility. Default is 42")
+
+    # Arguments for scatter plot, heatmap, PR curves, ROC curves, and box plots
+    parser.add_argument("--target_value", type=str, default=None,
+                        help="Target value for scatter plot.")
+
+    # Common arguments
+    parser.add_argument("--output_dir", type=str, default='output',
+                        help="Output directory. Default is 'output'")
+    parser.add_argument("--output_name", type=str, default=None,
+                        help="Output filename base")
+    parser.add_argument("--format", type=str, default='jpg', choices=['png', 'pdf', 'svg', 'jpg'],
+                        help="Output format for the plot. Default is 'jpg'")
+    parser.add_argument("--dpi", type=int, default=300,
+                        help="DPI for the output image. Default is 300")
+
+    args = parser.parse_args()
+
+    try:
+        # validate plot type
+        if not args.plot_type:
+            raise ValueError("Please specify a plot type using --plot_type")
+        if args.plot_type not in ['dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot']:
+            raise ValueError(f"Invalid plot type: {args.plot_type}. Must be one of: 'dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot'")
+
+        # Validate plot type requirements
+        if args.plot_type in ['dimred']:
+            if not args.embeddings:
+                raise ValueError("--embeddings is required when plot_type is 'dimred'")
+            if not os.path.isfile(args.embeddings):
+                raise FileNotFoundError(f"embeddings file not found: {args.embeddings}")
+            if not args.labels:
+                raise ValueError("--labels is required for dimensionality reduction plots")
+            if not args.method:
+                raise ValueError("--method is required for dimensionality reduction plots")
+            if not args.target_variables:
+                raise ValueError("--target_variables is required for dimensionality reduction plots")
+
+        if args.plot_type in ['kaplan_meier']:
+            if not args.survival_data:
+                raise ValueError("--survival_data is required when plot_type is 'kaplan_meier'")
+            if not os.path.isfile(args.survival_data):
+                raise FileNotFoundError(f"Survival data file not found: {args.survival_data}")
+            if not args.labels:
+                raise ValueError("--labels is required for dimensionality reduction plots")
+            if not args.method:
+                raise ValueError("--method is required for dimensionality reduction plots")
+            if not args.surv_time_var:
+                raise ValueError("--surv_time_var is required for Kaplan-Meier plots")
+            if not args.surv_event_var:
+                raise ValueError("--surv_event_var is required for Kaplan-Meier plots")
+            if not args.event_value:
+                raise ValueError("--event_value is required for Kaplan-Meier plots")
+
+        if args.plot_type in ['cox']:
+            if not args.model:
+                raise ValueError("--model is required when plot_type is 'cox'")
+            if not os.path.isfile(args.model):
+                raise FileNotFoundError(f"Model file not found: {args.model}")
+            if not args.clinical_train:
+                raise ValueError("--clinical_train is required when plot_type is 'cox'")
+            if not os.path.isfile(args.clinical_train):
+                raise FileNotFoundError(f"Training dataset file not found: {args.clinical_train}")
+            if not args.clinical_test:
+                raise ValueError("--clinical_test is required when plot_type is 'cox'")
+            if not os.path.isfile(args.clinical_test):
+                raise FileNotFoundError(f"Test dataset file not found: {args.clinical_test}")
+            if not args.omics_train:
+                raise ValueError("--omics_train is required when plot_type is 'cox'")
+            if not os.path.isfile(args.omics_train):
+                raise FileNotFoundError(f"Training omics dataset file not found: {args.omics_train}")
+            if not args.omics_test:
+                raise ValueError("--omics_test is required when plot_type is 'cox'")
+            if not os.path.isfile(args.omics_test):
+                raise FileNotFoundError(f"Test omics dataset file not found: {args.omics_test}")
+            if not args.surv_time_var:
+                raise ValueError("--surv_time_var is required for Cox plots")
+            if not args.surv_event_var:
+                raise ValueError("--surv_event_var is required for Cox plots")
+            if not args.clinical_variables:
+                raise ValueError("--clinical_variables is required for Cox plots")
+            if not isinstance(args.top_features, int) or args.top_features <= 0:
+                raise ValueError("--top_features must be a positive integer")
+            if not args.event_value:
+                raise ValueError("--event_value is required for Kaplan-Meier plots")
+            if not args.crossval:
+                args.crossval = False
+            if not isinstance(args.n_splits, int) or args.n_splits <= 0:
+                raise ValueError("--n_splits must be a positive integer")
+            if not isinstance(args.random_state, int):
+                raise ValueError("--random_state must be an integer")
+
+        if args.plot_type in ['scatter']:
+            if not args.labels:
+                raise ValueError("--labels is required for scatter plots")
+            if not args.target_value:
+                print("--target_value is not specified, using all unique variables from labels")
+            if not os.path.isfile(args.labels):
+                raise FileNotFoundError(f"Labels file not found: {args.labels}")
+
+        if args.plot_type in ['concordance_heatmap']:
+            if not args.labels:
+                raise ValueError("--labels is required for concordance heatmap")
+            if not args.target_value:
+                print("--target_value is not specified, using all unique variables from labels")
+            if not os.path.isfile(args.labels):
+                raise FileNotFoundError(f"Labels file not found: {args.labels}")
+
+        if args.plot_type in ['pr_curve']:
+            if not args.labels:
+                raise ValueError("--labels is required for precision-recall curves")
+            if not args.target_value:
+                print("--target_value is not specified, using all unique variables from labels")
+            if not os.path.isfile(args.labels):
+                raise FileNotFoundError(f"Labels file not found: {args.labels}")
+
+        if args.plot_type in ['roc_curve']:
+            if not args.labels:
+                raise ValueError("--labels is required for ROC curves")
+            if not args.target_value:
+                print("--target_value is not specified, using all unique variables from labels")
+            if not os.path.isfile(args.labels):
+                raise FileNotFoundError(f"Labels file not found: {args.labels}")
+
+        if args.plot_type in ['box_plot']:
+            if not args.labels:
+                raise ValueError("--labels is required for box plots")
+            if not args.target_value:
+                print("--target_value is not specified, using all unique variables from labels")
+            if not os.path.isfile(args.labels):
+                raise FileNotFoundError(f"Labels file not found: {args.labels}")
+
+        # Validate other arguments
+        if args.method not in ['pca', 'umap']:
+            raise ValueError("Method must be 'pca' or 'umap'")
+
+        # Create output directory
+        output_dir = Path(args.output_dir)
+        output_dir.mkdir(parents=True, exist_ok=True)
+        print(f"Output directory: {output_dir.absolute()}")
+
+        # Generate output filename base
+        if args.output_name:
+            output_name_base = args.output_name
+        else:
+            if args.plot_type == 'dimred':
+                embeddings_name = Path(args.embeddings).stem
+                output_name_base = f"{embeddings_name}_{args.method}"
+            elif args.plot_type == 'kaplan_meier':
+                survival_name = Path(args.survival_data).stem
+                output_name_base = f"{survival_name}_km"
+            elif args.plot_type == 'cox':
+                model_name = Path(args.model).stem
+                output_name_base = f"{model_name}_cox"
+            elif args.plot_type == 'scatter':
+                labels_name = Path(args.labels).stem
+                output_name_base = f"{labels_name}_scatter"
+            elif args.plot_type == 'concordance_heatmap':
+                labels_name = Path(args.labels).stem
+                output_name_base = f"{labels_name}_concordance"
+            elif args.plot_type == 'pr_curve':
+                labels_name = Path(args.labels).stem
+                output_name_base = f"{labels_name}_pr_curves"
+            elif args.plot_type == 'roc_curve':
+                labels_name = Path(args.labels).stem
+                output_name_base = f"{labels_name}_roc_curves"
+            elif args.plot_type == 'box_plot':
+                labels_name = Path(args.labels).stem
+                output_name_base = f"{labels_name}_box_plot"
+
+        # Generate plots based on type
+        if args.plot_type in ['dimred']:
+            # Load labels
+            print(f"Loading labels from: {args.labels}")
+            label_data = load_labels(args.labels)
+            # Load embeddings data
+            print(f"Loading embeddings from: {args.embeddings}")
+            embeddings, sample_names = load_embeddings(args.embeddings)
+            print(f"embeddings shape: {embeddings.shape}")
+
+            # Match samples to embeddings
+            matched_labels = match_samples_to_embeddings(sample_names, label_data)
+            print(f"Successfully matched {len(matched_labels)} samples for dimensionality reduction")
+
+            generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base)
+
+        elif args.plot_type in ['kaplan_meier']:
+            # Load labels
+            print(f"Loading labels from: {args.labels}")
+            label_data = load_labels(args.labels)
+            # Load survival data
+            print(f"Loading survival data from: {args.survival_data}")
+            survival_data = load_survival_data(args.survival_data)
+            print(f"Survival data shape: {survival_data.shape}")
+
+            generate_km_plots(survival_data, label_data, args, output_dir, output_name_base)
+
+        elif args.plot_type in ['cox']:
+            # Load model and datasets
+            print(f"Loading model from: {args.model}")
+            model = load_model(args.model)
+            print(f"Loading training dataset from: {args.clinical_train}")
+            clinical_train = load_omics(args.clinical_train)
+            print(f"Loading test dataset from: {args.clinical_test}")
+            clinical_test = load_omics(args.clinical_test)
+            print(f"Loading training omics dataset from: {args.omics_train}")
+            omics_train = load_omics(args.omics_train)
+            print(f"Loading test omics dataset from: {args.omics_test}")
+            omics_test = load_omics(args.omics_test)
+
+            generate_cox_plots(model, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base)
+
+        elif args.plot_type in ['scatter']:
+            # Load labels
+            print(f"Loading labels from: {args.labels}")
+            label_data = load_labels(args.labels)
+
+            generate_plot_scatter(label_data, args, output_dir, output_name_base)
+
+        elif args.plot_type in ['concordance_heatmap']:
+            # Load labels
+            print(f"Loading labels from: {args.labels}")
+            label_data = load_labels(args.labels)
+
+            generate_label_concordance_heatmap(label_data, args, output_dir, output_name_base)
+
+        elif args.plot_type in ['pr_curve']:
+            # Load labels
+            print(f"Loading labels from: {args.labels}")
+            label_data = load_labels(args.labels)
+
+            generate_pr_curves(label_data, args, output_dir, output_name_base)
+
+        elif args.plot_type in ['roc_curve']:
+            # Load labels
+            print(f"Loading labels from: {args.labels}")
+            label_data = load_labels(args.labels)
+
+            generate_roc_curves(label_data, args, output_dir, output_name_base)
+
+        elif args.plot_type in ['box_plot']:
+            # Load labels
+            print(f"Loading labels from: {args.labels}")
+            label_data = load_labels(args.labels)
+
+            generate_box_plots(label_data, args, output_dir, output_name_base)
+
+        print("All plots generated successfully!")
+
+    except (FileNotFoundError, ValueError, pd.errors.ParserError) as e:
+        print(f"Error: {e}")
+        return 1
+
+    return 0
+
+
+if __name__ == "__main__":
+    exit(main())