diff flexynesis_plot.py @ 6:33816f44fc7d draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flexynesis commit 6b520305ec30e6dc37eba92c67a5368cea0fc5ad
author bgruening
date Wed, 23 Jul 2025 07:49:41 +0000
parents 466b593fd87e
children
line wrap: on
line diff
--- a/flexynesis_plot.py	Fri Jul 04 14:57:40 2025 +0000
+++ b/flexynesis_plot.py	Wed Jul 23 07:49:41 2025 +0000
@@ -11,10 +11,8 @@
 import numpy as np
 import pandas as pd
 import seaborn as sns
-import torch
 from flexynesis import (
     build_cox_model,
-    get_important_features,
     plot_dim_reduced,
     plot_hazard_ratios,
     plot_kaplan_meier_curves,
@@ -55,35 +53,13 @@
         elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
             df = pd.read_csv(labels_input, sep='\t')
 
-        # Check if this is the specific format with sample_id, known_label, predicted_label
-        required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
-        if all(col in df.columns for col in required_cols):
-            return df
-        else:
-            raise ValueError(f"Labels file {labels_input} does not contain required columns: {required_cols}")
+        print(f"available columns: {df.columns.tolist()}")
+        return df
 
     except Exception as e:
         raise ValueError(f"Error loading labels from {labels_input}: {e}") from e
 
 
-def load_survival_data(survival_path):
-    """Load survival data from a file. First column should be sample_id"""
-    try:
-        # Determine file extension
-        file_ext = Path(survival_path).suffix.lower()
-
-        if file_ext == '.csv':
-            df = pd.read_csv(survival_path, index_col=0)
-        elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
-            df = pd.read_csv(survival_path, sep='\t', index_col=0)
-        else:
-            raise ValueError(f"Unsupported file extension: {file_ext}")
-        return df
-
-    except Exception as e:
-        raise ValueError(f"Error loading survival data from {survival_path}: {e}") from e
-
-
 def load_omics(omics_path):
     """Load omics data from a file. First column should be features"""
     try:
@@ -102,19 +78,17 @@
         raise ValueError(f"Error loading omics data from {omics_path}: {e}") from e
 
 
-def load_model(model_path):
-    """Load flexynesis model from pickle file"""
-    try:
-        with open(model_path, 'rb') as f:
-            model = torch.load(f, weights_only=False)
-        return model
-    except Exception as e:
-        raise ValueError(f"Error loading model from {model_path}: {e}") from e
+def match_samples_to_embeddings(sample_names, labels):
+    """Filter label data to match sample names in the embeddings"""
+    # Create a DataFrame from sample_names to preserve order
+    sample_df = pd.DataFrame({'sample_names': sample_names})
 
+    # left_join
+    first_column = labels.columns[0]
+    df_matched = sample_df.merge(labels, left_on='sample_names', right_on=first_column, how='left')
 
-def match_samples_to_embeddings(sample_names, label_data):
-    """Filter label data to match sample names in the embeddings"""
-    df_matched = label_data[label_data['sample_id'].isin(sample_names)]
+    # remove sample_names to keep the initial structure
+    df_matched = df_matched.drop('sample_names', axis=1)
     return df_matched
 
 
@@ -214,124 +188,149 @@
 def generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base):
     """Generate dimensionality reduction plots"""
 
-    # Parse target values from comma-separated string
-    if args.target_value:
-        target_values = [val.strip() for val in args.target_value.split(',')]
-    else:
-        # If no target values specified, use all unique variables
-        target_values = matched_labels['variable'].unique().tolist()
+    # Check if this is the specific format with sample_id, known_label, predicted_label
+    required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
+    is_flexynesis_format = all(col in matched_labels.columns for col in required_cols)
+
+    if not args.color:
+        if is_flexynesis_format:
+            print("Detected flexynesis labels format")
+            print(f"Generating {args.method.upper()} plots for known and predicted labels...")
+        else:
+            print("Labels are not in flexynesis format (Custom labels), please specify a color variable with --color")
+
+        # Parse target values from comma-separated string
+        if args.target_value:
+            target_values = [val.strip() for val in args.target_value.split(',')]
+        else:
+            # If no target values specified, use all unique variables
+            target_values = matched_labels['variable'].unique().tolist()
+
+        print(f"Generating {args.method.upper()} plots for {len(target_values)} target variable(s): {', '.join(target_values)}")
 
-    print(f"Generating {args.method.upper()} plots for {len(target_values)} target variable(s): {', '.join(target_values)}")
+        # Check variables
+        available_vars = matched_labels['variable'].unique()
+        missing_vars = [var for var in target_values if var not in available_vars]
+
+        if missing_vars:
+            print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}")
+            print(f"Available variables: {', '.join(available_vars)}")
+
+        # Filter to only process available variables
+        valid_vars = [var for var in target_values if var in available_vars]
 
-    # Check variables
-    available_vars = matched_labels['variable'].unique()
-    missing_vars = [var for var in target_values if var not in available_vars]
+        if not valid_vars:
+            raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}")
+
+        # Generate plots for each valid target variable
+        for var in valid_vars:
+            print(f"\nPlotting variable: {var}")
 
-    if missing_vars:
-        print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}")
-        print(f"Available variables: {', '.join(available_vars)}")
+            # Filter matched labels for current variable
+            var_labels = matched_labels[matched_labels['variable'] == var].copy()
+            var_labels = var_labels.drop_duplicates(subset='sample_id')
+
+            if var_labels.empty:
+                print(f"Warning: No data found for variable '{var}', skipping...")
+                continue
 
-    # Filter to only process available variables
-    valid_vars = [var for var in target_values if var in available_vars]
+            # Auto-detect color type
+            known_color_type = detect_color_type(var_labels['known_label'])
+            predicted_color_type = detect_color_type(var_labels['predicted_label'])
+
+            print(f"  Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}")
 
-    if not valid_vars:
-        raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}")
+            try:
+                # Plot 1: Known labels
+                print(f"  Creating known labels plot for {var}...")
+                fig_known = plot_dim_reduced(
+                    matrix=embeddings,
+                    labels=var_labels['known_label'],
+                    method=args.method,
+                    color_type=known_color_type
+                )
+
+                output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}"
+                print(f"  Saving known labels plot to: {output_path_known.name}")
+                fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight')
 
-    # Generate plots for each valid target variable
-    for var in valid_vars:
-        print(f"\nPlotting variable: {var}")
+                # Plot 2: Predicted labels
+                print(f"  Creating predicted labels plot for {var}...")
+                fig_predicted = plot_dim_reduced(
+                    matrix=embeddings,
+                    labels=var_labels['predicted_label'],
+                    method=args.method,
+                    color_type=predicted_color_type
+                )
+
+                output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}"
+                print(f"  Saving predicted labels plot to: {output_path_predicted.name}")
+                fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight')
 
-        # Filter matched labels for current variable
-        var_labels = matched_labels[matched_labels['variable'] == var].copy()
-        var_labels = var_labels.drop_duplicates(subset='sample_id')
+                print(f"  ✓ Successfully created plots for variable '{var}'")
+
+            except Exception as e:
+                print(f"  ✗ Error creating plots for variable '{var}': {e}")
+                continue
 
-        if var_labels.empty:
-            print(f"Warning: No data found for variable '{var}', skipping...")
-            continue
+        print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!")
+
+    else:
+        # check if the color variable exists in matched_labels
+        if args.color not in matched_labels.columns:
+            raise ValueError(f"Color variable '{args.color}' not found in matched labels. Available columns: {matched_labels.columns.tolist()}")
 
         # Auto-detect color type
-        known_color_type = detect_color_type(var_labels['known_label'])
-        predicted_color_type = detect_color_type(var_labels['predicted_label'])
-
-        print(f"  Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}")
+        color_type = detect_color_type(matched_labels[args.color])
 
-        try:
-            # Plot 1: Known labels
-            print(f"  Creating known labels plot for {var}...")
-            fig_known = plot_dim_reduced(
-                matrix=embeddings,
-                labels=var_labels['known_label'],
-                method=args.method,
-                color_type=known_color_type
-            )
-
-            output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}"
-            print(f"  Saving known labels plot to: {output_path_known.name}")
-            fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight')
+        print(f"  Auto-detected color type: {color_type}")
 
-            # Plot 2: Predicted labels
-            print(f"  Creating predicted labels plot for {var}...")
-            fig_predicted = plot_dim_reduced(
-                matrix=embeddings,
-                labels=var_labels['predicted_label'],
-                method=args.method,
-                color_type=predicted_color_type
-            )
+        # Plot: Specified color column
+        print(f"  Creating plot for {args.color}...")
+        fig = plot_dim_reduced(
+            matrix=embeddings,
+            labels=matched_labels[args.color],
+            method=args.method,
+            color_type=color_type
+        )
 
-            output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}"
-            print(f"  Saving predicted labels plot to: {output_path_predicted.name}")
-            fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight')
-
-            print(f"  ✓ Successfully created plots for variable '{var}'")
+        output_path = output_dir / f"{output_name_base}_{args.color}.{args.format}"
+        print(f"  Saving plot to: {output_path.name}")
+        fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
 
-        except Exception as e:
-            print(f"  ✗ Error creating plots for variable '{var}': {e}")
-            continue
-
-    print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!")
+        print(f"  ✓ Successfully created plot for variable '{args.color}'")
 
 
-def generate_km_plots(survival_data, label_data, args, output_dir, output_name_base):
+def generate_km_plots(survival_data, labels, args, output_dir, output_name_base):
     """Generate Kaplan-Meier plots"""
+
+    # Check if this is the specific format with sample_id, known_label, predicted_label
+    required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
+    is_flexynesis_format = all(col in labels.columns for col in required_cols)
+
+    if not is_flexynesis_format:
+        raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.")
+
     print("Generating Kaplan-Meier curves of risk subtypes...")
 
-    # Reset index and rename the index column to sample_id
-    survival_data = survival_data.reset_index()
     if survival_data.columns[0] != 'sample_id':
         survival_data = survival_data.rename(columns={survival_data.columns[0]: 'sample_id'})
 
-    # Convert survival event column to binary (0/1) based on event_value
     # Check if the event column exists
     if args.surv_event_var not in survival_data.columns:
         raise ValueError(f"Column '{args.surv_event_var}' not found in survival data")
 
-    # Convert to string for comparison to handle mixed types
-    survival_data[args.surv_event_var] = survival_data[args.surv_event_var].astype(str)
-    event_value_str = str(args.event_value)
-
-    # Create binary event column (1 if matches event_value, 0 otherwise)
-    survival_data[f'{args.surv_event_var}_binary'] = (
-        survival_data[args.surv_event_var] == event_value_str
-    ).astype(int)
-
-    # Filter for survival category and class_label == '1:DECEASED'
-    label_data['class_label'] = label_data['class_label'].astype(str)
-
-    label_data = label_data[(label_data['variable'] == args.surv_event_var) & (label_data['class_label'] == event_value_str)]
-
-    # check survival data
-    for col in [args.surv_time_var, args.surv_event_var]:
-        if col not in survival_data.columns:
-            raise ValueError(f"Column '{col}' not found in survival data")
+    labels = labels[(labels['variable'] == args.surv_event_var)]
 
     # Merge survival data with labels
-    df_deceased = pd.merge(survival_data, label_data, on='sample_id', how='inner')
+    df_deceased = pd.merge(survival_data, labels, on='sample_id', how='inner')
+    df_deceased = df_deceased.dropna(subset=[args.surv_time_var, args.surv_event_var])
 
     if df_deceased.empty:
         raise ValueError("No matching samples found after merging survival and label data.")
 
     # Get risk scores
-    risk_scores = df_deceased['probability'].values
+    risk_scores = df_deceased['predicted_label'].values
 
     # Compute groups (e.g., median split)
     quantiles = np.quantile(risk_scores, [0.5])
@@ -340,7 +339,7 @@
 
     fig_known = plot_kaplan_meier_curves(
         durations=df_deceased[args.surv_time_var],
-        events=df_deceased[f'{args.surv_event_var}_binary'],
+        events=df_deceased[args.surv_event_var],
         categorical_variable=group_labels
     )
 
@@ -351,12 +350,21 @@
     print("Kaplan-Meier plot saved successfully!")
 
 
-def generate_cox_plots(model, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base):
+def generate_cox_plots(important_features, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base):
     """Generate Cox proportional hazards plots"""
     print("Generating Cox proportional hazards analysis...")
 
+    # Check if this is the specific format with target_variable, importance
+    required_cols = ['target_variable', 'layer', 'importance']
+    is_flexynesis_format = all(col in important_features.columns for col in required_cols)
+
+    if not is_flexynesis_format:
+        raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid important_features file with the required columns, {required_cols}.")
+
     # Parse clinical variables
-    clinical_vars = [var.strip() for var in args.clinical_variables.split(',')]
+    clinical_vars = []
+    if args.clinical_variables:
+        clinical_vars = [var.strip() for var in args.clinical_variables.split(',')]
 
     # Validate that survival variables are included
     required_vars = [args.surv_time_var, args.surv_event_var]
@@ -382,10 +390,22 @@
     # Get top survival markers
     print(f"Extracting top {args.top_features} important features for {args.surv_event_var}...")
     try:
-        imp = get_important_features(model,
-                                     var=args.surv_event_var,
-                                     top=args.top_features
-                                     )['name'].unique().tolist()
+        print(f"Loading {args.top_features} important features from: {args.important_features}")
+        imp_features = load_labels(args.important_features)
+        imp_features = imp_features[imp_features['target_variable'] == args.surv_event_var]
+        if args.layer not in imp_features['layer'].unique():
+            print(f"Available class labels: {imp_features['layer'].unique()}")
+            raise ValueError(f"Class label '{args.layer}' not found in important features data: {args.important_features}")
+        imp_features = imp_features[imp_features['layer'] == args.layer]
+        if imp_features.empty:
+            raise ValueError(f"No important features found for target variable '{args.surv_event_var}' in {args.important_features}")
+        imp_features = imp_features.sort_values(by='importance', ascending=False)
+
+        if len(imp_features) < args.top_features:
+            raise ValueError(f"Requested top {args.top_features} features, but only {len(imp_features)} available in {args.important_features}")
+
+        imp = imp_features['name'].unique().tolist()[0:args.top_features]
+
         print(f"Top features: {', '.join(imp)}")
     except Exception as e:
         raise ValueError(f"Error getting important features: {e}")
@@ -418,15 +438,6 @@
     if df.empty:
         raise ValueError("No samples remain after filtering for survival data")
 
-    # Convert survival event column to binary (0/1) based on event_value
-    # Convert to string for comparison to handle mixed types
-    df[args.surv_event_var] = df[args.surv_event_var].astype(str)
-    event_value_str = str(args.event_value)
-
-    df[f'{args.surv_event_var}'] = (
-        df[args.surv_event_var] == event_value_str
-    ).astype(int)
-
     # Build Cox model
     print(f"Building Cox model with time variable: {args.surv_time_var}, event variable: {args.surv_event_var}")
     try:
@@ -459,124 +470,190 @@
     """Generate scatter plot of known vs predicted labels"""
     print("Generating scatter plots of known vs predicted labels...")
 
-    # Parse target values from comma-separated string
-    if args.target_value:
-        target_values = [val.strip() for val in args.target_value.split(',')]
-    else:
-        # If no target values specified, use all unique variables
-        target_values = labels['variable'].unique().tolist()
+    # Check if this is the specific format with sample_id, known_label, predicted_label
+    required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
+    is_flexynesis_format = all(col in labels.columns for col in required_cols)
+
+    if is_flexynesis_format:
+        # Parse target values from comma-separated string
+        if args.target_value:
+            target_values = [val.strip() for val in args.target_value.split(',')]
+        else:
+            # If no target values specified, use all unique variables
+            target_values = labels['variable'].unique().tolist()
+
+        print(f"Processing target values: {target_values}")
 
-    print(f"Processing target values: {target_values}")
+        successful_plots = 0
+        skipped_plots = 0
+
+        for target_value in target_values:
+            print(f"\nProcessing target value: '{target_value}'")
+
+            # Filter labels for the current target value
+            target_labels = labels[labels['variable'] == target_value]
+
+            if target_labels.empty:
+                print(f"  Warning: No data found for target value '{target_value}' - skipping")
+                skipped_plots += 1
+                continue
+
+            # Check if labels are numeric and convert
+            true_values = pd.to_numeric(target_labels['known_label'], errors='coerce')
+            predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce')
 
-    successful_plots = 0
-    skipped_plots = 0
+            if true_values.isna().all() or predicted_values.isna().all():
+                print(f"No valid numeric values found for known or predicted labels in '{target_value}'")
+                skipped_plots += 1
+                continue
+
+            try:
+                print(f"  Generating scatter plot for '{target_value}'...")
+                fig = plot_scatter(true_values, predicted_values)
 
-    for target_value in target_values:
-        print(f"\nProcessing target value: '{target_value}'")
+                # Create output filename with target value
+                safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
+                output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
+
+                output_path = output_dir / output_filename
+                print(f"  Saving scatter plot to: {output_path.absolute()}")
+                fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
+
+                successful_plots += 1
+                print(f"  Scatter plot for '{target_value}' generated successfully!")
 
-        # Filter labels for the current target value
-        target_labels = labels[labels['variable'] == target_value]
+            except Exception as e:
+                print(f"  Error generating plot for '{target_value}': {str(e)}")
+                skipped_plots += 1
+
+        # Summary
+        print("  Summary:")
+        print(f"  Successfully generated: {successful_plots} plots")
+        print(f"  Skipped: {skipped_plots} plots")
 
-        if target_labels.empty:
-            print(f"  Warning: No data found for target value '{target_value}' - skipping")
-            skipped_plots += 1
-            continue
+        if successful_plots == 0:
+            raise ValueError("No scatter plots could be generated. Check your data and target values.")
+
+        print("Scatter plot generation completed!")
+
+    if not is_flexynesis_format:
+        print("Labels are not in flexynesis format (Custom labels)")
+
+        if not args.true_label or not args.predicted_label:
+            raise ValueError("For custom labels, please specify --true_label and --predicted_label arguments.")
 
         # Check if labels are numeric and convert
-        true_values = pd.to_numeric(target_labels['known_label'], errors='coerce')
-        predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce')
+        true_values = pd.to_numeric(labels[args.true_label], errors='coerce')
+        predicted_values = pd.to_numeric(labels[args.predicted_label], errors='coerce')
 
         if true_values.isna().all() or predicted_values.isna().all():
-            print(f"No valid numeric values found for known or predicted labels in '{target_value}'")
-            skipped_plots += 1
-            continue
+            print("No valid numeric values found for known or predicted labels")
 
         try:
-            print(f"  Generating scatter plot for '{target_value}'...")
+            print("  Generating scatter plot...")
             fig = plot_scatter(true_values, predicted_values)
 
             # Create output filename with target value
-            safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
-            if len(target_values) > 1:
-                output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
-            else:
-                output_filename = f"{output_name_base}.{args.format}"
+            output_filename = f"{output_name_base}.{args.format}"
 
             output_path = output_dir / output_filename
             print(f"  Saving scatter plot to: {output_path.absolute()}")
             fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
 
-            successful_plots += 1
-            print(f"  Scatter plot for '{target_value}' generated successfully!")
-
         except Exception as e:
-            print(f"  Error generating plot for '{target_value}': {str(e)}")
-            skipped_plots += 1
+            print(f"  Error generating plot: {str(e)}")
 
-    # Summary
-    print("  Summary:")
-    print(f"  Successfully generated: {successful_plots} plots")
-    print(f"  Skipped: {skipped_plots} plots")
-
-    if successful_plots == 0:
-        raise ValueError("No scatter plots could be generated. Check your data and target values.")
-
-    print("Scatter plot generation completed!")
+        print("Scatter plot generation completed!")
 
 
 def generate_label_concordance_heatmap(labels, args, output_dir, output_name_base):
     """Generate label concordance heatmap"""
     print("Generating label concordance heatmaps...")
 
-    # Parse target values from comma-separated string
-    if args.target_value:
-        target_values = [val.strip() for val in args.target_value.split(',')]
-    else:
-        # If no target values specified, use all unique variables
-        target_values = labels['variable'].unique().tolist()
+    # Check if this is the specific format with sample_id, known_label, predicted_label
+    required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
+    is_flexynesis_format = all(col in labels.columns for col in required_cols)
+
+    if is_flexynesis_format:
+        # Parse target values from comma-separated string
+        if args.target_value:
+            target_values = [val.strip() for val in args.target_value.split(',')]
+        else:
+            # If no target values specified, use all unique variables
+            target_values = labels['variable'].unique().tolist()
 
-    print(f"Processing target values: {target_values}")
+        print(f"Processing target values: {target_values}")
+
+        for target_value in target_values:
+            print(f"\nProcessing target value: '{target_value}'")
+
+            # Filter labels for the current target value
+            target_labels = labels[labels['variable'] == target_value]
+
+            if target_labels.empty:
+                print(f"  Warning: No data found for target value '{target_value}' - skipping")
+                continue
+
+            true_values = target_labels['known_label'].tolist()
+            predicted_values = target_labels['predicted_label'].tolist()
 
-    for target_value in target_values:
-        print(f"\nProcessing target value: '{target_value}'")
+            try:
+                print(f"  Generating heatmap for '{target_value}'...")
+                fig = plot_label_concordance_heatmap(true_values, predicted_values)
+                plt.close(fig)
 
-        # Filter labels for the current target value
-        target_labels = labels[labels['variable'] == target_value]
+                # Create output filename with target value
+                safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
+                output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
+
+                output_path = output_dir / output_filename
+                print(f"  Saving heatmap to: {output_path.absolute()}")
+                fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight')
 
-        if target_labels.empty:
-            print(f"  Warning: No data found for target value '{target_value}' - skipping")
-            continue
+            except Exception as e:
+                print(f"  Error generating heatmap for '{target_value}': {str(e)}")
+                continue
+
+        print("Label concordance heatmap generated successfully!")
 
-        true_values = target_labels['known_label'].tolist()
-        predicted_values = target_labels['predicted_label'].tolist()
+    if not is_flexynesis_format:
+        print("Labels are not in flexynesis format (Custom labels)")
+
+        if not args.true_label or not args.predicted_label:
+            raise ValueError("For custom labels, please specify --true_label and --predicted_label arguments.")
+
+        true_values = labels[args.true_label].tolist()
+        predicted_values = labels[args.predicted_label].tolist()
 
         try:
-            print(f"  Generating heatmap for '{target_value}'...")
+            print("  Generating heatmap for...")
             fig = plot_label_concordance_heatmap(true_values, predicted_values)
             plt.close(fig)
 
             # Create output filename with target value
-            safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
-            if len(target_values) > 1:
-                output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
-            else:
-                output_filename = f"{output_name_base}.{args.format}"
+            output_filename = f"{output_name_base}.{args.format}"
 
             output_path = output_dir / output_filename
             print(f"  Saving heatmap to: {output_path.absolute()}")
             fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight')
 
         except Exception as e:
-            print(f"  Error generating heatmap for '{target_value}': {str(e)}")
-            continue
+            print(f"  Error generating heatmap': {str(e)}")
 
-    print("Label concordance heatmap generated successfully!")
+        print("Label concordance heatmap generated successfully!")
 
 
 def generate_pr_curves(labels, args, output_dir, output_name_base):
     """Generate precision-recall curves"""
     print("Generating precision-recall curves...")
 
+    # Check if this is the specific format with sample_id, known_label, predicted_label
+    required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
+    is_flexynesis_format = all(col in labels.columns for col in required_cols)
+
+    if not is_flexynesis_format:
+        raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.")
+
     # Parse target values from comma-separated string
     if args.target_value:
         target_values = [val.strip() for val in args.target_value.split(',')]
@@ -707,10 +784,7 @@
 
             # Create output filename with target value
             safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
-            if len(target_values) > 1:
-                output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
-            else:
-                output_filename = f"{output_name_base}.{args.format}"
+            output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
 
             output_path = output_dir / output_filename
             print(f"  Saving precision-recall curve to: {output_path.absolute()}")
@@ -729,6 +803,13 @@
     """Generate ROC curves"""
     print("Generating ROC curves...")
 
+    # Check if this is the specific format with sample_id, known_label, predicted_label
+    required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
+    is_flexynesis_format = all(col in labels.columns for col in required_cols)
+
+    if not is_flexynesis_format:
+        raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.")
+
     # Parse target values from comma-separated string
     if args.target_value:
         target_values = [val.strip() for val in args.target_value.split(',')]
@@ -859,10 +940,7 @@
 
             # Create output filename with target value
             safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
-            if len(target_values) > 1:
-                output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
-            else:
-                output_filename = f"{output_name_base}.{args.format}"
+            output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
 
             output_path = output_dir / output_filename
             print(f"  Saving ROC curve to: {output_path.absolute()}")
@@ -880,6 +958,13 @@
 def generate_box_plots(labels, args, output_dir, output_name_base):
     """Generate box plots for model predictions"""
 
+    # Check if this is the specific format with sample_id, known_label, predicted_label
+    required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
+    is_flexynesis_format = all(col in labels.columns for col in required_cols)
+
+    if not is_flexynesis_format:
+        raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.")
+
     print("Generating box plots...")
 
     # Parse target values from comma-separated string
@@ -993,6 +1078,8 @@
                         help="Path to input data embeddings file (CSV or tabular format). Required for dimred plots.")
     parser.add_argument("--method", type=str, default='pca', choices=['pca', 'umap'],
                         help="Transformation method ('pca' or 'umap'). Default is 'pca'. Used for dimred plots.")
+    parser.add_argument("--color", type=str, default=None,
+                        help="User-defined color for the plot.")
 
     # Arguments for Kaplan-Meier
     parser.add_argument("--survival_data", type=str,
@@ -1001,12 +1088,10 @@
                         help="Column name for survival time")
     parser.add_argument("--surv_event_var", type=str, required=False,
                         help="Column name for survival event")
-    parser.add_argument("--event_value", type=str, required=False,
-                        help="Value in event column that represents an event (e.g., 'DECEASED')")
 
     # Arguments for Cox analysis
-    parser.add_argument("--model", type=str,
-                        help="Path to trained flexynesis model (pickle file). Required for cox plots.")
+    parser.add_argument("--important_features", type=str,
+                        help="Path to calculated feature importance file. Required for cox plots.")
     parser.add_argument("--clinical_train", type=str,
                         help="Path to training dataset (pickle file). Required for cox plots.")
     parser.add_argument("--clinical_test", type=str,
@@ -1025,11 +1110,18 @@
                         help="Number of folds for cross-validation. Default is 5")
     parser.add_argument("--random_state", type=int, default=42,
                         help="Random seed for reproducibility. Default is 42")
+    parser.add_argument("--layer", type=str, default=None,
+                        help="Class label for filtering important features.")
 
     # Arguments for dimred, scatter plot, heatmap, PR curves, ROC curves, and box plots
     parser.add_argument("--target_value", type=str, default=None,
                         help="Target value for scatter plot.")
 
+    # Arguments for scatter plots and concordance heatmaps
+    parser.add_argument("--true_label", type=str, default=None,
+                        help="Column name for true labels in scatter plots and concordance heatmaps.")
+    parser.add_argument("--predicted_label", type=str, default=None,
+                        help="Column name for predicted labels in scatter plots and concordance heatmaps.")
     # Common arguments
     parser.add_argument("--output_dir", type=str, default='output',
                         help="Output directory. Default is 'output'")
@@ -1073,14 +1165,12 @@
                 raise ValueError("--surv_time_var is required for Kaplan-Meier plots")
             if not args.surv_event_var:
                 raise ValueError("--surv_event_var is required for Kaplan-Meier plots")
-            if not args.event_value:
-                raise ValueError("--event_value is required for Kaplan-Meier plots")
 
         if args.plot_type in ['cox']:
-            if not args.model:
-                raise ValueError("--model is required when plot_type is 'cox'")
-            if not os.path.isfile(args.model):
-                raise FileNotFoundError(f"Model file not found: {args.model}")
+            if not args.important_features:
+                raise ValueError("--important_features is required when plot_type is 'cox'")
+            if not os.path.isfile(args.important_features):
+                raise FileNotFoundError(f"Important features file not found: {args.important_features}")
             if not args.clinical_train:
                 raise ValueError("--clinical_train is required when plot_type is 'cox'")
             if not os.path.isfile(args.clinical_train):
@@ -1102,17 +1192,17 @@
             if not args.surv_event_var:
                 raise ValueError("--surv_event_var is required for Cox plots")
             if not args.clinical_variables:
-                raise ValueError("--clinical_variables is required for Cox plots")
+                print("--clinical_variables is not set for Cox plots")
             if not isinstance(args.top_features, int) or args.top_features <= 0:
                 raise ValueError("--top_features must be a positive integer")
-            if not args.event_value:
-                raise ValueError("--event_value is required for Kaplan-Meier plots")
             if not args.crossval:
                 args.crossval = False
             if not isinstance(args.n_splits, int) or args.n_splits <= 0:
                 raise ValueError("--n_splits must be a positive integer")
             if not isinstance(args.random_state, int):
                 raise ValueError("--random_state must be an integer")
+            if not args.layer:
+                print("--layer is not specified, using all classes from labels")
 
         if args.plot_type in ['scatter']:
             if not args.labels:
@@ -1174,7 +1264,7 @@
                 survival_name = Path(args.survival_data).stem
                 output_name_base = f"{survival_name}_km"
             elif args.plot_type == 'cox':
-                model_name = Path(args.model).stem
+                model_name = Path(args.important_features).stem
                 output_name_base = f"{model_name}_cox"
             elif args.plot_type == 'scatter':
                 labels_name = Path(args.labels).stem
@@ -1196,33 +1286,34 @@
         if args.plot_type in ['dimred']:
             # Load labels
             print(f"Loading labels from: {args.labels}")
-            label_data = load_labels(args.labels)
+            labels = load_labels(args.labels)
             # Load embeddings data
             print(f"Loading embeddings from: {args.embeddings}")
             embeddings, sample_names = load_embeddings(args.embeddings)
             print(f"embeddings shape: {embeddings.shape}")
 
             # Match samples to embeddings
-            matched_labels = match_samples_to_embeddings(sample_names, label_data)
+            matched_labels = match_samples_to_embeddings(sample_names, labels)
             print(f"Successfully matched {len(matched_labels)} samples for dimensionality reduction")
-
+            print(f"Matched labels shape: {matched_labels.shape}")
+            print(f"Columns in matched labels: {matched_labels.columns.tolist()}")
             generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base)
 
         elif args.plot_type in ['kaplan_meier']:
             # Load labels
             print(f"Loading labels from: {args.labels}")
-            label_data = load_labels(args.labels)
+            labels = load_labels(args.labels)
             # Load survival data
             print(f"Loading survival data from: {args.survival_data}")
-            survival_data = load_survival_data(args.survival_data)
+            survival_data = load_labels(args.survival_data)
             print(f"Survival data shape: {survival_data.shape}")
 
-            generate_km_plots(survival_data, label_data, args, output_dir, output_name_base)
+            generate_km_plots(survival_data, labels, args, output_dir, output_name_base)
 
         elif args.plot_type in ['cox']:
-            # Load model and datasets
-            print(f"Loading model from: {args.model}")
-            model = load_model(args.model)
+            # Load important_features and datasets
+            print(f"Loading important features from: {args.important_features}")
+            important_features = load_labels(args.important_features)
             print(f"Loading training dataset from: {args.clinical_train}")
             clinical_train = load_omics(args.clinical_train)
             print(f"Loading test dataset from: {args.clinical_test}")
@@ -1232,42 +1323,42 @@
             print(f"Loading test omics dataset from: {args.omics_test}")
             omics_test = load_omics(args.omics_test)
 
-            generate_cox_plots(model, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base)
+            generate_cox_plots(important_features, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base)
 
         elif args.plot_type in ['scatter']:
             # Load labels
             print(f"Loading labels from: {args.labels}")
-            label_data = load_labels(args.labels)
+            labels = load_labels(args.labels)
 
-            generate_plot_scatter(label_data, args, output_dir, output_name_base)
+            generate_plot_scatter(labels, args, output_dir, output_name_base)
 
         elif args.plot_type in ['concordance_heatmap']:
             # Load labels
             print(f"Loading labels from: {args.labels}")
-            label_data = load_labels(args.labels)
+            labels = load_labels(args.labels)
 
-            generate_label_concordance_heatmap(label_data, args, output_dir, output_name_base)
+            generate_label_concordance_heatmap(labels, args, output_dir, output_name_base)
 
         elif args.plot_type in ['pr_curve']:
             # Load labels
             print(f"Loading labels from: {args.labels}")
-            label_data = load_labels(args.labels)
+            labels = load_labels(args.labels)
 
-            generate_pr_curves(label_data, args, output_dir, output_name_base)
+            generate_pr_curves(labels, args, output_dir, output_name_base)
 
         elif args.plot_type in ['roc_curve']:
             # Load labels
             print(f"Loading labels from: {args.labels}")
-            label_data = load_labels(args.labels)
+            labels = load_labels(args.labels)
 
-            generate_roc_curves(label_data, args, output_dir, output_name_base)
+            generate_roc_curves(labels, args, output_dir, output_name_base)
 
         elif args.plot_type in ['box_plot']:
             # Load labels
             print(f"Loading labels from: {args.labels}")
-            label_data = load_labels(args.labels)
+            labels = load_labels(args.labels)
 
-            generate_box_plots(label_data, args, output_dir, output_name_base)
+            generate_box_plots(labels, args, output_dir, output_name_base)
 
         print("All plots generated successfully!")