Mercurial > repos > imgteam > crop_image
comparison crop_image.py @ 1:457514bb6750 draft default tip
planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/crop_image/ commit 52a95105291e38f3410e347ed3b60d6acd6d5daa
| author | imgteam |
|---|---|
| date | Fri, 09 Jan 2026 14:54:49 +0000 |
| parents | f8bfa85cac4c |
| children |
comparison
equal
deleted
inserted
replaced
| 0:f8bfa85cac4c | 1:457514bb6750 |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 import os | 2 import os |
| 3 | 3 |
| 4 import dask.array as da | |
| 5 import giatools | |
| 6 import giatools.image | |
| 4 import numpy as np | 7 import numpy as np |
| 5 from giatools.image import Image | 8 |
| 9 # Fail early if an optional backend is not available | |
| 10 giatools.require_backend('omezarr') | |
| 6 | 11 |
| 7 | 12 |
| 8 def crop_image( | 13 def crop_image( |
| 9 image_filepath: str, | 14 image_filepath: str, |
| 10 labelmap_filepath: str, | 15 labelmap_filepath: str, |
| 11 output_ext: str, | 16 output_ext: str, |
| 12 output_dir: str, | 17 output_dir: str, |
| 13 skip_labels: frozenset[int], | 18 skip_labels: frozenset[int], |
| 14 ): | 19 ): |
| 15 image = Image.read(image_filepath) | 20 axes = giatools.default_normalized_axes |
| 16 labelmap = Image.read(labelmap_filepath) | 21 image = giatools.Image.read(image_filepath, normalize_axes=axes) |
| 22 labelmap = giatools.Image.read(labelmap_filepath, normalize_axes=axes) | |
| 17 | 23 |
| 18 if image.axes != labelmap.axes: | 24 # Establish compatibility of multi-channel/frame/etc. images with single-channel/frame/etc. label maps |
| 19 raise ValueError(f'Axes mismatch between image ({image.axes}) and label map ({labelmap.axes}).') | 25 original_labelmap_shape = labelmap.shape |
| 26 for image_s, labelmap_s, (axis_idx, axis) in zip(image.shape, labelmap.shape, enumerate(axes)): | |
| 27 if image_s > 1 and labelmap_s == 1 and axis not in 'YX': | |
| 28 target_shape = list(labelmap.shape) | |
| 29 target_shape[axis_idx] = image_s | |
| 20 | 30 |
| 21 if image.data.shape != labelmap.data.shape: | 31 # Broadcast the labelmap data to the target shape without copying |
| 22 raise ValueError(f'Shape mismatch between image ({image.data.shape}) and label map ({labelmap.data.shape}).') | 32 if hasattr(labelmap.data, 'compute'): |
| 33 labelmap.data = da.broadcast_to(labelmap.data, target_shape) # `data` is Dask array | |
| 34 else: | |
| 35 labelmap.data = np.broadcast_to(labelmap.data, target_shape, subok=True) # `data` is NumPy array | |
| 23 | 36 |
| 24 for label in np.unique(labelmap.data): | 37 # Validate that the shapes of the images are compatible |
| 38 if image.shape != labelmap.shape: | |
| 39 labelmap_shape_str = str(original_labelmap_shape) | |
| 40 if labelmap.shape != original_labelmap_shape: | |
| 41 labelmap_shape_str = f'{labelmap_shape_str}, broadcasted to {labelmap.shape}' | |
| 42 raise ValueError( | |
| 43 f'Shape mismatch between image {image.shape} and label map {labelmap_shape_str}, with {axes} axes.', | |
| 44 ) | |
| 45 | |
| 46 # Extract the image crops | |
| 47 for label in giatools.image._unique(labelmap.data): | |
| 25 if label in skip_labels: | 48 if label in skip_labels: |
| 26 continue | 49 continue |
| 27 roi_mask = (labelmap.data == label) | 50 roi_mask = (labelmap.data == label) |
| 28 roi = crop_image_to_mask(image.data, roi_mask) | 51 roi = crop_image_to_mask(image.data, roi_mask) |
| 29 roi_image = Image(roi, image.axes).normalize_axes_like(image.original_axes) | 52 roi_image = giatools.Image(roi, image.axes).normalize_axes_like(image.original_axes) |
| 30 roi_image.write(os.path.join(output_dir, f'{label}.{output_ext}')) | 53 roi_image.write(os.path.join(output_dir, f'{label}.{output_ext}')) |
| 31 | 54 |
| 32 | 55 |
| 33 def crop_image_to_mask(data: np.ndarray, mask: np.ndarray) -> np.ndarray: | 56 def crop_image_to_mask(data: np.ndarray, mask: np.ndarray) -> np.ndarray: |
| 34 """ | 57 """ |
| 40 | 63 |
| 41 # Crop `data` to the convex hull of the mask in each dimension | 64 # Crop `data` to the convex hull of the mask in each dimension |
| 42 for dim in range(data.ndim): | 65 for dim in range(data.ndim): |
| 43 mask1d = mask.any(axis=tuple(i for i in range(mask.ndim) if i != dim)) | 66 mask1d = mask.any(axis=tuple(i for i in range(mask.ndim) if i != dim)) |
| 44 mask1d_indices = np.where(mask1d)[0] | 67 mask1d_indices = np.where(mask1d)[0] |
| 68 | |
| 69 # Convert `mask1d_indices` to a NumPy array if it is a Dask array | |
| 70 if hasattr(mask1d_indices, 'compute'): | |
| 71 mask1d_indices = mask1d_indices.compute() | |
| 72 | |
| 45 mask1d_indices_cvxhull = np.arange(min(mask1d_indices), max(mask1d_indices) + 1) | 73 mask1d_indices_cvxhull = np.arange(min(mask1d_indices), max(mask1d_indices) + 1) |
| 46 data = data.take(axis=dim, indices=mask1d_indices_cvxhull) | 74 |
| 75 # Crop the `data` to the minimal bounding box | |
| 76 if hasattr(data, 'compute'): | |
| 77 data = da.take(data, axis=dim, indices=mask1d_indices_cvxhull) # `data` is a Dask array | |
| 78 else: | |
| 79 data = data.take(axis=dim, indices=mask1d_indices_cvxhull) # `data` is a NumPy array | |
| 47 | 80 |
| 48 return data | 81 return data |
| 49 | 82 |
| 50 | 83 |
| 51 if __name__ == "__main__": | 84 if __name__ == "__main__": |
