Mercurial > repos > bgruening > flexynesis
comparison flexynesis_plot.py @ 6:33816f44fc7d draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flexynesis commit 6b520305ec30e6dc37eba92c67a5368cea0fc5ad
author | bgruening |
---|---|
date | Wed, 23 Jul 2025 07:49:41 +0000 |
parents | 466b593fd87e |
children |
comparison
equal
deleted
inserted
replaced
5:466b593fd87e | 6:33816f44fc7d |
---|---|
9 | 9 |
10 import matplotlib.pyplot as plt | 10 import matplotlib.pyplot as plt |
11 import numpy as np | 11 import numpy as np |
12 import pandas as pd | 12 import pandas as pd |
13 import seaborn as sns | 13 import seaborn as sns |
14 import torch | |
15 from flexynesis import ( | 14 from flexynesis import ( |
16 build_cox_model, | 15 build_cox_model, |
17 get_important_features, | |
18 plot_dim_reduced, | 16 plot_dim_reduced, |
19 plot_hazard_ratios, | 17 plot_hazard_ratios, |
20 plot_kaplan_meier_curves, | 18 plot_kaplan_meier_curves, |
21 plot_pr_curves, | 19 plot_pr_curves, |
22 plot_roc_curves, | 20 plot_roc_curves, |
53 if file_ext == '.csv': | 51 if file_ext == '.csv': |
54 df = pd.read_csv(labels_input) | 52 df = pd.read_csv(labels_input) |
55 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: | 53 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: |
56 df = pd.read_csv(labels_input, sep='\t') | 54 df = pd.read_csv(labels_input, sep='\t') |
57 | 55 |
58 # Check if this is the specific format with sample_id, known_label, predicted_label | 56 print(f"available columns: {df.columns.tolist()}") |
59 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] | 57 return df |
60 if all(col in df.columns for col in required_cols): | |
61 return df | |
62 else: | |
63 raise ValueError(f"Labels file {labels_input} does not contain required columns: {required_cols}") | |
64 | 58 |
65 except Exception as e: | 59 except Exception as e: |
66 raise ValueError(f"Error loading labels from {labels_input}: {e}") from e | 60 raise ValueError(f"Error loading labels from {labels_input}: {e}") from e |
67 | |
68 | |
69 def load_survival_data(survival_path): | |
70 """Load survival data from a file. First column should be sample_id""" | |
71 try: | |
72 # Determine file extension | |
73 file_ext = Path(survival_path).suffix.lower() | |
74 | |
75 if file_ext == '.csv': | |
76 df = pd.read_csv(survival_path, index_col=0) | |
77 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: | |
78 df = pd.read_csv(survival_path, sep='\t', index_col=0) | |
79 else: | |
80 raise ValueError(f"Unsupported file extension: {file_ext}") | |
81 return df | |
82 | |
83 except Exception as e: | |
84 raise ValueError(f"Error loading survival data from {survival_path}: {e}") from e | |
85 | 61 |
86 | 62 |
87 def load_omics(omics_path): | 63 def load_omics(omics_path): |
88 """Load omics data from a file. First column should be features""" | 64 """Load omics data from a file. First column should be features""" |
89 try: | 65 try: |
100 | 76 |
101 except Exception as e: | 77 except Exception as e: |
102 raise ValueError(f"Error loading omics data from {omics_path}: {e}") from e | 78 raise ValueError(f"Error loading omics data from {omics_path}: {e}") from e |
103 | 79 |
104 | 80 |
105 def load_model(model_path): | 81 def match_samples_to_embeddings(sample_names, labels): |
106 """Load flexynesis model from pickle file""" | |
107 try: | |
108 with open(model_path, 'rb') as f: | |
109 model = torch.load(f, weights_only=False) | |
110 return model | |
111 except Exception as e: | |
112 raise ValueError(f"Error loading model from {model_path}: {e}") from e | |
113 | |
114 | |
115 def match_samples_to_embeddings(sample_names, label_data): | |
116 """Filter label data to match sample names in the embeddings""" | 82 """Filter label data to match sample names in the embeddings""" |
117 df_matched = label_data[label_data['sample_id'].isin(sample_names)] | 83 # Create a DataFrame from sample_names to preserve order |
84 sample_df = pd.DataFrame({'sample_names': sample_names}) | |
85 | |
86 # left_join | |
87 first_column = labels.columns[0] | |
88 df_matched = sample_df.merge(labels, left_on='sample_names', right_on=first_column, how='left') | |
89 | |
90 # remove sample_names to keep the initial structure | |
91 df_matched = df_matched.drop('sample_names', axis=1) | |
118 return df_matched | 92 return df_matched |
119 | 93 |
120 | 94 |
121 def detect_color_type(labels_series): | 95 def detect_color_type(labels_series): |
122 """Auto-detect whether target variables should be treated as categorical or numerical""" | 96 """Auto-detect whether target variables should be treated as categorical or numerical""" |
212 | 186 |
213 | 187 |
214 def generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base): | 188 def generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base): |
215 """Generate dimensionality reduction plots""" | 189 """Generate dimensionality reduction plots""" |
216 | 190 |
217 # Parse target values from comma-separated string | 191 # Check if this is the specific format with sample_id, known_label, predicted_label |
218 if args.target_value: | 192 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] |
219 target_values = [val.strip() for val in args.target_value.split(',')] | 193 is_flexynesis_format = all(col in matched_labels.columns for col in required_cols) |
194 | |
195 if not args.color: | |
196 if is_flexynesis_format: | |
197 print("Detected flexynesis labels format") | |
198 print(f"Generating {args.method.upper()} plots for known and predicted labels...") | |
199 else: | |
200 print("Labels are not in flexynesis format (Custom labels), please specify a color variable with --color") | |
201 | |
202 # Parse target values from comma-separated string | |
203 if args.target_value: | |
204 target_values = [val.strip() for val in args.target_value.split(',')] | |
205 else: | |
206 # If no target values specified, use all unique variables | |
207 target_values = matched_labels['variable'].unique().tolist() | |
208 | |
209 print(f"Generating {args.method.upper()} plots for {len(target_values)} target variable(s): {', '.join(target_values)}") | |
210 | |
211 # Check variables | |
212 available_vars = matched_labels['variable'].unique() | |
213 missing_vars = [var for var in target_values if var not in available_vars] | |
214 | |
215 if missing_vars: | |
216 print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}") | |
217 print(f"Available variables: {', '.join(available_vars)}") | |
218 | |
219 # Filter to only process available variables | |
220 valid_vars = [var for var in target_values if var in available_vars] | |
221 | |
222 if not valid_vars: | |
223 raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}") | |
224 | |
225 # Generate plots for each valid target variable | |
226 for var in valid_vars: | |
227 print(f"\nPlotting variable: {var}") | |
228 | |
229 # Filter matched labels for current variable | |
230 var_labels = matched_labels[matched_labels['variable'] == var].copy() | |
231 var_labels = var_labels.drop_duplicates(subset='sample_id') | |
232 | |
233 if var_labels.empty: | |
234 print(f"Warning: No data found for variable '{var}', skipping...") | |
235 continue | |
236 | |
237 # Auto-detect color type | |
238 known_color_type = detect_color_type(var_labels['known_label']) | |
239 predicted_color_type = detect_color_type(var_labels['predicted_label']) | |
240 | |
241 print(f" Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}") | |
242 | |
243 try: | |
244 # Plot 1: Known labels | |
245 print(f" Creating known labels plot for {var}...") | |
246 fig_known = plot_dim_reduced( | |
247 matrix=embeddings, | |
248 labels=var_labels['known_label'], | |
249 method=args.method, | |
250 color_type=known_color_type | |
251 ) | |
252 | |
253 output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}" | |
254 print(f" Saving known labels plot to: {output_path_known.name}") | |
255 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') | |
256 | |
257 # Plot 2: Predicted labels | |
258 print(f" Creating predicted labels plot for {var}...") | |
259 fig_predicted = plot_dim_reduced( | |
260 matrix=embeddings, | |
261 labels=var_labels['predicted_label'], | |
262 method=args.method, | |
263 color_type=predicted_color_type | |
264 ) | |
265 | |
266 output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}" | |
267 print(f" Saving predicted labels plot to: {output_path_predicted.name}") | |
268 fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight') | |
269 | |
270 print(f" ✓ Successfully created plots for variable '{var}'") | |
271 | |
272 except Exception as e: | |
273 print(f" ✗ Error creating plots for variable '{var}': {e}") | |
274 continue | |
275 | |
276 print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!") | |
277 | |
220 else: | 278 else: |
221 # If no target values specified, use all unique variables | 279 # check if the color variable exists in matched_labels |
222 target_values = matched_labels['variable'].unique().tolist() | 280 if args.color not in matched_labels.columns: |
223 | 281 raise ValueError(f"Color variable '{args.color}' not found in matched labels. Available columns: {matched_labels.columns.tolist()}") |
224 print(f"Generating {args.method.upper()} plots for {len(target_values)} target variable(s): {', '.join(target_values)}") | |
225 | |
226 # Check variables | |
227 available_vars = matched_labels['variable'].unique() | |
228 missing_vars = [var for var in target_values if var not in available_vars] | |
229 | |
230 if missing_vars: | |
231 print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}") | |
232 print(f"Available variables: {', '.join(available_vars)}") | |
233 | |
234 # Filter to only process available variables | |
235 valid_vars = [var for var in target_values if var in available_vars] | |
236 | |
237 if not valid_vars: | |
238 raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}") | |
239 | |
240 # Generate plots for each valid target variable | |
241 for var in valid_vars: | |
242 print(f"\nPlotting variable: {var}") | |
243 | |
244 # Filter matched labels for current variable | |
245 var_labels = matched_labels[matched_labels['variable'] == var].copy() | |
246 var_labels = var_labels.drop_duplicates(subset='sample_id') | |
247 | |
248 if var_labels.empty: | |
249 print(f"Warning: No data found for variable '{var}', skipping...") | |
250 continue | |
251 | 282 |
252 # Auto-detect color type | 283 # Auto-detect color type |
253 known_color_type = detect_color_type(var_labels['known_label']) | 284 color_type = detect_color_type(matched_labels[args.color]) |
254 predicted_color_type = detect_color_type(var_labels['predicted_label']) | 285 |
255 | 286 print(f" Auto-detected color type: {color_type}") |
256 print(f" Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}") | 287 |
257 | 288 # Plot: Specified color column |
258 try: | 289 print(f" Creating plot for {args.color}...") |
259 # Plot 1: Known labels | 290 fig = plot_dim_reduced( |
260 print(f" Creating known labels plot for {var}...") | 291 matrix=embeddings, |
261 fig_known = plot_dim_reduced( | 292 labels=matched_labels[args.color], |
262 matrix=embeddings, | 293 method=args.method, |
263 labels=var_labels['known_label'], | 294 color_type=color_type |
264 method=args.method, | 295 ) |
265 color_type=known_color_type | 296 |
266 ) | 297 output_path = output_dir / f"{output_name_base}_{args.color}.{args.format}" |
267 | 298 print(f" Saving plot to: {output_path.name}") |
268 output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}" | 299 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') |
269 print(f" Saving known labels plot to: {output_path_known.name}") | 300 |
270 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') | 301 print(f" ✓ Successfully created plot for variable '{args.color}'") |
271 | 302 |
272 # Plot 2: Predicted labels | 303 |
273 print(f" Creating predicted labels plot for {var}...") | 304 def generate_km_plots(survival_data, labels, args, output_dir, output_name_base): |
274 fig_predicted = plot_dim_reduced( | |
275 matrix=embeddings, | |
276 labels=var_labels['predicted_label'], | |
277 method=args.method, | |
278 color_type=predicted_color_type | |
279 ) | |
280 | |
281 output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}" | |
282 print(f" Saving predicted labels plot to: {output_path_predicted.name}") | |
283 fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight') | |
284 | |
285 print(f" ✓ Successfully created plots for variable '{var}'") | |
286 | |
287 except Exception as e: | |
288 print(f" ✗ Error creating plots for variable '{var}': {e}") | |
289 continue | |
290 | |
291 print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!") | |
292 | |
293 | |
294 def generate_km_plots(survival_data, label_data, args, output_dir, output_name_base): | |
295 """Generate Kaplan-Meier plots""" | 305 """Generate Kaplan-Meier plots""" |
306 | |
307 # Check if this is the specific format with sample_id, known_label, predicted_label | |
308 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] | |
309 is_flexynesis_format = all(col in labels.columns for col in required_cols) | |
310 | |
311 if not is_flexynesis_format: | |
312 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.") | |
313 | |
296 print("Generating Kaplan-Meier curves of risk subtypes...") | 314 print("Generating Kaplan-Meier curves of risk subtypes...") |
297 | 315 |
298 # Reset index and rename the index column to sample_id | |
299 survival_data = survival_data.reset_index() | |
300 if survival_data.columns[0] != 'sample_id': | 316 if survival_data.columns[0] != 'sample_id': |
301 survival_data = survival_data.rename(columns={survival_data.columns[0]: 'sample_id'}) | 317 survival_data = survival_data.rename(columns={survival_data.columns[0]: 'sample_id'}) |
302 | 318 |
303 # Convert survival event column to binary (0/1) based on event_value | |
304 # Check if the event column exists | 319 # Check if the event column exists |
305 if args.surv_event_var not in survival_data.columns: | 320 if args.surv_event_var not in survival_data.columns: |
306 raise ValueError(f"Column '{args.surv_event_var}' not found in survival data") | 321 raise ValueError(f"Column '{args.surv_event_var}' not found in survival data") |
307 | 322 |
308 # Convert to string for comparison to handle mixed types | 323 labels = labels[(labels['variable'] == args.surv_event_var)] |
309 survival_data[args.surv_event_var] = survival_data[args.surv_event_var].astype(str) | |
310 event_value_str = str(args.event_value) | |
311 | |
312 # Create binary event column (1 if matches event_value, 0 otherwise) | |
313 survival_data[f'{args.surv_event_var}_binary'] = ( | |
314 survival_data[args.surv_event_var] == event_value_str | |
315 ).astype(int) | |
316 | |
317 # Filter for survival category and class_label == '1:DECEASED' | |
318 label_data['class_label'] = label_data['class_label'].astype(str) | |
319 | |
320 label_data = label_data[(label_data['variable'] == args.surv_event_var) & (label_data['class_label'] == event_value_str)] | |
321 | |
322 # check survival data | |
323 for col in [args.surv_time_var, args.surv_event_var]: | |
324 if col not in survival_data.columns: | |
325 raise ValueError(f"Column '{col}' not found in survival data") | |
326 | 324 |
327 # Merge survival data with labels | 325 # Merge survival data with labels |
328 df_deceased = pd.merge(survival_data, label_data, on='sample_id', how='inner') | 326 df_deceased = pd.merge(survival_data, labels, on='sample_id', how='inner') |
327 df_deceased = df_deceased.dropna(subset=[args.surv_time_var, args.surv_event_var]) | |
329 | 328 |
330 if df_deceased.empty: | 329 if df_deceased.empty: |
331 raise ValueError("No matching samples found after merging survival and label data.") | 330 raise ValueError("No matching samples found after merging survival and label data.") |
332 | 331 |
333 # Get risk scores | 332 # Get risk scores |
334 risk_scores = df_deceased['probability'].values | 333 risk_scores = df_deceased['predicted_label'].values |
335 | 334 |
336 # Compute groups (e.g., median split) | 335 # Compute groups (e.g., median split) |
337 quantiles = np.quantile(risk_scores, [0.5]) | 336 quantiles = np.quantile(risk_scores, [0.5]) |
338 groups = np.digitize(risk_scores, quantiles) | 337 groups = np.digitize(risk_scores, quantiles) |
339 group_labels = ['low_risk' if g == 0 else 'high_risk' for g in groups] | 338 group_labels = ['low_risk' if g == 0 else 'high_risk' for g in groups] |
340 | 339 |
341 fig_known = plot_kaplan_meier_curves( | 340 fig_known = plot_kaplan_meier_curves( |
342 durations=df_deceased[args.surv_time_var], | 341 durations=df_deceased[args.surv_time_var], |
343 events=df_deceased[f'{args.surv_event_var}_binary'], | 342 events=df_deceased[args.surv_event_var], |
344 categorical_variable=group_labels | 343 categorical_variable=group_labels |
345 ) | 344 ) |
346 | 345 |
347 output_path_known = output_dir / f"{output_name_base}_km_risk_subtypes.{args.format}" | 346 output_path_known = output_dir / f"{output_name_base}_km_risk_subtypes.{args.format}" |
348 print(f"Saving Kaplan-Meier plot to: {output_path_known.absolute()}") | 347 print(f"Saving Kaplan-Meier plot to: {output_path_known.absolute()}") |
349 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') | 348 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') |
350 | 349 |
351 print("Kaplan-Meier plot saved successfully!") | 350 print("Kaplan-Meier plot saved successfully!") |
352 | 351 |
353 | 352 |
354 def generate_cox_plots(model, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base): | 353 def generate_cox_plots(important_features, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base): |
355 """Generate Cox proportional hazards plots""" | 354 """Generate Cox proportional hazards plots""" |
356 print("Generating Cox proportional hazards analysis...") | 355 print("Generating Cox proportional hazards analysis...") |
357 | 356 |
357 # Check if this is the specific format with target_variable, importance | |
358 required_cols = ['target_variable', 'layer', 'importance'] | |
359 is_flexynesis_format = all(col in important_features.columns for col in required_cols) | |
360 | |
361 if not is_flexynesis_format: | |
362 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid important_features file with the required columns, {required_cols}.") | |
363 | |
358 # Parse clinical variables | 364 # Parse clinical variables |
359 clinical_vars = [var.strip() for var in args.clinical_variables.split(',')] | 365 clinical_vars = [] |
366 if args.clinical_variables: | |
367 clinical_vars = [var.strip() for var in args.clinical_variables.split(',')] | |
360 | 368 |
361 # Validate that survival variables are included | 369 # Validate that survival variables are included |
362 required_vars = [args.surv_time_var, args.surv_event_var] | 370 required_vars = [args.surv_time_var, args.surv_event_var] |
363 for var in required_vars: | 371 for var in required_vars: |
364 if var not in clinical_vars: | 372 if var not in clinical_vars: |
380 df_clin = pd.concat([df_clin_train, df_clin_test], axis=0) | 388 df_clin = pd.concat([df_clin_train, df_clin_test], axis=0) |
381 | 389 |
382 # Get top survival markers | 390 # Get top survival markers |
383 print(f"Extracting top {args.top_features} important features for {args.surv_event_var}...") | 391 print(f"Extracting top {args.top_features} important features for {args.surv_event_var}...") |
384 try: | 392 try: |
385 imp = get_important_features(model, | 393 print(f"Loading {args.top_features} important features from: {args.important_features}") |
386 var=args.surv_event_var, | 394 imp_features = load_labels(args.important_features) |
387 top=args.top_features | 395 imp_features = imp_features[imp_features['target_variable'] == args.surv_event_var] |
388 )['name'].unique().tolist() | 396 if args.layer not in imp_features['layer'].unique(): |
397 print(f"Available class labels: {imp_features['layer'].unique()}") | |
398 raise ValueError(f"Class label '{args.layer}' not found in important features data: {args.important_features}") | |
399 imp_features = imp_features[imp_features['layer'] == args.layer] | |
400 if imp_features.empty: | |
401 raise ValueError(f"No important features found for target variable '{args.surv_event_var}' in {args.important_features}") | |
402 imp_features = imp_features.sort_values(by='importance', ascending=False) | |
403 | |
404 if len(imp_features) < args.top_features: | |
405 raise ValueError(f"Requested top {args.top_features} features, but only {len(imp_features)} available in {args.important_features}") | |
406 | |
407 imp = imp_features['name'].unique().tolist()[0:args.top_features] | |
408 | |
389 print(f"Top features: {', '.join(imp)}") | 409 print(f"Top features: {', '.join(imp)}") |
390 except Exception as e: | 410 except Exception as e: |
391 raise ValueError(f"Error getting important features: {e}") | 411 raise ValueError(f"Error getting important features: {e}") |
392 | 412 |
393 # Extract feature data from omics datasets | 413 # Extract feature data from omics datasets |
415 final_samples = len(df) | 435 final_samples = len(df) |
416 print(f"Removed {initial_samples - final_samples} samples without survival data") | 436 print(f"Removed {initial_samples - final_samples} samples without survival data") |
417 | 437 |
418 if df.empty: | 438 if df.empty: |
419 raise ValueError("No samples remain after filtering for survival data") | 439 raise ValueError("No samples remain after filtering for survival data") |
420 | |
421 # Convert survival event column to binary (0/1) based on event_value | |
422 # Convert to string for comparison to handle mixed types | |
423 df[args.surv_event_var] = df[args.surv_event_var].astype(str) | |
424 event_value_str = str(args.event_value) | |
425 | |
426 df[f'{args.surv_event_var}'] = ( | |
427 df[args.surv_event_var] == event_value_str | |
428 ).astype(int) | |
429 | 440 |
430 # Build Cox model | 441 # Build Cox model |
431 print(f"Building Cox model with time variable: {args.surv_time_var}, event variable: {args.surv_event_var}") | 442 print(f"Building Cox model with time variable: {args.surv_time_var}, event variable: {args.surv_event_var}") |
432 try: | 443 try: |
433 coxm = build_cox_model(df, | 444 coxm = build_cox_model(df, |
457 | 468 |
458 def generate_plot_scatter(labels, args, output_dir, output_name_base): | 469 def generate_plot_scatter(labels, args, output_dir, output_name_base): |
459 """Generate scatter plot of known vs predicted labels""" | 470 """Generate scatter plot of known vs predicted labels""" |
460 print("Generating scatter plots of known vs predicted labels...") | 471 print("Generating scatter plots of known vs predicted labels...") |
461 | 472 |
473 # Check if this is the specific format with sample_id, known_label, predicted_label | |
474 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] | |
475 is_flexynesis_format = all(col in labels.columns for col in required_cols) | |
476 | |
477 if is_flexynesis_format: | |
478 # Parse target values from comma-separated string | |
479 if args.target_value: | |
480 target_values = [val.strip() for val in args.target_value.split(',')] | |
481 else: | |
482 # If no target values specified, use all unique variables | |
483 target_values = labels['variable'].unique().tolist() | |
484 | |
485 print(f"Processing target values: {target_values}") | |
486 | |
487 successful_plots = 0 | |
488 skipped_plots = 0 | |
489 | |
490 for target_value in target_values: | |
491 print(f"\nProcessing target value: '{target_value}'") | |
492 | |
493 # Filter labels for the current target value | |
494 target_labels = labels[labels['variable'] == target_value] | |
495 | |
496 if target_labels.empty: | |
497 print(f" Warning: No data found for target value '{target_value}' - skipping") | |
498 skipped_plots += 1 | |
499 continue | |
500 | |
501 # Check if labels are numeric and convert | |
502 true_values = pd.to_numeric(target_labels['known_label'], errors='coerce') | |
503 predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce') | |
504 | |
505 if true_values.isna().all() or predicted_values.isna().all(): | |
506 print(f"No valid numeric values found for known or predicted labels in '{target_value}'") | |
507 skipped_plots += 1 | |
508 continue | |
509 | |
510 try: | |
511 print(f" Generating scatter plot for '{target_value}'...") | |
512 fig = plot_scatter(true_values, predicted_values) | |
513 | |
514 # Create output filename with target value | |
515 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
516 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
517 | |
518 output_path = output_dir / output_filename | |
519 print(f" Saving scatter plot to: {output_path.absolute()}") | |
520 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | |
521 | |
522 successful_plots += 1 | |
523 print(f" Scatter plot for '{target_value}' generated successfully!") | |
524 | |
525 except Exception as e: | |
526 print(f" Error generating plot for '{target_value}': {str(e)}") | |
527 skipped_plots += 1 | |
528 | |
529 # Summary | |
530 print(" Summary:") | |
531 print(f" Successfully generated: {successful_plots} plots") | |
532 print(f" Skipped: {skipped_plots} plots") | |
533 | |
534 if successful_plots == 0: | |
535 raise ValueError("No scatter plots could be generated. Check your data and target values.") | |
536 | |
537 print("Scatter plot generation completed!") | |
538 | |
539 if not is_flexynesis_format: | |
540 print("Labels are not in flexynesis format (Custom labels)") | |
541 | |
542 if not args.true_label or not args.predicted_label: | |
543 raise ValueError("For custom labels, please specify --true_label and --predicted_label arguments.") | |
544 | |
545 # Check if labels are numeric and convert | |
546 true_values = pd.to_numeric(labels[args.true_label], errors='coerce') | |
547 predicted_values = pd.to_numeric(labels[args.predicted_label], errors='coerce') | |
548 | |
549 if true_values.isna().all() or predicted_values.isna().all(): | |
550 print("No valid numeric values found for known or predicted labels") | |
551 | |
552 try: | |
553 print(" Generating scatter plot...") | |
554 fig = plot_scatter(true_values, predicted_values) | |
555 | |
556 # Create output filename with target value | |
557 output_filename = f"{output_name_base}.{args.format}" | |
558 | |
559 output_path = output_dir / output_filename | |
560 print(f" Saving scatter plot to: {output_path.absolute()}") | |
561 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | |
562 | |
563 except Exception as e: | |
564 print(f" Error generating plot: {str(e)}") | |
565 | |
566 print("Scatter plot generation completed!") | |
567 | |
568 | |
569 def generate_label_concordance_heatmap(labels, args, output_dir, output_name_base): | |
570 """Generate label concordance heatmap""" | |
571 print("Generating label concordance heatmaps...") | |
572 | |
573 # Check if this is the specific format with sample_id, known_label, predicted_label | |
574 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] | |
575 is_flexynesis_format = all(col in labels.columns for col in required_cols) | |
576 | |
577 if is_flexynesis_format: | |
578 # Parse target values from comma-separated string | |
579 if args.target_value: | |
580 target_values = [val.strip() for val in args.target_value.split(',')] | |
581 else: | |
582 # If no target values specified, use all unique variables | |
583 target_values = labels['variable'].unique().tolist() | |
584 | |
585 print(f"Processing target values: {target_values}") | |
586 | |
587 for target_value in target_values: | |
588 print(f"\nProcessing target value: '{target_value}'") | |
589 | |
590 # Filter labels for the current target value | |
591 target_labels = labels[labels['variable'] == target_value] | |
592 | |
593 if target_labels.empty: | |
594 print(f" Warning: No data found for target value '{target_value}' - skipping") | |
595 continue | |
596 | |
597 true_values = target_labels['known_label'].tolist() | |
598 predicted_values = target_labels['predicted_label'].tolist() | |
599 | |
600 try: | |
601 print(f" Generating heatmap for '{target_value}'...") | |
602 fig = plot_label_concordance_heatmap(true_values, predicted_values) | |
603 plt.close(fig) | |
604 | |
605 # Create output filename with target value | |
606 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
607 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
608 | |
609 output_path = output_dir / output_filename | |
610 print(f" Saving heatmap to: {output_path.absolute()}") | |
611 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight') | |
612 | |
613 except Exception as e: | |
614 print(f" Error generating heatmap for '{target_value}': {str(e)}") | |
615 continue | |
616 | |
617 print("Label concordance heatmap generated successfully!") | |
618 | |
619 if not is_flexynesis_format: | |
620 print("Labels are not in flexynesis format (Custom labels)") | |
621 | |
622 if not args.true_label or not args.predicted_label: | |
623 raise ValueError("For custom labels, please specify --true_label and --predicted_label arguments.") | |
624 | |
625 true_values = labels[args.true_label].tolist() | |
626 predicted_values = labels[args.predicted_label].tolist() | |
627 | |
628 try: | |
629 print(" Generating heatmap for...") | |
630 fig = plot_label_concordance_heatmap(true_values, predicted_values) | |
631 plt.close(fig) | |
632 | |
633 # Create output filename with target value | |
634 output_filename = f"{output_name_base}.{args.format}" | |
635 | |
636 output_path = output_dir / output_filename | |
637 print(f" Saving heatmap to: {output_path.absolute()}") | |
638 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight') | |
639 | |
640 except Exception as e: | |
641 print(f" Error generating heatmap': {str(e)}") | |
642 | |
643 print("Label concordance heatmap generated successfully!") | |
644 | |
645 | |
646 def generate_pr_curves(labels, args, output_dir, output_name_base): | |
647 """Generate precision-recall curves""" | |
648 print("Generating precision-recall curves...") | |
649 | |
650 # Check if this is the specific format with sample_id, known_label, predicted_label | |
651 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] | |
652 is_flexynesis_format = all(col in labels.columns for col in required_cols) | |
653 | |
654 if not is_flexynesis_format: | |
655 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.") | |
656 | |
462 # Parse target values from comma-separated string | 657 # Parse target values from comma-separated string |
463 if args.target_value: | 658 if args.target_value: |
464 target_values = [val.strip() for val in args.target_value.split(',')] | 659 target_values = [val.strip() for val in args.target_value.split(',')] |
465 else: | 660 else: |
466 # If no target values specified, use all unique variables | 661 # If no target values specified, use all unique variables |
467 target_values = labels['variable'].unique().tolist() | 662 target_values = labels['variable'].unique().tolist() |
468 | 663 |
469 print(f"Processing target values: {target_values}") | 664 print(f"Processing target values: {target_values}") |
470 | 665 |
471 successful_plots = 0 | |
472 skipped_plots = 0 | |
473 | |
474 for target_value in target_values: | 666 for target_value in target_values: |
475 print(f"\nProcessing target value: '{target_value}'") | 667 print(f"\nProcessing target value: '{target_value}'") |
476 | 668 |
477 # Filter labels for the current target value | 669 # Filter labels for the current target value |
478 target_labels = labels[labels['variable'] == target_value] | 670 target_labels = labels[labels['variable'] == target_value] |
479 | 671 |
480 if target_labels.empty: | 672 # Check if this is a regression problem (no class probabilities) |
481 print(f" Warning: No data found for target value '{target_value}' - skipping") | 673 prob_columns = target_labels['class_label'].unique() |
482 skipped_plots += 1 | 674 non_na_probs = target_labels['probability'].notna().sum() |
483 continue | 675 |
484 | 676 print(f" Class labels found: {list(prob_columns)}") |
485 # Check if labels are numeric and convert | 677 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") |
486 true_values = pd.to_numeric(target_labels['known_label'], errors='coerce') | 678 |
487 predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce') | 679 # If most probabilities are NaN, this is likely a regression problem |
488 | 680 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities |
489 if true_values.isna().all() or predicted_values.isna().all(): | 681 print(" Detected regression problem - precision-recall curves not applicable") |
490 print(f"No valid numeric values found for known or predicted labels in '{target_value}'") | 682 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") |
491 skipped_plots += 1 | 683 continue |
684 | |
685 # Debug: Check data quality | |
686 total_rows = len(target_labels) | |
687 missing_labels = target_labels['known_label'].isna().sum() | |
688 missing_probs = target_labels['probability'].isna().sum() | |
689 unique_samples = target_labels['sample_id'].nunique() | |
690 unique_classes = target_labels['class_label'].nunique() | |
691 | |
692 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes") | |
693 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability") | |
694 | |
695 if missing_labels > 0: | |
696 print(f" Warning: Found {missing_labels} missing known_label values") | |
697 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5] | |
698 print(f" Sample IDs with missing known_label: {list(missing_samples)}") | |
699 | |
700 # Remove rows with missing known_label | |
701 target_labels = target_labels.dropna(subset=['known_label']) | |
702 if target_labels.empty: | |
703 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping") | |
704 continue | |
705 | |
706 # 1. Pivot to wide format | |
707 prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability') | |
708 | |
709 print(f" After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes") | |
710 print(f" Class columns: {list(prob_df.columns)}") | |
711 | |
712 # Check for NaN values in probability data | |
713 nan_counts = prob_df.isna().sum() | |
714 if nan_counts.any(): | |
715 print(f" NaN counts per class: {dict(nan_counts)}") | |
716 print(f" Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}") | |
717 | |
718 # Drop only rows where ALL probabilities are NaN | |
719 all_nan_rows = prob_df.isna().all(axis=1) | |
720 if all_nan_rows.any(): | |
721 print(f" Dropping {all_nan_rows.sum()} samples with all NaN probabilities") | |
722 prob_df = prob_df[~all_nan_rows] | |
723 | |
724 remaining_nans = prob_df.isna().sum().sum() | |
725 if remaining_nans > 0: | |
726 print(f" Warning: {remaining_nans} individual NaN values remain - filling with 0") | |
727 prob_df = prob_df.fillna(0) | |
728 | |
729 if prob_df.empty: | |
730 print(f" Error: No valid probability data remaining for '{target_value}' - skipping") | |
731 continue | |
732 | |
733 # 2. Get true labels | |
734 true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id') | |
735 | |
736 # 3. Align indices - only keep samples that exist in both datasets | |
737 common_indices = prob_df.index.intersection(true_labels_df.index) | |
738 if len(common_indices) == 0: | |
739 print(f" Error: No common sample_ids between probability and true label data for '{target_value}' - skipping") | |
740 continue | |
741 | |
742 print(f" Found {len(common_indices)} samples with both probability and true label data") | |
743 | |
744 # Filter both datasets to common indices | |
745 prob_df_aligned = prob_df.loc[common_indices] | |
746 y_true = true_labels_df.loc[common_indices]['known_label'] | |
747 | |
748 # 4. Final check for NaN values | |
749 if y_true.isna().any(): | |
750 print(f" Error: True labels still contain NaN after alignment for '{target_value}' - skipping") | |
751 continue | |
752 | |
753 if prob_df_aligned.isna().any().any(): | |
754 print(f" Error: Probability data still contains NaN after alignment for '{target_value}' - skipping") | |
755 continue | |
756 | |
757 # 5. Convert categorical labels to integer labels | |
758 # Create a mapping from class names to integers | |
759 class_names = list(prob_df_aligned.columns) | |
760 class_to_int = {class_name: i for i, class_name in enumerate(class_names)} | |
761 | |
762 print(f" Class mapping: {class_to_int}") | |
763 | |
764 # Convert true labels to integers | |
765 y_true_np = y_true.map(class_to_int).to_numpy() | |
766 y_probs_np = prob_df_aligned.to_numpy() | |
767 | |
768 print(f" Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}") | |
769 print(f" Unique true labels (integers): {set(y_true_np)}") | |
770 print(f" Class labels (columns): {class_names}") | |
771 print(f" Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}") | |
772 | |
773 # Check for any unmapped labels (will be NaN) | |
774 if pd.isna(y_true_np).any(): | |
775 print(" Error: Some true labels could not be mapped to class columns") | |
776 unmapped_labels = set(y_true[y_true.map(class_to_int).isna()]) | |
777 print(f" Unmapped labels: {unmapped_labels}") | |
778 print(f" Available classes: {class_names}") | |
492 continue | 779 continue |
493 | 780 |
494 try: | 781 try: |
495 print(f" Generating scatter plot for '{target_value}'...") | 782 print(f" Generating precision-recall curve for '{target_value}'...") |
496 fig = plot_scatter(true_values, predicted_values) | 783 fig = plot_pr_curves(y_true_np, y_probs_np) |
497 | 784 |
498 # Create output filename with target value | 785 # Create output filename with target value |
499 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | 786 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') |
500 if len(target_values) > 1: | 787 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" |
501 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
502 else: | |
503 output_filename = f"{output_name_base}.{args.format}" | |
504 | 788 |
505 output_path = output_dir / output_filename | 789 output_path = output_dir / output_filename |
506 print(f" Saving scatter plot to: {output_path.absolute()}") | 790 print(f" Saving precision-recall curve to: {output_path.absolute()}") |
507 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | 791 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') |
508 | 792 |
509 successful_plots += 1 | |
510 print(f" Scatter plot for '{target_value}' generated successfully!") | |
511 | |
512 except Exception as e: | 793 except Exception as e: |
513 print(f" Error generating plot for '{target_value}': {str(e)}") | 794 print(f" Error generating precision-recall curve for '{target_value}': {str(e)}") |
514 skipped_plots += 1 | 795 print(f" Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}") |
515 | 796 print(f" Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}") |
516 # Summary | 797 continue |
517 print(" Summary:") | 798 |
518 print(f" Successfully generated: {successful_plots} plots") | 799 print("Precision-recall curves generated successfully!") |
519 print(f" Skipped: {skipped_plots} plots") | 800 |
520 | 801 |
521 if successful_plots == 0: | 802 def generate_roc_curves(labels, args, output_dir, output_name_base): |
522 raise ValueError("No scatter plots could be generated. Check your data and target values.") | 803 """Generate ROC curves""" |
523 | 804 print("Generating ROC curves...") |
524 print("Scatter plot generation completed!") | 805 |
525 | 806 # Check if this is the specific format with sample_id, known_label, predicted_label |
526 | 807 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] |
527 def generate_label_concordance_heatmap(labels, args, output_dir, output_name_base): | 808 is_flexynesis_format = all(col in labels.columns for col in required_cols) |
528 """Generate label concordance heatmap""" | 809 |
529 print("Generating label concordance heatmaps...") | 810 if not is_flexynesis_format: |
811 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.") | |
530 | 812 |
531 # Parse target values from comma-separated string | 813 # Parse target values from comma-separated string |
532 if args.target_value: | 814 if args.target_value: |
533 target_values = [val.strip() for val in args.target_value.split(',')] | 815 target_values = [val.strip() for val in args.target_value.split(',')] |
534 else: | 816 else: |
541 print(f"\nProcessing target value: '{target_value}'") | 823 print(f"\nProcessing target value: '{target_value}'") |
542 | 824 |
543 # Filter labels for the current target value | 825 # Filter labels for the current target value |
544 target_labels = labels[labels['variable'] == target_value] | 826 target_labels = labels[labels['variable'] == target_value] |
545 | 827 |
546 if target_labels.empty: | |
547 print(f" Warning: No data found for target value '{target_value}' - skipping") | |
548 continue | |
549 | |
550 true_values = target_labels['known_label'].tolist() | |
551 predicted_values = target_labels['predicted_label'].tolist() | |
552 | |
553 try: | |
554 print(f" Generating heatmap for '{target_value}'...") | |
555 fig = plot_label_concordance_heatmap(true_values, predicted_values) | |
556 plt.close(fig) | |
557 | |
558 # Create output filename with target value | |
559 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
560 if len(target_values) > 1: | |
561 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
562 else: | |
563 output_filename = f"{output_name_base}.{args.format}" | |
564 | |
565 output_path = output_dir / output_filename | |
566 print(f" Saving heatmap to: {output_path.absolute()}") | |
567 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight') | |
568 | |
569 except Exception as e: | |
570 print(f" Error generating heatmap for '{target_value}': {str(e)}") | |
571 continue | |
572 | |
573 print("Label concordance heatmap generated successfully!") | |
574 | |
575 | |
576 def generate_pr_curves(labels, args, output_dir, output_name_base): | |
577 """Generate precision-recall curves""" | |
578 print("Generating precision-recall curves...") | |
579 | |
580 # Parse target values from comma-separated string | |
581 if args.target_value: | |
582 target_values = [val.strip() for val in args.target_value.split(',')] | |
583 else: | |
584 # If no target values specified, use all unique variables | |
585 target_values = labels['variable'].unique().tolist() | |
586 | |
587 print(f"Processing target values: {target_values}") | |
588 | |
589 for target_value in target_values: | |
590 print(f"\nProcessing target value: '{target_value}'") | |
591 | |
592 # Filter labels for the current target value | |
593 target_labels = labels[labels['variable'] == target_value] | |
594 | |
595 # Check if this is a regression problem (no class probabilities) | 828 # Check if this is a regression problem (no class probabilities) |
596 prob_columns = target_labels['class_label'].unique() | 829 prob_columns = target_labels['class_label'].unique() |
597 non_na_probs = target_labels['probability'].notna().sum() | 830 non_na_probs = target_labels['probability'].notna().sum() |
598 | 831 |
599 print(f" Class labels found: {list(prob_columns)}") | 832 print(f" Class labels found: {list(prob_columns)}") |
600 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") | 833 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") |
601 | 834 |
602 # If most probabilities are NaN, this is likely a regression problem | 835 # If most probabilities are NaN, this is likely a regression problem |
603 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities | 836 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities |
604 print(" Detected regression problem - precision-recall curves not applicable") | 837 print(" Detected regression problem - ROC curves not applicable") |
605 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") | 838 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") |
606 continue | 839 continue |
607 | 840 |
608 # Debug: Check data quality | 841 # Debug: Check data quality |
609 total_rows = len(target_labels) | 842 total_rows = len(target_labels) |
700 print(f" Unmapped labels: {unmapped_labels}") | 933 print(f" Unmapped labels: {unmapped_labels}") |
701 print(f" Available classes: {class_names}") | 934 print(f" Available classes: {class_names}") |
702 continue | 935 continue |
703 | 936 |
704 try: | 937 try: |
705 print(f" Generating precision-recall curve for '{target_value}'...") | 938 print(f" Generating ROC curve for '{target_value}'...") |
706 fig = plot_pr_curves(y_true_np, y_probs_np) | 939 fig = plot_roc_curves(y_true_np, y_probs_np) |
707 | 940 |
708 # Create output filename with target value | 941 # Create output filename with target value |
709 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | 942 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') |
710 if len(target_values) > 1: | 943 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" |
711 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
712 else: | |
713 output_filename = f"{output_name_base}.{args.format}" | |
714 | |
715 output_path = output_dir / output_filename | |
716 print(f" Saving precision-recall curve to: {output_path.absolute()}") | |
717 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | |
718 | |
719 except Exception as e: | |
720 print(f" Error generating precision-recall curve for '{target_value}': {str(e)}") | |
721 print(f" Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}") | |
722 print(f" Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}") | |
723 continue | |
724 | |
725 print("Precision-recall curves generated successfully!") | |
726 | |
727 | |
728 def generate_roc_curves(labels, args, output_dir, output_name_base): | |
729 """Generate ROC curves""" | |
730 print("Generating ROC curves...") | |
731 | |
732 # Parse target values from comma-separated string | |
733 if args.target_value: | |
734 target_values = [val.strip() for val in args.target_value.split(',')] | |
735 else: | |
736 # If no target values specified, use all unique variables | |
737 target_values = labels['variable'].unique().tolist() | |
738 | |
739 print(f"Processing target values: {target_values}") | |
740 | |
741 for target_value in target_values: | |
742 print(f"\nProcessing target value: '{target_value}'") | |
743 | |
744 # Filter labels for the current target value | |
745 target_labels = labels[labels['variable'] == target_value] | |
746 | |
747 # Check if this is a regression problem (no class probabilities) | |
748 prob_columns = target_labels['class_label'].unique() | |
749 non_na_probs = target_labels['probability'].notna().sum() | |
750 | |
751 print(f" Class labels found: {list(prob_columns)}") | |
752 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") | |
753 | |
754 # If most probabilities are NaN, this is likely a regression problem | |
755 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities | |
756 print(" Detected regression problem - ROC curves not applicable") | |
757 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") | |
758 continue | |
759 | |
760 # Debug: Check data quality | |
761 total_rows = len(target_labels) | |
762 missing_labels = target_labels['known_label'].isna().sum() | |
763 missing_probs = target_labels['probability'].isna().sum() | |
764 unique_samples = target_labels['sample_id'].nunique() | |
765 unique_classes = target_labels['class_label'].nunique() | |
766 | |
767 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes") | |
768 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability") | |
769 | |
770 if missing_labels > 0: | |
771 print(f" Warning: Found {missing_labels} missing known_label values") | |
772 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5] | |
773 print(f" Sample IDs with missing known_label: {list(missing_samples)}") | |
774 | |
775 # Remove rows with missing known_label | |
776 target_labels = target_labels.dropna(subset=['known_label']) | |
777 if target_labels.empty: | |
778 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping") | |
779 continue | |
780 | |
781 # 1. Pivot to wide format | |
782 prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability') | |
783 | |
784 print(f" After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes") | |
785 print(f" Class columns: {list(prob_df.columns)}") | |
786 | |
787 # Check for NaN values in probability data | |
788 nan_counts = prob_df.isna().sum() | |
789 if nan_counts.any(): | |
790 print(f" NaN counts per class: {dict(nan_counts)}") | |
791 print(f" Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}") | |
792 | |
793 # Drop only rows where ALL probabilities are NaN | |
794 all_nan_rows = prob_df.isna().all(axis=1) | |
795 if all_nan_rows.any(): | |
796 print(f" Dropping {all_nan_rows.sum()} samples with all NaN probabilities") | |
797 prob_df = prob_df[~all_nan_rows] | |
798 | |
799 remaining_nans = prob_df.isna().sum().sum() | |
800 if remaining_nans > 0: | |
801 print(f" Warning: {remaining_nans} individual NaN values remain - filling with 0") | |
802 prob_df = prob_df.fillna(0) | |
803 | |
804 if prob_df.empty: | |
805 print(f" Error: No valid probability data remaining for '{target_value}' - skipping") | |
806 continue | |
807 | |
808 # 2. Get true labels | |
809 true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id') | |
810 | |
811 # 3. Align indices - only keep samples that exist in both datasets | |
812 common_indices = prob_df.index.intersection(true_labels_df.index) | |
813 if len(common_indices) == 0: | |
814 print(f" Error: No common sample_ids between probability and true label data for '{target_value}' - skipping") | |
815 continue | |
816 | |
817 print(f" Found {len(common_indices)} samples with both probability and true label data") | |
818 | |
819 # Filter both datasets to common indices | |
820 prob_df_aligned = prob_df.loc[common_indices] | |
821 y_true = true_labels_df.loc[common_indices]['known_label'] | |
822 | |
823 # 4. Final check for NaN values | |
824 if y_true.isna().any(): | |
825 print(f" Error: True labels still contain NaN after alignment for '{target_value}' - skipping") | |
826 continue | |
827 | |
828 if prob_df_aligned.isna().any().any(): | |
829 print(f" Error: Probability data still contains NaN after alignment for '{target_value}' - skipping") | |
830 continue | |
831 | |
832 # 5. Convert categorical labels to integer labels | |
833 # Create a mapping from class names to integers | |
834 class_names = list(prob_df_aligned.columns) | |
835 class_to_int = {class_name: i for i, class_name in enumerate(class_names)} | |
836 | |
837 print(f" Class mapping: {class_to_int}") | |
838 | |
839 # Convert true labels to integers | |
840 y_true_np = y_true.map(class_to_int).to_numpy() | |
841 y_probs_np = prob_df_aligned.to_numpy() | |
842 | |
843 print(f" Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}") | |
844 print(f" Unique true labels (integers): {set(y_true_np)}") | |
845 print(f" Class labels (columns): {class_names}") | |
846 print(f" Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}") | |
847 | |
848 # Check for any unmapped labels (will be NaN) | |
849 if pd.isna(y_true_np).any(): | |
850 print(" Error: Some true labels could not be mapped to class columns") | |
851 unmapped_labels = set(y_true[y_true.map(class_to_int).isna()]) | |
852 print(f" Unmapped labels: {unmapped_labels}") | |
853 print(f" Available classes: {class_names}") | |
854 continue | |
855 | |
856 try: | |
857 print(f" Generating ROC curve for '{target_value}'...") | |
858 fig = plot_roc_curves(y_true_np, y_probs_np) | |
859 | |
860 # Create output filename with target value | |
861 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
862 if len(target_values) > 1: | |
863 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
864 else: | |
865 output_filename = f"{output_name_base}.{args.format}" | |
866 | 944 |
867 output_path = output_dir / output_filename | 945 output_path = output_dir / output_filename |
868 print(f" Saving ROC curve to: {output_path.absolute()}") | 946 print(f" Saving ROC curve to: {output_path.absolute()}") |
869 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | 947 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') |
870 | 948 |
877 print("ROC curves generated successfully!") | 955 print("ROC curves generated successfully!") |
878 | 956 |
879 | 957 |
880 def generate_box_plots(labels, args, output_dir, output_name_base): | 958 def generate_box_plots(labels, args, output_dir, output_name_base): |
881 """Generate box plots for model predictions""" | 959 """Generate box plots for model predictions""" |
960 | |
961 # Check if this is the specific format with sample_id, known_label, predicted_label | |
962 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] | |
963 is_flexynesis_format = all(col in labels.columns for col in required_cols) | |
964 | |
965 if not is_flexynesis_format: | |
966 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.") | |
882 | 967 |
883 print("Generating box plots...") | 968 print("Generating box plots...") |
884 | 969 |
885 # Parse target values from comma-separated string | 970 # Parse target values from comma-separated string |
886 if args.target_value: | 971 if args.target_value: |
991 # Arguments for dimensionality reduction | 1076 # Arguments for dimensionality reduction |
992 parser.add_argument("--embeddings", type=str, | 1077 parser.add_argument("--embeddings", type=str, |
993 help="Path to input data embeddings file (CSV or tabular format). Required for dimred plots.") | 1078 help="Path to input data embeddings file (CSV or tabular format). Required for dimred plots.") |
994 parser.add_argument("--method", type=str, default='pca', choices=['pca', 'umap'], | 1079 parser.add_argument("--method", type=str, default='pca', choices=['pca', 'umap'], |
995 help="Transformation method ('pca' or 'umap'). Default is 'pca'. Used for dimred plots.") | 1080 help="Transformation method ('pca' or 'umap'). Default is 'pca'. Used for dimred plots.") |
1081 parser.add_argument("--color", type=str, default=None, | |
1082 help="User-defined color for the plot.") | |
996 | 1083 |
997 # Arguments for Kaplan-Meier | 1084 # Arguments for Kaplan-Meier |
998 parser.add_argument("--survival_data", type=str, | 1085 parser.add_argument("--survival_data", type=str, |
999 help="Path to survival data file with columns: duration and event. Required for kaplan_meier plots.") | 1086 help="Path to survival data file with columns: duration and event. Required for kaplan_meier plots.") |
1000 parser.add_argument("--surv_time_var", type=str, required=False, | 1087 parser.add_argument("--surv_time_var", type=str, required=False, |
1001 help="Column name for survival time") | 1088 help="Column name for survival time") |
1002 parser.add_argument("--surv_event_var", type=str, required=False, | 1089 parser.add_argument("--surv_event_var", type=str, required=False, |
1003 help="Column name for survival event") | 1090 help="Column name for survival event") |
1004 parser.add_argument("--event_value", type=str, required=False, | |
1005 help="Value in event column that represents an event (e.g., 'DECEASED')") | |
1006 | 1091 |
1007 # Arguments for Cox analysis | 1092 # Arguments for Cox analysis |
1008 parser.add_argument("--model", type=str, | 1093 parser.add_argument("--important_features", type=str, |
1009 help="Path to trained flexynesis model (pickle file). Required for cox plots.") | 1094 help="Path to calculated feature importance file. Required for cox plots.") |
1010 parser.add_argument("--clinical_train", type=str, | 1095 parser.add_argument("--clinical_train", type=str, |
1011 help="Path to training dataset (pickle file). Required for cox plots.") | 1096 help="Path to training dataset (pickle file). Required for cox plots.") |
1012 parser.add_argument("--clinical_test", type=str, | 1097 parser.add_argument("--clinical_test", type=str, |
1013 help="Path to test dataset (pickle file). Required for cox plots.") | 1098 help="Path to test dataset (pickle file). Required for cox plots.") |
1014 parser.add_argument("--omics_train", type=str, default=None, | 1099 parser.add_argument("--omics_train", type=str, default=None, |
1023 help="If True, performs K-fold cross-validation and returns average C-index. Default is False") | 1108 help="If True, performs K-fold cross-validation and returns average C-index. Default is False") |
1024 parser.add_argument("--n_splits", type=int, default=5, | 1109 parser.add_argument("--n_splits", type=int, default=5, |
1025 help="Number of folds for cross-validation. Default is 5") | 1110 help="Number of folds for cross-validation. Default is 5") |
1026 parser.add_argument("--random_state", type=int, default=42, | 1111 parser.add_argument("--random_state", type=int, default=42, |
1027 help="Random seed for reproducibility. Default is 42") | 1112 help="Random seed for reproducibility. Default is 42") |
1113 parser.add_argument("--layer", type=str, default=None, | |
1114 help="Class label for filtering important features.") | |
1028 | 1115 |
1029 # Arguments for dimred, scatter plot, heatmap, PR curves, ROC curves, and box plots | 1116 # Arguments for dimred, scatter plot, heatmap, PR curves, ROC curves, and box plots |
1030 parser.add_argument("--target_value", type=str, default=None, | 1117 parser.add_argument("--target_value", type=str, default=None, |
1031 help="Target value for scatter plot.") | 1118 help="Target value for scatter plot.") |
1032 | 1119 |
1120 # Arguments for scatter plots and concordance heatmaps | |
1121 parser.add_argument("--true_label", type=str, default=None, | |
1122 help="Column name for true labels in scatter plots and concordance heatmaps.") | |
1123 parser.add_argument("--predicted_label", type=str, default=None, | |
1124 help="Column name for predicted labels in scatter plots and concordance heatmaps.") | |
1033 # Common arguments | 1125 # Common arguments |
1034 parser.add_argument("--output_dir", type=str, default='output', | 1126 parser.add_argument("--output_dir", type=str, default='output', |
1035 help="Output directory. Default is 'output'") | 1127 help="Output directory. Default is 'output'") |
1036 parser.add_argument("--output_name", type=str, default=None, | 1128 parser.add_argument("--output_name", type=str, default=None, |
1037 help="Output filename base") | 1129 help="Output filename base") |
1071 raise ValueError("--method is required for dimensionality reduction plots") | 1163 raise ValueError("--method is required for dimensionality reduction plots") |
1072 if not args.surv_time_var: | 1164 if not args.surv_time_var: |
1073 raise ValueError("--surv_time_var is required for Kaplan-Meier plots") | 1165 raise ValueError("--surv_time_var is required for Kaplan-Meier plots") |
1074 if not args.surv_event_var: | 1166 if not args.surv_event_var: |
1075 raise ValueError("--surv_event_var is required for Kaplan-Meier plots") | 1167 raise ValueError("--surv_event_var is required for Kaplan-Meier plots") |
1076 if not args.event_value: | |
1077 raise ValueError("--event_value is required for Kaplan-Meier plots") | |
1078 | 1168 |
1079 if args.plot_type in ['cox']: | 1169 if args.plot_type in ['cox']: |
1080 if not args.model: | 1170 if not args.important_features: |
1081 raise ValueError("--model is required when plot_type is 'cox'") | 1171 raise ValueError("--important_features is required when plot_type is 'cox'") |
1082 if not os.path.isfile(args.model): | 1172 if not os.path.isfile(args.important_features): |
1083 raise FileNotFoundError(f"Model file not found: {args.model}") | 1173 raise FileNotFoundError(f"Important features file not found: {args.important_features}") |
1084 if not args.clinical_train: | 1174 if not args.clinical_train: |
1085 raise ValueError("--clinical_train is required when plot_type is 'cox'") | 1175 raise ValueError("--clinical_train is required when plot_type is 'cox'") |
1086 if not os.path.isfile(args.clinical_train): | 1176 if not os.path.isfile(args.clinical_train): |
1087 raise FileNotFoundError(f"Training dataset file not found: {args.clinical_train}") | 1177 raise FileNotFoundError(f"Training dataset file not found: {args.clinical_train}") |
1088 if not args.clinical_test: | 1178 if not args.clinical_test: |
1100 if not args.surv_time_var: | 1190 if not args.surv_time_var: |
1101 raise ValueError("--surv_time_var is required for Cox plots") | 1191 raise ValueError("--surv_time_var is required for Cox plots") |
1102 if not args.surv_event_var: | 1192 if not args.surv_event_var: |
1103 raise ValueError("--surv_event_var is required for Cox plots") | 1193 raise ValueError("--surv_event_var is required for Cox plots") |
1104 if not args.clinical_variables: | 1194 if not args.clinical_variables: |
1105 raise ValueError("--clinical_variables is required for Cox plots") | 1195 print("--clinical_variables is not set for Cox plots") |
1106 if not isinstance(args.top_features, int) or args.top_features <= 0: | 1196 if not isinstance(args.top_features, int) or args.top_features <= 0: |
1107 raise ValueError("--top_features must be a positive integer") | 1197 raise ValueError("--top_features must be a positive integer") |
1108 if not args.event_value: | |
1109 raise ValueError("--event_value is required for Kaplan-Meier plots") | |
1110 if not args.crossval: | 1198 if not args.crossval: |
1111 args.crossval = False | 1199 args.crossval = False |
1112 if not isinstance(args.n_splits, int) or args.n_splits <= 0: | 1200 if not isinstance(args.n_splits, int) or args.n_splits <= 0: |
1113 raise ValueError("--n_splits must be a positive integer") | 1201 raise ValueError("--n_splits must be a positive integer") |
1114 if not isinstance(args.random_state, int): | 1202 if not isinstance(args.random_state, int): |
1115 raise ValueError("--random_state must be an integer") | 1203 raise ValueError("--random_state must be an integer") |
1204 if not args.layer: | |
1205 print("--layer is not specified, using all classes from labels") | |
1116 | 1206 |
1117 if args.plot_type in ['scatter']: | 1207 if args.plot_type in ['scatter']: |
1118 if not args.labels: | 1208 if not args.labels: |
1119 raise ValueError("--labels is required for scatter plots") | 1209 raise ValueError("--labels is required for scatter plots") |
1120 if not args.target_value: | 1210 if not args.target_value: |
1172 output_name_base = f"{embeddings_name}_{args.method}" | 1262 output_name_base = f"{embeddings_name}_{args.method}" |
1173 elif args.plot_type == 'kaplan_meier': | 1263 elif args.plot_type == 'kaplan_meier': |
1174 survival_name = Path(args.survival_data).stem | 1264 survival_name = Path(args.survival_data).stem |
1175 output_name_base = f"{survival_name}_km" | 1265 output_name_base = f"{survival_name}_km" |
1176 elif args.plot_type == 'cox': | 1266 elif args.plot_type == 'cox': |
1177 model_name = Path(args.model).stem | 1267 model_name = Path(args.important_features).stem |
1178 output_name_base = f"{model_name}_cox" | 1268 output_name_base = f"{model_name}_cox" |
1179 elif args.plot_type == 'scatter': | 1269 elif args.plot_type == 'scatter': |
1180 labels_name = Path(args.labels).stem | 1270 labels_name = Path(args.labels).stem |
1181 output_name_base = f"{labels_name}_scatter" | 1271 output_name_base = f"{labels_name}_scatter" |
1182 elif args.plot_type == 'concordance_heatmap': | 1272 elif args.plot_type == 'concordance_heatmap': |
1194 | 1284 |
1195 # Generate plots based on type | 1285 # Generate plots based on type |
1196 if args.plot_type in ['dimred']: | 1286 if args.plot_type in ['dimred']: |
1197 # Load labels | 1287 # Load labels |
1198 print(f"Loading labels from: {args.labels}") | 1288 print(f"Loading labels from: {args.labels}") |
1199 label_data = load_labels(args.labels) | 1289 labels = load_labels(args.labels) |
1200 # Load embeddings data | 1290 # Load embeddings data |
1201 print(f"Loading embeddings from: {args.embeddings}") | 1291 print(f"Loading embeddings from: {args.embeddings}") |
1202 embeddings, sample_names = load_embeddings(args.embeddings) | 1292 embeddings, sample_names = load_embeddings(args.embeddings) |
1203 print(f"embeddings shape: {embeddings.shape}") | 1293 print(f"embeddings shape: {embeddings.shape}") |
1204 | 1294 |
1205 # Match samples to embeddings | 1295 # Match samples to embeddings |
1206 matched_labels = match_samples_to_embeddings(sample_names, label_data) | 1296 matched_labels = match_samples_to_embeddings(sample_names, labels) |
1207 print(f"Successfully matched {len(matched_labels)} samples for dimensionality reduction") | 1297 print(f"Successfully matched {len(matched_labels)} samples for dimensionality reduction") |
1208 | 1298 print(f"Matched labels shape: {matched_labels.shape}") |
1299 print(f"Columns in matched labels: {matched_labels.columns.tolist()}") | |
1209 generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base) | 1300 generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base) |
1210 | 1301 |
1211 elif args.plot_type in ['kaplan_meier']: | 1302 elif args.plot_type in ['kaplan_meier']: |
1212 # Load labels | 1303 # Load labels |
1213 print(f"Loading labels from: {args.labels}") | 1304 print(f"Loading labels from: {args.labels}") |
1214 label_data = load_labels(args.labels) | 1305 labels = load_labels(args.labels) |
1215 # Load survival data | 1306 # Load survival data |
1216 print(f"Loading survival data from: {args.survival_data}") | 1307 print(f"Loading survival data from: {args.survival_data}") |
1217 survival_data = load_survival_data(args.survival_data) | 1308 survival_data = load_labels(args.survival_data) |
1218 print(f"Survival data shape: {survival_data.shape}") | 1309 print(f"Survival data shape: {survival_data.shape}") |
1219 | 1310 |
1220 generate_km_plots(survival_data, label_data, args, output_dir, output_name_base) | 1311 generate_km_plots(survival_data, labels, args, output_dir, output_name_base) |
1221 | 1312 |
1222 elif args.plot_type in ['cox']: | 1313 elif args.plot_type in ['cox']: |
1223 # Load model and datasets | 1314 # Load important_features and datasets |
1224 print(f"Loading model from: {args.model}") | 1315 print(f"Loading important features from: {args.important_features}") |
1225 model = load_model(args.model) | 1316 important_features = load_labels(args.important_features) |
1226 print(f"Loading training dataset from: {args.clinical_train}") | 1317 print(f"Loading training dataset from: {args.clinical_train}") |
1227 clinical_train = load_omics(args.clinical_train) | 1318 clinical_train = load_omics(args.clinical_train) |
1228 print(f"Loading test dataset from: {args.clinical_test}") | 1319 print(f"Loading test dataset from: {args.clinical_test}") |
1229 clinical_test = load_omics(args.clinical_test) | 1320 clinical_test = load_omics(args.clinical_test) |
1230 print(f"Loading training omics dataset from: {args.omics_train}") | 1321 print(f"Loading training omics dataset from: {args.omics_train}") |
1231 omics_train = load_omics(args.omics_train) | 1322 omics_train = load_omics(args.omics_train) |
1232 print(f"Loading test omics dataset from: {args.omics_test}") | 1323 print(f"Loading test omics dataset from: {args.omics_test}") |
1233 omics_test = load_omics(args.omics_test) | 1324 omics_test = load_omics(args.omics_test) |
1234 | 1325 |
1235 generate_cox_plots(model, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base) | 1326 generate_cox_plots(important_features, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base) |
1236 | 1327 |
1237 elif args.plot_type in ['scatter']: | 1328 elif args.plot_type in ['scatter']: |
1238 # Load labels | 1329 # Load labels |
1239 print(f"Loading labels from: {args.labels}") | 1330 print(f"Loading labels from: {args.labels}") |
1240 label_data = load_labels(args.labels) | 1331 labels = load_labels(args.labels) |
1241 | 1332 |
1242 generate_plot_scatter(label_data, args, output_dir, output_name_base) | 1333 generate_plot_scatter(labels, args, output_dir, output_name_base) |
1243 | 1334 |
1244 elif args.plot_type in ['concordance_heatmap']: | 1335 elif args.plot_type in ['concordance_heatmap']: |
1245 # Load labels | 1336 # Load labels |
1246 print(f"Loading labels from: {args.labels}") | 1337 print(f"Loading labels from: {args.labels}") |
1247 label_data = load_labels(args.labels) | 1338 labels = load_labels(args.labels) |
1248 | 1339 |
1249 generate_label_concordance_heatmap(label_data, args, output_dir, output_name_base) | 1340 generate_label_concordance_heatmap(labels, args, output_dir, output_name_base) |
1250 | 1341 |
1251 elif args.plot_type in ['pr_curve']: | 1342 elif args.plot_type in ['pr_curve']: |
1252 # Load labels | 1343 # Load labels |
1253 print(f"Loading labels from: {args.labels}") | 1344 print(f"Loading labels from: {args.labels}") |
1254 label_data = load_labels(args.labels) | 1345 labels = load_labels(args.labels) |
1255 | 1346 |
1256 generate_pr_curves(label_data, args, output_dir, output_name_base) | 1347 generate_pr_curves(labels, args, output_dir, output_name_base) |
1257 | 1348 |
1258 elif args.plot_type in ['roc_curve']: | 1349 elif args.plot_type in ['roc_curve']: |
1259 # Load labels | 1350 # Load labels |
1260 print(f"Loading labels from: {args.labels}") | 1351 print(f"Loading labels from: {args.labels}") |
1261 label_data = load_labels(args.labels) | 1352 labels = load_labels(args.labels) |
1262 | 1353 |
1263 generate_roc_curves(label_data, args, output_dir, output_name_base) | 1354 generate_roc_curves(labels, args, output_dir, output_name_base) |
1264 | 1355 |
1265 elif args.plot_type in ['box_plot']: | 1356 elif args.plot_type in ['box_plot']: |
1266 # Load labels | 1357 # Load labels |
1267 print(f"Loading labels from: {args.labels}") | 1358 print(f"Loading labels from: {args.labels}") |
1268 label_data = load_labels(args.labels) | 1359 labels = load_labels(args.labels) |
1269 | 1360 |
1270 generate_box_plots(label_data, args, output_dir, output_name_base) | 1361 generate_box_plots(labels, args, output_dir, output_name_base) |
1271 | 1362 |
1272 print("All plots generated successfully!") | 1363 print("All plots generated successfully!") |
1273 | 1364 |
1274 except (FileNotFoundError, ValueError, pd.errors.ParserError) as e: | 1365 except (FileNotFoundError, ValueError, pd.errors.ParserError) as e: |
1275 print(f"Error: {e}") | 1366 print(f"Error: {e}") |