comparison crop_image.py @ 1:457514bb6750 draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/crop_image/ commit 52a95105291e38f3410e347ed3b60d6acd6d5daa
author imgteam
date Fri, 09 Jan 2026 14:54:49 +0000
parents f8bfa85cac4c
children
comparison
equal deleted inserted replaced
0:f8bfa85cac4c 1:457514bb6750
1 import argparse 1 import argparse
2 import os 2 import os
3 3
4 import dask.array as da
5 import giatools
6 import giatools.image
4 import numpy as np 7 import numpy as np
5 from giatools.image import Image 8
9 # Fail early if an optional backend is not available
10 giatools.require_backend('omezarr')
6 11
7 12
8 def crop_image( 13 def crop_image(
9 image_filepath: str, 14 image_filepath: str,
10 labelmap_filepath: str, 15 labelmap_filepath: str,
11 output_ext: str, 16 output_ext: str,
12 output_dir: str, 17 output_dir: str,
13 skip_labels: frozenset[int], 18 skip_labels: frozenset[int],
14 ): 19 ):
15 image = Image.read(image_filepath) 20 axes = giatools.default_normalized_axes
16 labelmap = Image.read(labelmap_filepath) 21 image = giatools.Image.read(image_filepath, normalize_axes=axes)
22 labelmap = giatools.Image.read(labelmap_filepath, normalize_axes=axes)
17 23
18 if image.axes != labelmap.axes: 24 # Establish compatibility of multi-channel/frame/etc. images with single-channel/frame/etc. label maps
19 raise ValueError(f'Axes mismatch between image ({image.axes}) and label map ({labelmap.axes}).') 25 original_labelmap_shape = labelmap.shape
26 for image_s, labelmap_s, (axis_idx, axis) in zip(image.shape, labelmap.shape, enumerate(axes)):
27 if image_s > 1 and labelmap_s == 1 and axis not in 'YX':
28 target_shape = list(labelmap.shape)
29 target_shape[axis_idx] = image_s
20 30
21 if image.data.shape != labelmap.data.shape: 31 # Broadcast the labelmap data to the target shape without copying
22 raise ValueError(f'Shape mismatch between image ({image.data.shape}) and label map ({labelmap.data.shape}).') 32 if hasattr(labelmap.data, 'compute'):
33 labelmap.data = da.broadcast_to(labelmap.data, target_shape) # `data` is Dask array
34 else:
35 labelmap.data = np.broadcast_to(labelmap.data, target_shape, subok=True) # `data` is NumPy array
23 36
24 for label in np.unique(labelmap.data): 37 # Validate that the shapes of the images are compatible
38 if image.shape != labelmap.shape:
39 labelmap_shape_str = str(original_labelmap_shape)
40 if labelmap.shape != original_labelmap_shape:
41 labelmap_shape_str = f'{labelmap_shape_str}, broadcasted to {labelmap.shape}'
42 raise ValueError(
43 f'Shape mismatch between image {image.shape} and label map {labelmap_shape_str}, with {axes} axes.',
44 )
45
46 # Extract the image crops
47 for label in giatools.image._unique(labelmap.data):
25 if label in skip_labels: 48 if label in skip_labels:
26 continue 49 continue
27 roi_mask = (labelmap.data == label) 50 roi_mask = (labelmap.data == label)
28 roi = crop_image_to_mask(image.data, roi_mask) 51 roi = crop_image_to_mask(image.data, roi_mask)
29 roi_image = Image(roi, image.axes).normalize_axes_like(image.original_axes) 52 roi_image = giatools.Image(roi, image.axes).normalize_axes_like(image.original_axes)
30 roi_image.write(os.path.join(output_dir, f'{label}.{output_ext}')) 53 roi_image.write(os.path.join(output_dir, f'{label}.{output_ext}'))
31 54
32 55
33 def crop_image_to_mask(data: np.ndarray, mask: np.ndarray) -> np.ndarray: 56 def crop_image_to_mask(data: np.ndarray, mask: np.ndarray) -> np.ndarray:
34 """ 57 """
40 63
41 # Crop `data` to the convex hull of the mask in each dimension 64 # Crop `data` to the convex hull of the mask in each dimension
42 for dim in range(data.ndim): 65 for dim in range(data.ndim):
43 mask1d = mask.any(axis=tuple(i for i in range(mask.ndim) if i != dim)) 66 mask1d = mask.any(axis=tuple(i for i in range(mask.ndim) if i != dim))
44 mask1d_indices = np.where(mask1d)[0] 67 mask1d_indices = np.where(mask1d)[0]
68
69 # Convert `mask1d_indices` to a NumPy array if it is a Dask array
70 if hasattr(mask1d_indices, 'compute'):
71 mask1d_indices = mask1d_indices.compute()
72
45 mask1d_indices_cvxhull = np.arange(min(mask1d_indices), max(mask1d_indices) + 1) 73 mask1d_indices_cvxhull = np.arange(min(mask1d_indices), max(mask1d_indices) + 1)
46 data = data.take(axis=dim, indices=mask1d_indices_cvxhull) 74
75 # Crop the `data` to the minimal bounding box
76 if hasattr(data, 'compute'):
77 data = da.take(data, axis=dim, indices=mask1d_indices_cvxhull) # `data` is a Dask array
78 else:
79 data = data.take(axis=dim, indices=mask1d_indices_cvxhull) # `data` is a NumPy array
47 80
48 return data 81 return data
49 82
50 83
51 if __name__ == "__main__": 84 if __name__ == "__main__":