Mercurial > repos > imgteam > crop_image
diff crop_image.py @ 1:457514bb6750 draft default tip
planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/crop_image/ commit 52a95105291e38f3410e347ed3b60d6acd6d5daa
| author | imgteam |
|---|---|
| date | Fri, 09 Jan 2026 14:54:49 +0000 |
| parents | f8bfa85cac4c |
| children |
line wrap: on
line diff
--- a/crop_image.py Fri Jun 06 12:46:50 2025 +0000 +++ b/crop_image.py Fri Jan 09 14:54:49 2026 +0000 @@ -1,8 +1,13 @@ import argparse import os +import dask.array as da +import giatools +import giatools.image import numpy as np -from giatools.image import Image + +# Fail early if an optional backend is not available +giatools.require_backend('omezarr') def crop_image( @@ -12,21 +17,39 @@ output_dir: str, skip_labels: frozenset[int], ): - image = Image.read(image_filepath) - labelmap = Image.read(labelmap_filepath) + axes = giatools.default_normalized_axes + image = giatools.Image.read(image_filepath, normalize_axes=axes) + labelmap = giatools.Image.read(labelmap_filepath, normalize_axes=axes) + + # Establish compatibility of multi-channel/frame/etc. images with single-channel/frame/etc. label maps + original_labelmap_shape = labelmap.shape + for image_s, labelmap_s, (axis_idx, axis) in zip(image.shape, labelmap.shape, enumerate(axes)): + if image_s > 1 and labelmap_s == 1 and axis not in 'YX': + target_shape = list(labelmap.shape) + target_shape[axis_idx] = image_s - if image.axes != labelmap.axes: - raise ValueError(f'Axes mismatch between image ({image.axes}) and label map ({labelmap.axes}).') + # Broadcast the labelmap data to the target shape without copying + if hasattr(labelmap.data, 'compute'): + labelmap.data = da.broadcast_to(labelmap.data, target_shape) # `data` is Dask array + else: + labelmap.data = np.broadcast_to(labelmap.data, target_shape, subok=True) # `data` is NumPy array - if image.data.shape != labelmap.data.shape: - raise ValueError(f'Shape mismatch between image ({image.data.shape}) and label map ({labelmap.data.shape}).') + # Validate that the shapes of the images are compatible + if image.shape != labelmap.shape: + labelmap_shape_str = str(original_labelmap_shape) + if labelmap.shape != original_labelmap_shape: + labelmap_shape_str = f'{labelmap_shape_str}, broadcasted to {labelmap.shape}' + raise ValueError( + f'Shape mismatch between image {image.shape} and label map {labelmap_shape_str}, with {axes} axes.', + ) - for label in np.unique(labelmap.data): + # Extract the image crops + for label in giatools.image._unique(labelmap.data): if label in skip_labels: continue roi_mask = (labelmap.data == label) roi = crop_image_to_mask(image.data, roi_mask) - roi_image = Image(roi, image.axes).normalize_axes_like(image.original_axes) + roi_image = giatools.Image(roi, image.axes).normalize_axes_like(image.original_axes) roi_image.write(os.path.join(output_dir, f'{label}.{output_ext}')) @@ -42,8 +65,18 @@ for dim in range(data.ndim): mask1d = mask.any(axis=tuple(i for i in range(mask.ndim) if i != dim)) mask1d_indices = np.where(mask1d)[0] + + # Convert `mask1d_indices` to a NumPy array if it is a Dask array + if hasattr(mask1d_indices, 'compute'): + mask1d_indices = mask1d_indices.compute() + mask1d_indices_cvxhull = np.arange(min(mask1d_indices), max(mask1d_indices) + 1) - data = data.take(axis=dim, indices=mask1d_indices_cvxhull) + + # Crop the `data` to the minimal bounding box + if hasattr(data, 'compute'): + data = da.take(data, axis=dim, indices=mask1d_indices_cvxhull) # `data` is a Dask array + else: + data = data.take(axis=dim, indices=mask1d_indices_cvxhull) # `data` is a NumPy array return data
