diff crop_image.py @ 1:457514bb6750 draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/crop_image/ commit 52a95105291e38f3410e347ed3b60d6acd6d5daa
author imgteam
date Fri, 09 Jan 2026 14:54:49 +0000
parents f8bfa85cac4c
children
line wrap: on
line diff
--- a/crop_image.py	Fri Jun 06 12:46:50 2025 +0000
+++ b/crop_image.py	Fri Jan 09 14:54:49 2026 +0000
@@ -1,8 +1,13 @@
 import argparse
 import os
 
+import dask.array as da
+import giatools
+import giatools.image
 import numpy as np
-from giatools.image import Image
+
+# Fail early if an optional backend is not available
+giatools.require_backend('omezarr')
 
 
 def crop_image(
@@ -12,21 +17,39 @@
     output_dir: str,
     skip_labels: frozenset[int],
 ):
-    image = Image.read(image_filepath)
-    labelmap = Image.read(labelmap_filepath)
+    axes = giatools.default_normalized_axes
+    image = giatools.Image.read(image_filepath, normalize_axes=axes)
+    labelmap = giatools.Image.read(labelmap_filepath, normalize_axes=axes)
+
+    # Establish compatibility of multi-channel/frame/etc. images with single-channel/frame/etc. label maps
+    original_labelmap_shape = labelmap.shape
+    for image_s, labelmap_s, (axis_idx, axis) in zip(image.shape, labelmap.shape, enumerate(axes)):
+        if image_s > 1 and labelmap_s == 1 and axis not in 'YX':
+            target_shape = list(labelmap.shape)
+            target_shape[axis_idx] = image_s
 
-    if image.axes != labelmap.axes:
-        raise ValueError(f'Axes mismatch between image ({image.axes}) and label map ({labelmap.axes}).')
+            # Broadcast the labelmap data to the target shape without copying
+            if hasattr(labelmap.data, 'compute'):
+                labelmap.data = da.broadcast_to(labelmap.data, target_shape)  # `data` is Dask array
+            else:
+                labelmap.data = np.broadcast_to(labelmap.data, target_shape, subok=True)  # `data` is NumPy array
 
-    if image.data.shape != labelmap.data.shape:
-        raise ValueError(f'Shape mismatch between image ({image.data.shape}) and label map ({labelmap.data.shape}).')
+    # Validate that the shapes of the images are compatible
+    if image.shape != labelmap.shape:
+        labelmap_shape_str = str(original_labelmap_shape)
+        if labelmap.shape != original_labelmap_shape:
+            labelmap_shape_str = f'{labelmap_shape_str}, broadcasted to {labelmap.shape}'
+        raise ValueError(
+            f'Shape mismatch between image {image.shape} and label map {labelmap_shape_str}, with {axes} axes.',
+        )
 
-    for label in np.unique(labelmap.data):
+    # Extract the image crops
+    for label in giatools.image._unique(labelmap.data):
         if label in skip_labels:
             continue
         roi_mask = (labelmap.data == label)
         roi = crop_image_to_mask(image.data, roi_mask)
-        roi_image = Image(roi, image.axes).normalize_axes_like(image.original_axes)
+        roi_image = giatools.Image(roi, image.axes).normalize_axes_like(image.original_axes)
         roi_image.write(os.path.join(output_dir, f'{label}.{output_ext}'))
 
 
@@ -42,8 +65,18 @@
     for dim in range(data.ndim):
         mask1d = mask.any(axis=tuple(i for i in range(mask.ndim) if i != dim))
         mask1d_indices = np.where(mask1d)[0]
+
+        # Convert `mask1d_indices` to a NumPy array if it is a Dask array
+        if hasattr(mask1d_indices, 'compute'):
+            mask1d_indices = mask1d_indices.compute()
+
         mask1d_indices_cvxhull = np.arange(min(mask1d_indices), max(mask1d_indices) + 1)
-        data = data.take(axis=dim, indices=mask1d_indices_cvxhull)
+
+        # Crop the `data` to the minimal bounding box
+        if hasattr(data, 'compute'):
+            data = da.take(data, axis=dim, indices=mask1d_indices_cvxhull)  # `data` is a Dask array
+        else:
+            data = data.take(axis=dim, indices=mask1d_indices_cvxhull)  # `data` is a NumPy array
 
     return data