comparison flexynesis_plot.py @ 6:33816f44fc7d draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flexynesis commit 6b520305ec30e6dc37eba92c67a5368cea0fc5ad
author bgruening
date Wed, 23 Jul 2025 07:49:41 +0000
parents 466b593fd87e
children
comparison
equal deleted inserted replaced
5:466b593fd87e 6:33816f44fc7d
9 9
10 import matplotlib.pyplot as plt 10 import matplotlib.pyplot as plt
11 import numpy as np 11 import numpy as np
12 import pandas as pd 12 import pandas as pd
13 import seaborn as sns 13 import seaborn as sns
14 import torch
15 from flexynesis import ( 14 from flexynesis import (
16 build_cox_model, 15 build_cox_model,
17 get_important_features,
18 plot_dim_reduced, 16 plot_dim_reduced,
19 plot_hazard_ratios, 17 plot_hazard_ratios,
20 plot_kaplan_meier_curves, 18 plot_kaplan_meier_curves,
21 plot_pr_curves, 19 plot_pr_curves,
22 plot_roc_curves, 20 plot_roc_curves,
53 if file_ext == '.csv': 51 if file_ext == '.csv':
54 df = pd.read_csv(labels_input) 52 df = pd.read_csv(labels_input)
55 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: 53 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
56 df = pd.read_csv(labels_input, sep='\t') 54 df = pd.read_csv(labels_input, sep='\t')
57 55
58 # Check if this is the specific format with sample_id, known_label, predicted_label 56 print(f"available columns: {df.columns.tolist()}")
59 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] 57 return df
60 if all(col in df.columns for col in required_cols):
61 return df
62 else:
63 raise ValueError(f"Labels file {labels_input} does not contain required columns: {required_cols}")
64 58
65 except Exception as e: 59 except Exception as e:
66 raise ValueError(f"Error loading labels from {labels_input}: {e}") from e 60 raise ValueError(f"Error loading labels from {labels_input}: {e}") from e
67
68
69 def load_survival_data(survival_path):
70 """Load survival data from a file. First column should be sample_id"""
71 try:
72 # Determine file extension
73 file_ext = Path(survival_path).suffix.lower()
74
75 if file_ext == '.csv':
76 df = pd.read_csv(survival_path, index_col=0)
77 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
78 df = pd.read_csv(survival_path, sep='\t', index_col=0)
79 else:
80 raise ValueError(f"Unsupported file extension: {file_ext}")
81 return df
82
83 except Exception as e:
84 raise ValueError(f"Error loading survival data from {survival_path}: {e}") from e
85 61
86 62
87 def load_omics(omics_path): 63 def load_omics(omics_path):
88 """Load omics data from a file. First column should be features""" 64 """Load omics data from a file. First column should be features"""
89 try: 65 try:
100 76
101 except Exception as e: 77 except Exception as e:
102 raise ValueError(f"Error loading omics data from {omics_path}: {e}") from e 78 raise ValueError(f"Error loading omics data from {omics_path}: {e}") from e
103 79
104 80
105 def load_model(model_path): 81 def match_samples_to_embeddings(sample_names, labels):
106 """Load flexynesis model from pickle file"""
107 try:
108 with open(model_path, 'rb') as f:
109 model = torch.load(f, weights_only=False)
110 return model
111 except Exception as e:
112 raise ValueError(f"Error loading model from {model_path}: {e}") from e
113
114
115 def match_samples_to_embeddings(sample_names, label_data):
116 """Filter label data to match sample names in the embeddings""" 82 """Filter label data to match sample names in the embeddings"""
117 df_matched = label_data[label_data['sample_id'].isin(sample_names)] 83 # Create a DataFrame from sample_names to preserve order
84 sample_df = pd.DataFrame({'sample_names': sample_names})
85
86 # left_join
87 first_column = labels.columns[0]
88 df_matched = sample_df.merge(labels, left_on='sample_names', right_on=first_column, how='left')
89
90 # remove sample_names to keep the initial structure
91 df_matched = df_matched.drop('sample_names', axis=1)
118 return df_matched 92 return df_matched
119 93
120 94
121 def detect_color_type(labels_series): 95 def detect_color_type(labels_series):
122 """Auto-detect whether target variables should be treated as categorical or numerical""" 96 """Auto-detect whether target variables should be treated as categorical or numerical"""
212 186
213 187
214 def generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base): 188 def generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base):
215 """Generate dimensionality reduction plots""" 189 """Generate dimensionality reduction plots"""
216 190
217 # Parse target values from comma-separated string 191 # Check if this is the specific format with sample_id, known_label, predicted_label
218 if args.target_value: 192 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
219 target_values = [val.strip() for val in args.target_value.split(',')] 193 is_flexynesis_format = all(col in matched_labels.columns for col in required_cols)
194
195 if not args.color:
196 if is_flexynesis_format:
197 print("Detected flexynesis labels format")
198 print(f"Generating {args.method.upper()} plots for known and predicted labels...")
199 else:
200 print("Labels are not in flexynesis format (Custom labels), please specify a color variable with --color")
201
202 # Parse target values from comma-separated string
203 if args.target_value:
204 target_values = [val.strip() for val in args.target_value.split(',')]
205 else:
206 # If no target values specified, use all unique variables
207 target_values = matched_labels['variable'].unique().tolist()
208
209 print(f"Generating {args.method.upper()} plots for {len(target_values)} target variable(s): {', '.join(target_values)}")
210
211 # Check variables
212 available_vars = matched_labels['variable'].unique()
213 missing_vars = [var for var in target_values if var not in available_vars]
214
215 if missing_vars:
216 print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}")
217 print(f"Available variables: {', '.join(available_vars)}")
218
219 # Filter to only process available variables
220 valid_vars = [var for var in target_values if var in available_vars]
221
222 if not valid_vars:
223 raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}")
224
225 # Generate plots for each valid target variable
226 for var in valid_vars:
227 print(f"\nPlotting variable: {var}")
228
229 # Filter matched labels for current variable
230 var_labels = matched_labels[matched_labels['variable'] == var].copy()
231 var_labels = var_labels.drop_duplicates(subset='sample_id')
232
233 if var_labels.empty:
234 print(f"Warning: No data found for variable '{var}', skipping...")
235 continue
236
237 # Auto-detect color type
238 known_color_type = detect_color_type(var_labels['known_label'])
239 predicted_color_type = detect_color_type(var_labels['predicted_label'])
240
241 print(f" Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}")
242
243 try:
244 # Plot 1: Known labels
245 print(f" Creating known labels plot for {var}...")
246 fig_known = plot_dim_reduced(
247 matrix=embeddings,
248 labels=var_labels['known_label'],
249 method=args.method,
250 color_type=known_color_type
251 )
252
253 output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}"
254 print(f" Saving known labels plot to: {output_path_known.name}")
255 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight')
256
257 # Plot 2: Predicted labels
258 print(f" Creating predicted labels plot for {var}...")
259 fig_predicted = plot_dim_reduced(
260 matrix=embeddings,
261 labels=var_labels['predicted_label'],
262 method=args.method,
263 color_type=predicted_color_type
264 )
265
266 output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}"
267 print(f" Saving predicted labels plot to: {output_path_predicted.name}")
268 fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight')
269
270 print(f" ✓ Successfully created plots for variable '{var}'")
271
272 except Exception as e:
273 print(f" ✗ Error creating plots for variable '{var}': {e}")
274 continue
275
276 print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!")
277
220 else: 278 else:
221 # If no target values specified, use all unique variables 279 # check if the color variable exists in matched_labels
222 target_values = matched_labels['variable'].unique().tolist() 280 if args.color not in matched_labels.columns:
223 281 raise ValueError(f"Color variable '{args.color}' not found in matched labels. Available columns: {matched_labels.columns.tolist()}")
224 print(f"Generating {args.method.upper()} plots for {len(target_values)} target variable(s): {', '.join(target_values)}")
225
226 # Check variables
227 available_vars = matched_labels['variable'].unique()
228 missing_vars = [var for var in target_values if var not in available_vars]
229
230 if missing_vars:
231 print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}")
232 print(f"Available variables: {', '.join(available_vars)}")
233
234 # Filter to only process available variables
235 valid_vars = [var for var in target_values if var in available_vars]
236
237 if not valid_vars:
238 raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}")
239
240 # Generate plots for each valid target variable
241 for var in valid_vars:
242 print(f"\nPlotting variable: {var}")
243
244 # Filter matched labels for current variable
245 var_labels = matched_labels[matched_labels['variable'] == var].copy()
246 var_labels = var_labels.drop_duplicates(subset='sample_id')
247
248 if var_labels.empty:
249 print(f"Warning: No data found for variable '{var}', skipping...")
250 continue
251 282
252 # Auto-detect color type 283 # Auto-detect color type
253 known_color_type = detect_color_type(var_labels['known_label']) 284 color_type = detect_color_type(matched_labels[args.color])
254 predicted_color_type = detect_color_type(var_labels['predicted_label']) 285
255 286 print(f" Auto-detected color type: {color_type}")
256 print(f" Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}") 287
257 288 # Plot: Specified color column
258 try: 289 print(f" Creating plot for {args.color}...")
259 # Plot 1: Known labels 290 fig = plot_dim_reduced(
260 print(f" Creating known labels plot for {var}...") 291 matrix=embeddings,
261 fig_known = plot_dim_reduced( 292 labels=matched_labels[args.color],
262 matrix=embeddings, 293 method=args.method,
263 labels=var_labels['known_label'], 294 color_type=color_type
264 method=args.method, 295 )
265 color_type=known_color_type 296
266 ) 297 output_path = output_dir / f"{output_name_base}_{args.color}.{args.format}"
267 298 print(f" Saving plot to: {output_path.name}")
268 output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}" 299 fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
269 print(f" Saving known labels plot to: {output_path_known.name}") 300
270 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') 301 print(f" ✓ Successfully created plot for variable '{args.color}'")
271 302
272 # Plot 2: Predicted labels 303
273 print(f" Creating predicted labels plot for {var}...") 304 def generate_km_plots(survival_data, labels, args, output_dir, output_name_base):
274 fig_predicted = plot_dim_reduced(
275 matrix=embeddings,
276 labels=var_labels['predicted_label'],
277 method=args.method,
278 color_type=predicted_color_type
279 )
280
281 output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}"
282 print(f" Saving predicted labels plot to: {output_path_predicted.name}")
283 fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight')
284
285 print(f" ✓ Successfully created plots for variable '{var}'")
286
287 except Exception as e:
288 print(f" ✗ Error creating plots for variable '{var}': {e}")
289 continue
290
291 print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!")
292
293
294 def generate_km_plots(survival_data, label_data, args, output_dir, output_name_base):
295 """Generate Kaplan-Meier plots""" 305 """Generate Kaplan-Meier plots"""
306
307 # Check if this is the specific format with sample_id, known_label, predicted_label
308 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
309 is_flexynesis_format = all(col in labels.columns for col in required_cols)
310
311 if not is_flexynesis_format:
312 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.")
313
296 print("Generating Kaplan-Meier curves of risk subtypes...") 314 print("Generating Kaplan-Meier curves of risk subtypes...")
297 315
298 # Reset index and rename the index column to sample_id
299 survival_data = survival_data.reset_index()
300 if survival_data.columns[0] != 'sample_id': 316 if survival_data.columns[0] != 'sample_id':
301 survival_data = survival_data.rename(columns={survival_data.columns[0]: 'sample_id'}) 317 survival_data = survival_data.rename(columns={survival_data.columns[0]: 'sample_id'})
302 318
303 # Convert survival event column to binary (0/1) based on event_value
304 # Check if the event column exists 319 # Check if the event column exists
305 if args.surv_event_var not in survival_data.columns: 320 if args.surv_event_var not in survival_data.columns:
306 raise ValueError(f"Column '{args.surv_event_var}' not found in survival data") 321 raise ValueError(f"Column '{args.surv_event_var}' not found in survival data")
307 322
308 # Convert to string for comparison to handle mixed types 323 labels = labels[(labels['variable'] == args.surv_event_var)]
309 survival_data[args.surv_event_var] = survival_data[args.surv_event_var].astype(str)
310 event_value_str = str(args.event_value)
311
312 # Create binary event column (1 if matches event_value, 0 otherwise)
313 survival_data[f'{args.surv_event_var}_binary'] = (
314 survival_data[args.surv_event_var] == event_value_str
315 ).astype(int)
316
317 # Filter for survival category and class_label == '1:DECEASED'
318 label_data['class_label'] = label_data['class_label'].astype(str)
319
320 label_data = label_data[(label_data['variable'] == args.surv_event_var) & (label_data['class_label'] == event_value_str)]
321
322 # check survival data
323 for col in [args.surv_time_var, args.surv_event_var]:
324 if col not in survival_data.columns:
325 raise ValueError(f"Column '{col}' not found in survival data")
326 324
327 # Merge survival data with labels 325 # Merge survival data with labels
328 df_deceased = pd.merge(survival_data, label_data, on='sample_id', how='inner') 326 df_deceased = pd.merge(survival_data, labels, on='sample_id', how='inner')
327 df_deceased = df_deceased.dropna(subset=[args.surv_time_var, args.surv_event_var])
329 328
330 if df_deceased.empty: 329 if df_deceased.empty:
331 raise ValueError("No matching samples found after merging survival and label data.") 330 raise ValueError("No matching samples found after merging survival and label data.")
332 331
333 # Get risk scores 332 # Get risk scores
334 risk_scores = df_deceased['probability'].values 333 risk_scores = df_deceased['predicted_label'].values
335 334
336 # Compute groups (e.g., median split) 335 # Compute groups (e.g., median split)
337 quantiles = np.quantile(risk_scores, [0.5]) 336 quantiles = np.quantile(risk_scores, [0.5])
338 groups = np.digitize(risk_scores, quantiles) 337 groups = np.digitize(risk_scores, quantiles)
339 group_labels = ['low_risk' if g == 0 else 'high_risk' for g in groups] 338 group_labels = ['low_risk' if g == 0 else 'high_risk' for g in groups]
340 339
341 fig_known = plot_kaplan_meier_curves( 340 fig_known = plot_kaplan_meier_curves(
342 durations=df_deceased[args.surv_time_var], 341 durations=df_deceased[args.surv_time_var],
343 events=df_deceased[f'{args.surv_event_var}_binary'], 342 events=df_deceased[args.surv_event_var],
344 categorical_variable=group_labels 343 categorical_variable=group_labels
345 ) 344 )
346 345
347 output_path_known = output_dir / f"{output_name_base}_km_risk_subtypes.{args.format}" 346 output_path_known = output_dir / f"{output_name_base}_km_risk_subtypes.{args.format}"
348 print(f"Saving Kaplan-Meier plot to: {output_path_known.absolute()}") 347 print(f"Saving Kaplan-Meier plot to: {output_path_known.absolute()}")
349 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') 348 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight')
350 349
351 print("Kaplan-Meier plot saved successfully!") 350 print("Kaplan-Meier plot saved successfully!")
352 351
353 352
354 def generate_cox_plots(model, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base): 353 def generate_cox_plots(important_features, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base):
355 """Generate Cox proportional hazards plots""" 354 """Generate Cox proportional hazards plots"""
356 print("Generating Cox proportional hazards analysis...") 355 print("Generating Cox proportional hazards analysis...")
357 356
357 # Check if this is the specific format with target_variable, importance
358 required_cols = ['target_variable', 'layer', 'importance']
359 is_flexynesis_format = all(col in important_features.columns for col in required_cols)
360
361 if not is_flexynesis_format:
362 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid important_features file with the required columns, {required_cols}.")
363
358 # Parse clinical variables 364 # Parse clinical variables
359 clinical_vars = [var.strip() for var in args.clinical_variables.split(',')] 365 clinical_vars = []
366 if args.clinical_variables:
367 clinical_vars = [var.strip() for var in args.clinical_variables.split(',')]
360 368
361 # Validate that survival variables are included 369 # Validate that survival variables are included
362 required_vars = [args.surv_time_var, args.surv_event_var] 370 required_vars = [args.surv_time_var, args.surv_event_var]
363 for var in required_vars: 371 for var in required_vars:
364 if var not in clinical_vars: 372 if var not in clinical_vars:
380 df_clin = pd.concat([df_clin_train, df_clin_test], axis=0) 388 df_clin = pd.concat([df_clin_train, df_clin_test], axis=0)
381 389
382 # Get top survival markers 390 # Get top survival markers
383 print(f"Extracting top {args.top_features} important features for {args.surv_event_var}...") 391 print(f"Extracting top {args.top_features} important features for {args.surv_event_var}...")
384 try: 392 try:
385 imp = get_important_features(model, 393 print(f"Loading {args.top_features} important features from: {args.important_features}")
386 var=args.surv_event_var, 394 imp_features = load_labels(args.important_features)
387 top=args.top_features 395 imp_features = imp_features[imp_features['target_variable'] == args.surv_event_var]
388 )['name'].unique().tolist() 396 if args.layer not in imp_features['layer'].unique():
397 print(f"Available class labels: {imp_features['layer'].unique()}")
398 raise ValueError(f"Class label '{args.layer}' not found in important features data: {args.important_features}")
399 imp_features = imp_features[imp_features['layer'] == args.layer]
400 if imp_features.empty:
401 raise ValueError(f"No important features found for target variable '{args.surv_event_var}' in {args.important_features}")
402 imp_features = imp_features.sort_values(by='importance', ascending=False)
403
404 if len(imp_features) < args.top_features:
405 raise ValueError(f"Requested top {args.top_features} features, but only {len(imp_features)} available in {args.important_features}")
406
407 imp = imp_features['name'].unique().tolist()[0:args.top_features]
408
389 print(f"Top features: {', '.join(imp)}") 409 print(f"Top features: {', '.join(imp)}")
390 except Exception as e: 410 except Exception as e:
391 raise ValueError(f"Error getting important features: {e}") 411 raise ValueError(f"Error getting important features: {e}")
392 412
393 # Extract feature data from omics datasets 413 # Extract feature data from omics datasets
415 final_samples = len(df) 435 final_samples = len(df)
416 print(f"Removed {initial_samples - final_samples} samples without survival data") 436 print(f"Removed {initial_samples - final_samples} samples without survival data")
417 437
418 if df.empty: 438 if df.empty:
419 raise ValueError("No samples remain after filtering for survival data") 439 raise ValueError("No samples remain after filtering for survival data")
420
421 # Convert survival event column to binary (0/1) based on event_value
422 # Convert to string for comparison to handle mixed types
423 df[args.surv_event_var] = df[args.surv_event_var].astype(str)
424 event_value_str = str(args.event_value)
425
426 df[f'{args.surv_event_var}'] = (
427 df[args.surv_event_var] == event_value_str
428 ).astype(int)
429 440
430 # Build Cox model 441 # Build Cox model
431 print(f"Building Cox model with time variable: {args.surv_time_var}, event variable: {args.surv_event_var}") 442 print(f"Building Cox model with time variable: {args.surv_time_var}, event variable: {args.surv_event_var}")
432 try: 443 try:
433 coxm = build_cox_model(df, 444 coxm = build_cox_model(df,
457 468
458 def generate_plot_scatter(labels, args, output_dir, output_name_base): 469 def generate_plot_scatter(labels, args, output_dir, output_name_base):
459 """Generate scatter plot of known vs predicted labels""" 470 """Generate scatter plot of known vs predicted labels"""
460 print("Generating scatter plots of known vs predicted labels...") 471 print("Generating scatter plots of known vs predicted labels...")
461 472
473 # Check if this is the specific format with sample_id, known_label, predicted_label
474 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
475 is_flexynesis_format = all(col in labels.columns for col in required_cols)
476
477 if is_flexynesis_format:
478 # Parse target values from comma-separated string
479 if args.target_value:
480 target_values = [val.strip() for val in args.target_value.split(',')]
481 else:
482 # If no target values specified, use all unique variables
483 target_values = labels['variable'].unique().tolist()
484
485 print(f"Processing target values: {target_values}")
486
487 successful_plots = 0
488 skipped_plots = 0
489
490 for target_value in target_values:
491 print(f"\nProcessing target value: '{target_value}'")
492
493 # Filter labels for the current target value
494 target_labels = labels[labels['variable'] == target_value]
495
496 if target_labels.empty:
497 print(f" Warning: No data found for target value '{target_value}' - skipping")
498 skipped_plots += 1
499 continue
500
501 # Check if labels are numeric and convert
502 true_values = pd.to_numeric(target_labels['known_label'], errors='coerce')
503 predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce')
504
505 if true_values.isna().all() or predicted_values.isna().all():
506 print(f"No valid numeric values found for known or predicted labels in '{target_value}'")
507 skipped_plots += 1
508 continue
509
510 try:
511 print(f" Generating scatter plot for '{target_value}'...")
512 fig = plot_scatter(true_values, predicted_values)
513
514 # Create output filename with target value
515 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
516 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
517
518 output_path = output_dir / output_filename
519 print(f" Saving scatter plot to: {output_path.absolute()}")
520 fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
521
522 successful_plots += 1
523 print(f" Scatter plot for '{target_value}' generated successfully!")
524
525 except Exception as e:
526 print(f" Error generating plot for '{target_value}': {str(e)}")
527 skipped_plots += 1
528
529 # Summary
530 print(" Summary:")
531 print(f" Successfully generated: {successful_plots} plots")
532 print(f" Skipped: {skipped_plots} plots")
533
534 if successful_plots == 0:
535 raise ValueError("No scatter plots could be generated. Check your data and target values.")
536
537 print("Scatter plot generation completed!")
538
539 if not is_flexynesis_format:
540 print("Labels are not in flexynesis format (Custom labels)")
541
542 if not args.true_label or not args.predicted_label:
543 raise ValueError("For custom labels, please specify --true_label and --predicted_label arguments.")
544
545 # Check if labels are numeric and convert
546 true_values = pd.to_numeric(labels[args.true_label], errors='coerce')
547 predicted_values = pd.to_numeric(labels[args.predicted_label], errors='coerce')
548
549 if true_values.isna().all() or predicted_values.isna().all():
550 print("No valid numeric values found for known or predicted labels")
551
552 try:
553 print(" Generating scatter plot...")
554 fig = plot_scatter(true_values, predicted_values)
555
556 # Create output filename with target value
557 output_filename = f"{output_name_base}.{args.format}"
558
559 output_path = output_dir / output_filename
560 print(f" Saving scatter plot to: {output_path.absolute()}")
561 fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
562
563 except Exception as e:
564 print(f" Error generating plot: {str(e)}")
565
566 print("Scatter plot generation completed!")
567
568
569 def generate_label_concordance_heatmap(labels, args, output_dir, output_name_base):
570 """Generate label concordance heatmap"""
571 print("Generating label concordance heatmaps...")
572
573 # Check if this is the specific format with sample_id, known_label, predicted_label
574 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
575 is_flexynesis_format = all(col in labels.columns for col in required_cols)
576
577 if is_flexynesis_format:
578 # Parse target values from comma-separated string
579 if args.target_value:
580 target_values = [val.strip() for val in args.target_value.split(',')]
581 else:
582 # If no target values specified, use all unique variables
583 target_values = labels['variable'].unique().tolist()
584
585 print(f"Processing target values: {target_values}")
586
587 for target_value in target_values:
588 print(f"\nProcessing target value: '{target_value}'")
589
590 # Filter labels for the current target value
591 target_labels = labels[labels['variable'] == target_value]
592
593 if target_labels.empty:
594 print(f" Warning: No data found for target value '{target_value}' - skipping")
595 continue
596
597 true_values = target_labels['known_label'].tolist()
598 predicted_values = target_labels['predicted_label'].tolist()
599
600 try:
601 print(f" Generating heatmap for '{target_value}'...")
602 fig = plot_label_concordance_heatmap(true_values, predicted_values)
603 plt.close(fig)
604
605 # Create output filename with target value
606 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
607 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
608
609 output_path = output_dir / output_filename
610 print(f" Saving heatmap to: {output_path.absolute()}")
611 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight')
612
613 except Exception as e:
614 print(f" Error generating heatmap for '{target_value}': {str(e)}")
615 continue
616
617 print("Label concordance heatmap generated successfully!")
618
619 if not is_flexynesis_format:
620 print("Labels are not in flexynesis format (Custom labels)")
621
622 if not args.true_label or not args.predicted_label:
623 raise ValueError("For custom labels, please specify --true_label and --predicted_label arguments.")
624
625 true_values = labels[args.true_label].tolist()
626 predicted_values = labels[args.predicted_label].tolist()
627
628 try:
629 print(" Generating heatmap for...")
630 fig = plot_label_concordance_heatmap(true_values, predicted_values)
631 plt.close(fig)
632
633 # Create output filename with target value
634 output_filename = f"{output_name_base}.{args.format}"
635
636 output_path = output_dir / output_filename
637 print(f" Saving heatmap to: {output_path.absolute()}")
638 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight')
639
640 except Exception as e:
641 print(f" Error generating heatmap': {str(e)}")
642
643 print("Label concordance heatmap generated successfully!")
644
645
646 def generate_pr_curves(labels, args, output_dir, output_name_base):
647 """Generate precision-recall curves"""
648 print("Generating precision-recall curves...")
649
650 # Check if this is the specific format with sample_id, known_label, predicted_label
651 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
652 is_flexynesis_format = all(col in labels.columns for col in required_cols)
653
654 if not is_flexynesis_format:
655 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.")
656
462 # Parse target values from comma-separated string 657 # Parse target values from comma-separated string
463 if args.target_value: 658 if args.target_value:
464 target_values = [val.strip() for val in args.target_value.split(',')] 659 target_values = [val.strip() for val in args.target_value.split(',')]
465 else: 660 else:
466 # If no target values specified, use all unique variables 661 # If no target values specified, use all unique variables
467 target_values = labels['variable'].unique().tolist() 662 target_values = labels['variable'].unique().tolist()
468 663
469 print(f"Processing target values: {target_values}") 664 print(f"Processing target values: {target_values}")
470 665
471 successful_plots = 0
472 skipped_plots = 0
473
474 for target_value in target_values: 666 for target_value in target_values:
475 print(f"\nProcessing target value: '{target_value}'") 667 print(f"\nProcessing target value: '{target_value}'")
476 668
477 # Filter labels for the current target value 669 # Filter labels for the current target value
478 target_labels = labels[labels['variable'] == target_value] 670 target_labels = labels[labels['variable'] == target_value]
479 671
480 if target_labels.empty: 672 # Check if this is a regression problem (no class probabilities)
481 print(f" Warning: No data found for target value '{target_value}' - skipping") 673 prob_columns = target_labels['class_label'].unique()
482 skipped_plots += 1 674 non_na_probs = target_labels['probability'].notna().sum()
483 continue 675
484 676 print(f" Class labels found: {list(prob_columns)}")
485 # Check if labels are numeric and convert 677 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}")
486 true_values = pd.to_numeric(target_labels['known_label'], errors='coerce') 678
487 predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce') 679 # If most probabilities are NaN, this is likely a regression problem
488 680 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities
489 if true_values.isna().all() or predicted_values.isna().all(): 681 print(" Detected regression problem - precision-recall curves not applicable")
490 print(f"No valid numeric values found for known or predicted labels in '{target_value}'") 682 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)")
491 skipped_plots += 1 683 continue
684
685 # Debug: Check data quality
686 total_rows = len(target_labels)
687 missing_labels = target_labels['known_label'].isna().sum()
688 missing_probs = target_labels['probability'].isna().sum()
689 unique_samples = target_labels['sample_id'].nunique()
690 unique_classes = target_labels['class_label'].nunique()
691
692 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes")
693 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability")
694
695 if missing_labels > 0:
696 print(f" Warning: Found {missing_labels} missing known_label values")
697 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5]
698 print(f" Sample IDs with missing known_label: {list(missing_samples)}")
699
700 # Remove rows with missing known_label
701 target_labels = target_labels.dropna(subset=['known_label'])
702 if target_labels.empty:
703 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping")
704 continue
705
706 # 1. Pivot to wide format
707 prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability')
708
709 print(f" After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes")
710 print(f" Class columns: {list(prob_df.columns)}")
711
712 # Check for NaN values in probability data
713 nan_counts = prob_df.isna().sum()
714 if nan_counts.any():
715 print(f" NaN counts per class: {dict(nan_counts)}")
716 print(f" Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}")
717
718 # Drop only rows where ALL probabilities are NaN
719 all_nan_rows = prob_df.isna().all(axis=1)
720 if all_nan_rows.any():
721 print(f" Dropping {all_nan_rows.sum()} samples with all NaN probabilities")
722 prob_df = prob_df[~all_nan_rows]
723
724 remaining_nans = prob_df.isna().sum().sum()
725 if remaining_nans > 0:
726 print(f" Warning: {remaining_nans} individual NaN values remain - filling with 0")
727 prob_df = prob_df.fillna(0)
728
729 if prob_df.empty:
730 print(f" Error: No valid probability data remaining for '{target_value}' - skipping")
731 continue
732
733 # 2. Get true labels
734 true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id')
735
736 # 3. Align indices - only keep samples that exist in both datasets
737 common_indices = prob_df.index.intersection(true_labels_df.index)
738 if len(common_indices) == 0:
739 print(f" Error: No common sample_ids between probability and true label data for '{target_value}' - skipping")
740 continue
741
742 print(f" Found {len(common_indices)} samples with both probability and true label data")
743
744 # Filter both datasets to common indices
745 prob_df_aligned = prob_df.loc[common_indices]
746 y_true = true_labels_df.loc[common_indices]['known_label']
747
748 # 4. Final check for NaN values
749 if y_true.isna().any():
750 print(f" Error: True labels still contain NaN after alignment for '{target_value}' - skipping")
751 continue
752
753 if prob_df_aligned.isna().any().any():
754 print(f" Error: Probability data still contains NaN after alignment for '{target_value}' - skipping")
755 continue
756
757 # 5. Convert categorical labels to integer labels
758 # Create a mapping from class names to integers
759 class_names = list(prob_df_aligned.columns)
760 class_to_int = {class_name: i for i, class_name in enumerate(class_names)}
761
762 print(f" Class mapping: {class_to_int}")
763
764 # Convert true labels to integers
765 y_true_np = y_true.map(class_to_int).to_numpy()
766 y_probs_np = prob_df_aligned.to_numpy()
767
768 print(f" Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}")
769 print(f" Unique true labels (integers): {set(y_true_np)}")
770 print(f" Class labels (columns): {class_names}")
771 print(f" Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}")
772
773 # Check for any unmapped labels (will be NaN)
774 if pd.isna(y_true_np).any():
775 print(" Error: Some true labels could not be mapped to class columns")
776 unmapped_labels = set(y_true[y_true.map(class_to_int).isna()])
777 print(f" Unmapped labels: {unmapped_labels}")
778 print(f" Available classes: {class_names}")
492 continue 779 continue
493 780
494 try: 781 try:
495 print(f" Generating scatter plot for '{target_value}'...") 782 print(f" Generating precision-recall curve for '{target_value}'...")
496 fig = plot_scatter(true_values, predicted_values) 783 fig = plot_pr_curves(y_true_np, y_probs_np)
497 784
498 # Create output filename with target value 785 # Create output filename with target value
499 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') 786 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
500 if len(target_values) > 1: 787 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
501 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
502 else:
503 output_filename = f"{output_name_base}.{args.format}"
504 788
505 output_path = output_dir / output_filename 789 output_path = output_dir / output_filename
506 print(f" Saving scatter plot to: {output_path.absolute()}") 790 print(f" Saving precision-recall curve to: {output_path.absolute()}")
507 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') 791 fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
508 792
509 successful_plots += 1
510 print(f" Scatter plot for '{target_value}' generated successfully!")
511
512 except Exception as e: 793 except Exception as e:
513 print(f" Error generating plot for '{target_value}': {str(e)}") 794 print(f" Error generating precision-recall curve for '{target_value}': {str(e)}")
514 skipped_plots += 1 795 print(f" Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}")
515 796 print(f" Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}")
516 # Summary 797 continue
517 print(" Summary:") 798
518 print(f" Successfully generated: {successful_plots} plots") 799 print("Precision-recall curves generated successfully!")
519 print(f" Skipped: {skipped_plots} plots") 800
520 801
521 if successful_plots == 0: 802 def generate_roc_curves(labels, args, output_dir, output_name_base):
522 raise ValueError("No scatter plots could be generated. Check your data and target values.") 803 """Generate ROC curves"""
523 804 print("Generating ROC curves...")
524 print("Scatter plot generation completed!") 805
525 806 # Check if this is the specific format with sample_id, known_label, predicted_label
526 807 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
527 def generate_label_concordance_heatmap(labels, args, output_dir, output_name_base): 808 is_flexynesis_format = all(col in labels.columns for col in required_cols)
528 """Generate label concordance heatmap""" 809
529 print("Generating label concordance heatmaps...") 810 if not is_flexynesis_format:
811 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.")
530 812
531 # Parse target values from comma-separated string 813 # Parse target values from comma-separated string
532 if args.target_value: 814 if args.target_value:
533 target_values = [val.strip() for val in args.target_value.split(',')] 815 target_values = [val.strip() for val in args.target_value.split(',')]
534 else: 816 else:
541 print(f"\nProcessing target value: '{target_value}'") 823 print(f"\nProcessing target value: '{target_value}'")
542 824
543 # Filter labels for the current target value 825 # Filter labels for the current target value
544 target_labels = labels[labels['variable'] == target_value] 826 target_labels = labels[labels['variable'] == target_value]
545 827
546 if target_labels.empty:
547 print(f" Warning: No data found for target value '{target_value}' - skipping")
548 continue
549
550 true_values = target_labels['known_label'].tolist()
551 predicted_values = target_labels['predicted_label'].tolist()
552
553 try:
554 print(f" Generating heatmap for '{target_value}'...")
555 fig = plot_label_concordance_heatmap(true_values, predicted_values)
556 plt.close(fig)
557
558 # Create output filename with target value
559 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
560 if len(target_values) > 1:
561 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
562 else:
563 output_filename = f"{output_name_base}.{args.format}"
564
565 output_path = output_dir / output_filename
566 print(f" Saving heatmap to: {output_path.absolute()}")
567 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight')
568
569 except Exception as e:
570 print(f" Error generating heatmap for '{target_value}': {str(e)}")
571 continue
572
573 print("Label concordance heatmap generated successfully!")
574
575
576 def generate_pr_curves(labels, args, output_dir, output_name_base):
577 """Generate precision-recall curves"""
578 print("Generating precision-recall curves...")
579
580 # Parse target values from comma-separated string
581 if args.target_value:
582 target_values = [val.strip() for val in args.target_value.split(',')]
583 else:
584 # If no target values specified, use all unique variables
585 target_values = labels['variable'].unique().tolist()
586
587 print(f"Processing target values: {target_values}")
588
589 for target_value in target_values:
590 print(f"\nProcessing target value: '{target_value}'")
591
592 # Filter labels for the current target value
593 target_labels = labels[labels['variable'] == target_value]
594
595 # Check if this is a regression problem (no class probabilities) 828 # Check if this is a regression problem (no class probabilities)
596 prob_columns = target_labels['class_label'].unique() 829 prob_columns = target_labels['class_label'].unique()
597 non_na_probs = target_labels['probability'].notna().sum() 830 non_na_probs = target_labels['probability'].notna().sum()
598 831
599 print(f" Class labels found: {list(prob_columns)}") 832 print(f" Class labels found: {list(prob_columns)}")
600 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") 833 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}")
601 834
602 # If most probabilities are NaN, this is likely a regression problem 835 # If most probabilities are NaN, this is likely a regression problem
603 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities 836 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities
604 print(" Detected regression problem - precision-recall curves not applicable") 837 print(" Detected regression problem - ROC curves not applicable")
605 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") 838 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)")
606 continue 839 continue
607 840
608 # Debug: Check data quality 841 # Debug: Check data quality
609 total_rows = len(target_labels) 842 total_rows = len(target_labels)
700 print(f" Unmapped labels: {unmapped_labels}") 933 print(f" Unmapped labels: {unmapped_labels}")
701 print(f" Available classes: {class_names}") 934 print(f" Available classes: {class_names}")
702 continue 935 continue
703 936
704 try: 937 try:
705 print(f" Generating precision-recall curve for '{target_value}'...") 938 print(f" Generating ROC curve for '{target_value}'...")
706 fig = plot_pr_curves(y_true_np, y_probs_np) 939 fig = plot_roc_curves(y_true_np, y_probs_np)
707 940
708 # Create output filename with target value 941 # Create output filename with target value
709 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') 942 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
710 if len(target_values) > 1: 943 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
711 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
712 else:
713 output_filename = f"{output_name_base}.{args.format}"
714
715 output_path = output_dir / output_filename
716 print(f" Saving precision-recall curve to: {output_path.absolute()}")
717 fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
718
719 except Exception as e:
720 print(f" Error generating precision-recall curve for '{target_value}': {str(e)}")
721 print(f" Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}")
722 print(f" Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}")
723 continue
724
725 print("Precision-recall curves generated successfully!")
726
727
728 def generate_roc_curves(labels, args, output_dir, output_name_base):
729 """Generate ROC curves"""
730 print("Generating ROC curves...")
731
732 # Parse target values from comma-separated string
733 if args.target_value:
734 target_values = [val.strip() for val in args.target_value.split(',')]
735 else:
736 # If no target values specified, use all unique variables
737 target_values = labels['variable'].unique().tolist()
738
739 print(f"Processing target values: {target_values}")
740
741 for target_value in target_values:
742 print(f"\nProcessing target value: '{target_value}'")
743
744 # Filter labels for the current target value
745 target_labels = labels[labels['variable'] == target_value]
746
747 # Check if this is a regression problem (no class probabilities)
748 prob_columns = target_labels['class_label'].unique()
749 non_na_probs = target_labels['probability'].notna().sum()
750
751 print(f" Class labels found: {list(prob_columns)}")
752 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}")
753
754 # If most probabilities are NaN, this is likely a regression problem
755 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities
756 print(" Detected regression problem - ROC curves not applicable")
757 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)")
758 continue
759
760 # Debug: Check data quality
761 total_rows = len(target_labels)
762 missing_labels = target_labels['known_label'].isna().sum()
763 missing_probs = target_labels['probability'].isna().sum()
764 unique_samples = target_labels['sample_id'].nunique()
765 unique_classes = target_labels['class_label'].nunique()
766
767 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes")
768 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability")
769
770 if missing_labels > 0:
771 print(f" Warning: Found {missing_labels} missing known_label values")
772 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5]
773 print(f" Sample IDs with missing known_label: {list(missing_samples)}")
774
775 # Remove rows with missing known_label
776 target_labels = target_labels.dropna(subset=['known_label'])
777 if target_labels.empty:
778 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping")
779 continue
780
781 # 1. Pivot to wide format
782 prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability')
783
784 print(f" After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes")
785 print(f" Class columns: {list(prob_df.columns)}")
786
787 # Check for NaN values in probability data
788 nan_counts = prob_df.isna().sum()
789 if nan_counts.any():
790 print(f" NaN counts per class: {dict(nan_counts)}")
791 print(f" Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}")
792
793 # Drop only rows where ALL probabilities are NaN
794 all_nan_rows = prob_df.isna().all(axis=1)
795 if all_nan_rows.any():
796 print(f" Dropping {all_nan_rows.sum()} samples with all NaN probabilities")
797 prob_df = prob_df[~all_nan_rows]
798
799 remaining_nans = prob_df.isna().sum().sum()
800 if remaining_nans > 0:
801 print(f" Warning: {remaining_nans} individual NaN values remain - filling with 0")
802 prob_df = prob_df.fillna(0)
803
804 if prob_df.empty:
805 print(f" Error: No valid probability data remaining for '{target_value}' - skipping")
806 continue
807
808 # 2. Get true labels
809 true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id')
810
811 # 3. Align indices - only keep samples that exist in both datasets
812 common_indices = prob_df.index.intersection(true_labels_df.index)
813 if len(common_indices) == 0:
814 print(f" Error: No common sample_ids between probability and true label data for '{target_value}' - skipping")
815 continue
816
817 print(f" Found {len(common_indices)} samples with both probability and true label data")
818
819 # Filter both datasets to common indices
820 prob_df_aligned = prob_df.loc[common_indices]
821 y_true = true_labels_df.loc[common_indices]['known_label']
822
823 # 4. Final check for NaN values
824 if y_true.isna().any():
825 print(f" Error: True labels still contain NaN after alignment for '{target_value}' - skipping")
826 continue
827
828 if prob_df_aligned.isna().any().any():
829 print(f" Error: Probability data still contains NaN after alignment for '{target_value}' - skipping")
830 continue
831
832 # 5. Convert categorical labels to integer labels
833 # Create a mapping from class names to integers
834 class_names = list(prob_df_aligned.columns)
835 class_to_int = {class_name: i for i, class_name in enumerate(class_names)}
836
837 print(f" Class mapping: {class_to_int}")
838
839 # Convert true labels to integers
840 y_true_np = y_true.map(class_to_int).to_numpy()
841 y_probs_np = prob_df_aligned.to_numpy()
842
843 print(f" Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}")
844 print(f" Unique true labels (integers): {set(y_true_np)}")
845 print(f" Class labels (columns): {class_names}")
846 print(f" Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}")
847
848 # Check for any unmapped labels (will be NaN)
849 if pd.isna(y_true_np).any():
850 print(" Error: Some true labels could not be mapped to class columns")
851 unmapped_labels = set(y_true[y_true.map(class_to_int).isna()])
852 print(f" Unmapped labels: {unmapped_labels}")
853 print(f" Available classes: {class_names}")
854 continue
855
856 try:
857 print(f" Generating ROC curve for '{target_value}'...")
858 fig = plot_roc_curves(y_true_np, y_probs_np)
859
860 # Create output filename with target value
861 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
862 if len(target_values) > 1:
863 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
864 else:
865 output_filename = f"{output_name_base}.{args.format}"
866 944
867 output_path = output_dir / output_filename 945 output_path = output_dir / output_filename
868 print(f" Saving ROC curve to: {output_path.absolute()}") 946 print(f" Saving ROC curve to: {output_path.absolute()}")
869 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') 947 fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
870 948
877 print("ROC curves generated successfully!") 955 print("ROC curves generated successfully!")
878 956
879 957
880 def generate_box_plots(labels, args, output_dir, output_name_base): 958 def generate_box_plots(labels, args, output_dir, output_name_base):
881 """Generate box plots for model predictions""" 959 """Generate box plots for model predictions"""
960
961 # Check if this is the specific format with sample_id, known_label, predicted_label
962 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
963 is_flexynesis_format = all(col in labels.columns for col in required_cols)
964
965 if not is_flexynesis_format:
966 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.")
882 967
883 print("Generating box plots...") 968 print("Generating box plots...")
884 969
885 # Parse target values from comma-separated string 970 # Parse target values from comma-separated string
886 if args.target_value: 971 if args.target_value:
991 # Arguments for dimensionality reduction 1076 # Arguments for dimensionality reduction
992 parser.add_argument("--embeddings", type=str, 1077 parser.add_argument("--embeddings", type=str,
993 help="Path to input data embeddings file (CSV or tabular format). Required for dimred plots.") 1078 help="Path to input data embeddings file (CSV or tabular format). Required for dimred plots.")
994 parser.add_argument("--method", type=str, default='pca', choices=['pca', 'umap'], 1079 parser.add_argument("--method", type=str, default='pca', choices=['pca', 'umap'],
995 help="Transformation method ('pca' or 'umap'). Default is 'pca'. Used for dimred plots.") 1080 help="Transformation method ('pca' or 'umap'). Default is 'pca'. Used for dimred plots.")
1081 parser.add_argument("--color", type=str, default=None,
1082 help="User-defined color for the plot.")
996 1083
997 # Arguments for Kaplan-Meier 1084 # Arguments for Kaplan-Meier
998 parser.add_argument("--survival_data", type=str, 1085 parser.add_argument("--survival_data", type=str,
999 help="Path to survival data file with columns: duration and event. Required for kaplan_meier plots.") 1086 help="Path to survival data file with columns: duration and event. Required for kaplan_meier plots.")
1000 parser.add_argument("--surv_time_var", type=str, required=False, 1087 parser.add_argument("--surv_time_var", type=str, required=False,
1001 help="Column name for survival time") 1088 help="Column name for survival time")
1002 parser.add_argument("--surv_event_var", type=str, required=False, 1089 parser.add_argument("--surv_event_var", type=str, required=False,
1003 help="Column name for survival event") 1090 help="Column name for survival event")
1004 parser.add_argument("--event_value", type=str, required=False,
1005 help="Value in event column that represents an event (e.g., 'DECEASED')")
1006 1091
1007 # Arguments for Cox analysis 1092 # Arguments for Cox analysis
1008 parser.add_argument("--model", type=str, 1093 parser.add_argument("--important_features", type=str,
1009 help="Path to trained flexynesis model (pickle file). Required for cox plots.") 1094 help="Path to calculated feature importance file. Required for cox plots.")
1010 parser.add_argument("--clinical_train", type=str, 1095 parser.add_argument("--clinical_train", type=str,
1011 help="Path to training dataset (pickle file). Required for cox plots.") 1096 help="Path to training dataset (pickle file). Required for cox plots.")
1012 parser.add_argument("--clinical_test", type=str, 1097 parser.add_argument("--clinical_test", type=str,
1013 help="Path to test dataset (pickle file). Required for cox plots.") 1098 help="Path to test dataset (pickle file). Required for cox plots.")
1014 parser.add_argument("--omics_train", type=str, default=None, 1099 parser.add_argument("--omics_train", type=str, default=None,
1023 help="If True, performs K-fold cross-validation and returns average C-index. Default is False") 1108 help="If True, performs K-fold cross-validation and returns average C-index. Default is False")
1024 parser.add_argument("--n_splits", type=int, default=5, 1109 parser.add_argument("--n_splits", type=int, default=5,
1025 help="Number of folds for cross-validation. Default is 5") 1110 help="Number of folds for cross-validation. Default is 5")
1026 parser.add_argument("--random_state", type=int, default=42, 1111 parser.add_argument("--random_state", type=int, default=42,
1027 help="Random seed for reproducibility. Default is 42") 1112 help="Random seed for reproducibility. Default is 42")
1113 parser.add_argument("--layer", type=str, default=None,
1114 help="Class label for filtering important features.")
1028 1115
1029 # Arguments for dimred, scatter plot, heatmap, PR curves, ROC curves, and box plots 1116 # Arguments for dimred, scatter plot, heatmap, PR curves, ROC curves, and box plots
1030 parser.add_argument("--target_value", type=str, default=None, 1117 parser.add_argument("--target_value", type=str, default=None,
1031 help="Target value for scatter plot.") 1118 help="Target value for scatter plot.")
1032 1119
1120 # Arguments for scatter plots and concordance heatmaps
1121 parser.add_argument("--true_label", type=str, default=None,
1122 help="Column name for true labels in scatter plots and concordance heatmaps.")
1123 parser.add_argument("--predicted_label", type=str, default=None,
1124 help="Column name for predicted labels in scatter plots and concordance heatmaps.")
1033 # Common arguments 1125 # Common arguments
1034 parser.add_argument("--output_dir", type=str, default='output', 1126 parser.add_argument("--output_dir", type=str, default='output',
1035 help="Output directory. Default is 'output'") 1127 help="Output directory. Default is 'output'")
1036 parser.add_argument("--output_name", type=str, default=None, 1128 parser.add_argument("--output_name", type=str, default=None,
1037 help="Output filename base") 1129 help="Output filename base")
1071 raise ValueError("--method is required for dimensionality reduction plots") 1163 raise ValueError("--method is required for dimensionality reduction plots")
1072 if not args.surv_time_var: 1164 if not args.surv_time_var:
1073 raise ValueError("--surv_time_var is required for Kaplan-Meier plots") 1165 raise ValueError("--surv_time_var is required for Kaplan-Meier plots")
1074 if not args.surv_event_var: 1166 if not args.surv_event_var:
1075 raise ValueError("--surv_event_var is required for Kaplan-Meier plots") 1167 raise ValueError("--surv_event_var is required for Kaplan-Meier plots")
1076 if not args.event_value:
1077 raise ValueError("--event_value is required for Kaplan-Meier plots")
1078 1168
1079 if args.plot_type in ['cox']: 1169 if args.plot_type in ['cox']:
1080 if not args.model: 1170 if not args.important_features:
1081 raise ValueError("--model is required when plot_type is 'cox'") 1171 raise ValueError("--important_features is required when plot_type is 'cox'")
1082 if not os.path.isfile(args.model): 1172 if not os.path.isfile(args.important_features):
1083 raise FileNotFoundError(f"Model file not found: {args.model}") 1173 raise FileNotFoundError(f"Important features file not found: {args.important_features}")
1084 if not args.clinical_train: 1174 if not args.clinical_train:
1085 raise ValueError("--clinical_train is required when plot_type is 'cox'") 1175 raise ValueError("--clinical_train is required when plot_type is 'cox'")
1086 if not os.path.isfile(args.clinical_train): 1176 if not os.path.isfile(args.clinical_train):
1087 raise FileNotFoundError(f"Training dataset file not found: {args.clinical_train}") 1177 raise FileNotFoundError(f"Training dataset file not found: {args.clinical_train}")
1088 if not args.clinical_test: 1178 if not args.clinical_test:
1100 if not args.surv_time_var: 1190 if not args.surv_time_var:
1101 raise ValueError("--surv_time_var is required for Cox plots") 1191 raise ValueError("--surv_time_var is required for Cox plots")
1102 if not args.surv_event_var: 1192 if not args.surv_event_var:
1103 raise ValueError("--surv_event_var is required for Cox plots") 1193 raise ValueError("--surv_event_var is required for Cox plots")
1104 if not args.clinical_variables: 1194 if not args.clinical_variables:
1105 raise ValueError("--clinical_variables is required for Cox plots") 1195 print("--clinical_variables is not set for Cox plots")
1106 if not isinstance(args.top_features, int) or args.top_features <= 0: 1196 if not isinstance(args.top_features, int) or args.top_features <= 0:
1107 raise ValueError("--top_features must be a positive integer") 1197 raise ValueError("--top_features must be a positive integer")
1108 if not args.event_value:
1109 raise ValueError("--event_value is required for Kaplan-Meier plots")
1110 if not args.crossval: 1198 if not args.crossval:
1111 args.crossval = False 1199 args.crossval = False
1112 if not isinstance(args.n_splits, int) or args.n_splits <= 0: 1200 if not isinstance(args.n_splits, int) or args.n_splits <= 0:
1113 raise ValueError("--n_splits must be a positive integer") 1201 raise ValueError("--n_splits must be a positive integer")
1114 if not isinstance(args.random_state, int): 1202 if not isinstance(args.random_state, int):
1115 raise ValueError("--random_state must be an integer") 1203 raise ValueError("--random_state must be an integer")
1204 if not args.layer:
1205 print("--layer is not specified, using all classes from labels")
1116 1206
1117 if args.plot_type in ['scatter']: 1207 if args.plot_type in ['scatter']:
1118 if not args.labels: 1208 if not args.labels:
1119 raise ValueError("--labels is required for scatter plots") 1209 raise ValueError("--labels is required for scatter plots")
1120 if not args.target_value: 1210 if not args.target_value:
1172 output_name_base = f"{embeddings_name}_{args.method}" 1262 output_name_base = f"{embeddings_name}_{args.method}"
1173 elif args.plot_type == 'kaplan_meier': 1263 elif args.plot_type == 'kaplan_meier':
1174 survival_name = Path(args.survival_data).stem 1264 survival_name = Path(args.survival_data).stem
1175 output_name_base = f"{survival_name}_km" 1265 output_name_base = f"{survival_name}_km"
1176 elif args.plot_type == 'cox': 1266 elif args.plot_type == 'cox':
1177 model_name = Path(args.model).stem 1267 model_name = Path(args.important_features).stem
1178 output_name_base = f"{model_name}_cox" 1268 output_name_base = f"{model_name}_cox"
1179 elif args.plot_type == 'scatter': 1269 elif args.plot_type == 'scatter':
1180 labels_name = Path(args.labels).stem 1270 labels_name = Path(args.labels).stem
1181 output_name_base = f"{labels_name}_scatter" 1271 output_name_base = f"{labels_name}_scatter"
1182 elif args.plot_type == 'concordance_heatmap': 1272 elif args.plot_type == 'concordance_heatmap':
1194 1284
1195 # Generate plots based on type 1285 # Generate plots based on type
1196 if args.plot_type in ['dimred']: 1286 if args.plot_type in ['dimred']:
1197 # Load labels 1287 # Load labels
1198 print(f"Loading labels from: {args.labels}") 1288 print(f"Loading labels from: {args.labels}")
1199 label_data = load_labels(args.labels) 1289 labels = load_labels(args.labels)
1200 # Load embeddings data 1290 # Load embeddings data
1201 print(f"Loading embeddings from: {args.embeddings}") 1291 print(f"Loading embeddings from: {args.embeddings}")
1202 embeddings, sample_names = load_embeddings(args.embeddings) 1292 embeddings, sample_names = load_embeddings(args.embeddings)
1203 print(f"embeddings shape: {embeddings.shape}") 1293 print(f"embeddings shape: {embeddings.shape}")
1204 1294
1205 # Match samples to embeddings 1295 # Match samples to embeddings
1206 matched_labels = match_samples_to_embeddings(sample_names, label_data) 1296 matched_labels = match_samples_to_embeddings(sample_names, labels)
1207 print(f"Successfully matched {len(matched_labels)} samples for dimensionality reduction") 1297 print(f"Successfully matched {len(matched_labels)} samples for dimensionality reduction")
1208 1298 print(f"Matched labels shape: {matched_labels.shape}")
1299 print(f"Columns in matched labels: {matched_labels.columns.tolist()}")
1209 generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base) 1300 generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base)
1210 1301
1211 elif args.plot_type in ['kaplan_meier']: 1302 elif args.plot_type in ['kaplan_meier']:
1212 # Load labels 1303 # Load labels
1213 print(f"Loading labels from: {args.labels}") 1304 print(f"Loading labels from: {args.labels}")
1214 label_data = load_labels(args.labels) 1305 labels = load_labels(args.labels)
1215 # Load survival data 1306 # Load survival data
1216 print(f"Loading survival data from: {args.survival_data}") 1307 print(f"Loading survival data from: {args.survival_data}")
1217 survival_data = load_survival_data(args.survival_data) 1308 survival_data = load_labels(args.survival_data)
1218 print(f"Survival data shape: {survival_data.shape}") 1309 print(f"Survival data shape: {survival_data.shape}")
1219 1310
1220 generate_km_plots(survival_data, label_data, args, output_dir, output_name_base) 1311 generate_km_plots(survival_data, labels, args, output_dir, output_name_base)
1221 1312
1222 elif args.plot_type in ['cox']: 1313 elif args.plot_type in ['cox']:
1223 # Load model and datasets 1314 # Load important_features and datasets
1224 print(f"Loading model from: {args.model}") 1315 print(f"Loading important features from: {args.important_features}")
1225 model = load_model(args.model) 1316 important_features = load_labels(args.important_features)
1226 print(f"Loading training dataset from: {args.clinical_train}") 1317 print(f"Loading training dataset from: {args.clinical_train}")
1227 clinical_train = load_omics(args.clinical_train) 1318 clinical_train = load_omics(args.clinical_train)
1228 print(f"Loading test dataset from: {args.clinical_test}") 1319 print(f"Loading test dataset from: {args.clinical_test}")
1229 clinical_test = load_omics(args.clinical_test) 1320 clinical_test = load_omics(args.clinical_test)
1230 print(f"Loading training omics dataset from: {args.omics_train}") 1321 print(f"Loading training omics dataset from: {args.omics_train}")
1231 omics_train = load_omics(args.omics_train) 1322 omics_train = load_omics(args.omics_train)
1232 print(f"Loading test omics dataset from: {args.omics_test}") 1323 print(f"Loading test omics dataset from: {args.omics_test}")
1233 omics_test = load_omics(args.omics_test) 1324 omics_test = load_omics(args.omics_test)
1234 1325
1235 generate_cox_plots(model, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base) 1326 generate_cox_plots(important_features, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base)
1236 1327
1237 elif args.plot_type in ['scatter']: 1328 elif args.plot_type in ['scatter']:
1238 # Load labels 1329 # Load labels
1239 print(f"Loading labels from: {args.labels}") 1330 print(f"Loading labels from: {args.labels}")
1240 label_data = load_labels(args.labels) 1331 labels = load_labels(args.labels)
1241 1332
1242 generate_plot_scatter(label_data, args, output_dir, output_name_base) 1333 generate_plot_scatter(labels, args, output_dir, output_name_base)
1243 1334
1244 elif args.plot_type in ['concordance_heatmap']: 1335 elif args.plot_type in ['concordance_heatmap']:
1245 # Load labels 1336 # Load labels
1246 print(f"Loading labels from: {args.labels}") 1337 print(f"Loading labels from: {args.labels}")
1247 label_data = load_labels(args.labels) 1338 labels = load_labels(args.labels)
1248 1339
1249 generate_label_concordance_heatmap(label_data, args, output_dir, output_name_base) 1340 generate_label_concordance_heatmap(labels, args, output_dir, output_name_base)
1250 1341
1251 elif args.plot_type in ['pr_curve']: 1342 elif args.plot_type in ['pr_curve']:
1252 # Load labels 1343 # Load labels
1253 print(f"Loading labels from: {args.labels}") 1344 print(f"Loading labels from: {args.labels}")
1254 label_data = load_labels(args.labels) 1345 labels = load_labels(args.labels)
1255 1346
1256 generate_pr_curves(label_data, args, output_dir, output_name_base) 1347 generate_pr_curves(labels, args, output_dir, output_name_base)
1257 1348
1258 elif args.plot_type in ['roc_curve']: 1349 elif args.plot_type in ['roc_curve']:
1259 # Load labels 1350 # Load labels
1260 print(f"Loading labels from: {args.labels}") 1351 print(f"Loading labels from: {args.labels}")
1261 label_data = load_labels(args.labels) 1352 labels = load_labels(args.labels)
1262 1353
1263 generate_roc_curves(label_data, args, output_dir, output_name_base) 1354 generate_roc_curves(labels, args, output_dir, output_name_base)
1264 1355
1265 elif args.plot_type in ['box_plot']: 1356 elif args.plot_type in ['box_plot']:
1266 # Load labels 1357 # Load labels
1267 print(f"Loading labels from: {args.labels}") 1358 print(f"Loading labels from: {args.labels}")
1268 label_data = load_labels(args.labels) 1359 labels = load_labels(args.labels)
1269 1360
1270 generate_box_plots(label_data, args, output_dir, output_name_base) 1361 generate_box_plots(labels, args, output_dir, output_name_base)
1271 1362
1272 print("All plots generated successfully!") 1363 print("All plots generated successfully!")
1273 1364
1274 except (FileNotFoundError, ValueError, pd.errors.ParserError) as e: 1365 except (FileNotFoundError, ValueError, pd.errors.ParserError) as e:
1275 print(f"Error: {e}") 1366 print(f"Error: {e}")