# HG changeset patch
# User bgruening
# Date 1740565648 0
# Node ID bc28236f407b52cd63b0225ada651e1bd98f3c28
# Parent 0c0de5546fe19b4e70130114d866488a9d751842
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/bioimaging commit e08711c242a340a1671dfca35f52d3724086e968
diff -r 0c0de5546fe1 -r bc28236f407b bioimage_inference.xml
--- a/bioimage_inference.xml Tue Oct 15 12:57:33 2024 +0000
+++ b/bioimage_inference.xml Wed Feb 26 10:27:28 2025 +0000
@@ -2,7 +2,7 @@
with PyTorch2.4.1
- 0
+ 1
@@ -30,12 +30,18 @@
--imaging_model '$input_imaging_model'
--image_file '$input_image_file'
--image_size '$input_image_input_size'
+ --image_axes '$input_image_input_axes'
]]>
-
+
+
+
+
+
+
@@ -46,15 +52,97 @@
-
-
+
+
+
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -64,13 +152,14 @@
The tool takes a BioImage.IO model and an image (as TIF or PNG) to be analyzed. The analysis is performed by the model. The model is used to obtain a prediction of the result of the analysis, and the predicted image becomes available as a TIF file in the Galaxy history.
**Input files**
- - BioImage.IO model: Add one of the model from Galaxy file uploader by choosing a "remote" file at "ML Models/bioimaging-models"
- - Image to be analyzed: Provide an image as TIF/PNG file
- - Provide the necessary input size for the model. This information can be found in the RDF file of each model (RDF file > config > test_information > inputs > size)
+ - BioImage.IO model: Add one of the model from Galaxy file uploader by choosing a "remote" file at "ML Models/bioimaging-models"
+ - Image to be analyzed: Provide an image as TIF/PNG file
+ - Provide the necessary input size for the model. This information can be found in the RDF file of each model (RDF file > config > test_information > inputs > size)
+ - Provide axes of input image. This information can also be found in the RDF file of each model (RDF file > inputs > axes). An example value of axes is 'bczyx' for 3D U-Net Arabidopsis Lateral Root Primordia model
**Output files**
- - Predicted image: Predicted image using the BioImage.IO model
- - Predicted image matrix: Predicted image matrix in original dimensions
+ - Predicted image: Predicted image using the BioImage.IO model
+ - Predicted image matrix: Predicted image matrix in original dimensions
]]>
diff -r 0c0de5546fe1 -r bc28236f407b main.py
--- a/main.py Tue Oct 15 12:57:33 2024 +0000
+++ b/main.py Wed Feb 26 10:27:28 2025 +0000
@@ -7,70 +7,128 @@
import imageio
import numpy as np
import torch
+import torch.nn.functional as F
-def find_dim_order(user_in_shape, input_image):
+def dynamic_resize(image: torch.Tensor, target_shape: tuple):
"""
- Find the correct order of input image's
- shape. For a few models, the order of input size
- mentioned in the RDF.yaml file is reversed compared
- to the input image's original size. If it is reversed,
- transpose the image to find correct order of image's
- dimensions.
+ Resize an input tensor dynamically to the target shape.
+
+ Parameters:
+ - image: Input tensor with shape (C, D1, D2, ..., DN) (any number of spatial dims)
+ - target_shape: Tuple specifying the target shape (C', D1', D2', ..., DN')
+
+ Returns:
+ - Resized tensor with target shape target_shape.
"""
- image_shape = list(input_image.shape)
- # reverse the input shape provided from RDF.yaml file
- correct_order = user_in_shape.split(",")[::-1]
- # remove 1s from the original dimensions
- correct_order = [int(i) for i in correct_order if i != "1"]
- if (correct_order[0] == image_shape[-1]) and (correct_order != image_shape):
- input_image = torch.tensor(input_image.transpose())
- return input_image, correct_order
+ # Extract input shape
+ input_shape = image.shape
+ num_dims = len(input_shape) # Includes channels and spatial dimensions
+
+ # Ensure target shape matches the number of dimensions
+ if len(target_shape) != num_dims:
+ raise ValueError(
+ f"Target shape {target_shape} must match input dimensions {num_dims}"
+ )
+
+ # Extract target channels and spatial sizes
+ target_channels = target_shape[0] # First element is the target channel count
+ target_spatial_size = target_shape[1:] # Remaining elements are spatial dimensions
+
+ # Add batch dim (N=1) for resizing
+ image = image.unsqueeze(0)
+
+ # Choose the best interpolation mode based on dimensionality
+ if num_dims == 4:
+ interp_mode = "trilinear"
+ elif num_dims == 3:
+ interp_mode = "bilinear"
+ elif num_dims == 2:
+ interp_mode = "bicubic"
+ else:
+ interp_mode = "nearest"
+
+ # Resize spatial dimensions dynamically
+ image = F.interpolate(
+ image, size=target_spatial_size, mode=interp_mode, align_corners=False
+ )
+
+ # Adjust channels if necessary
+ current_channels = image.shape[1]
+
+ if target_channels > current_channels:
+ # Expand channels by repeating existing ones
+ expand_factor = target_channels // current_channels
+ remainder = target_channels % current_channels
+ image = image.repeat(1, expand_factor, *[1] * (num_dims - 1))
+
+ if remainder > 0:
+ extra_channels = image[
+ :, :remainder, ...
+ ] # Take the first few channels to match target
+ image = torch.cat([image, extra_channels], dim=1)
+
+ elif target_channels < current_channels:
+ # Reduce channels by averaging adjacent ones
+ image = image[:, :target_channels, ...] # Simply slice to reduce channels
+ return image.squeeze(0) # Remove batch dimension before returning
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
- arg_parser.add_argument("-im", "--imaging_model", required=True, help="Input BioImage model")
- arg_parser.add_argument("-ii", "--image_file", required=True, help="Input image file")
- arg_parser.add_argument("-is", "--image_size", required=True, help="Input image file's size")
+ arg_parser.add_argument(
+ "-im", "--imaging_model", required=True, help="Input BioImage model"
+ )
+ arg_parser.add_argument(
+ "-ii", "--image_file", required=True, help="Input image file"
+ )
+ arg_parser.add_argument(
+ "-is", "--image_size", required=True, help="Input image file's size"
+ )
+ arg_parser.add_argument(
+ "-ia", "--image_axes", required=True, help="Input image file's axes"
+ )
# get argument values
args = vars(arg_parser.parse_args())
model_path = args["imaging_model"]
input_image_path = args["image_file"]
+ input_size = args["image_size"]
# load all embedded images in TIF file
test_data = imageio.v3.imread(input_image_path, index="...")
+ test_data = test_data.astype(np.float32)
test_data = np.squeeze(test_data)
- test_data = test_data.astype(np.float32)
+
+ target_image_dim = input_size.split(",")[::-1]
+ target_image_dim = [int(i) for i in target_image_dim if i != "1"]
+ target_image_dim = tuple(target_image_dim)
- # assess the correct dimensions of TIF input image
- input_image_shape = args["image_size"]
- im_test_data, shape_vals = find_dim_order(input_image_shape, test_data)
+ exp_test_data = torch.tensor(test_data)
+ # check if image dimensions are reversed
+ reversed_order = list(reversed(range(exp_test_data.dim())))
+ exp_test_data_T = exp_test_data.permute(*reversed_order)
+ if exp_test_data_T.shape == target_image_dim:
+ exp_test_data = exp_test_data_T
+ if exp_test_data.shape != target_image_dim:
+ for i in range(len(target_image_dim) - exp_test_data.dim()):
+ exp_test_data = exp_test_data.unsqueeze(i)
+ try:
+ exp_test_data = dynamic_resize(exp_test_data, target_image_dim)
+ except Exception as e:
+ raise RuntimeError(f"Error during resizing: {e}") from e
+
+ current_dimension = len(exp_test_data.shape)
+ input_axes = args["image_axes"]
+ target_dimension = len(input_axes)
+ # expand input image based on the number of target dimensions
+ for i in range(target_dimension - current_dimension):
+ exp_test_data = torch.unsqueeze(exp_test_data, i)
# load model
model = torch.load(model_path)
model.eval()
- # find the number of dimensions required by the model
- target_dimension = 0
- for param in model.named_parameters():
- target_dimension = len(param[1].shape)
- break
- current_dimension = len(list(im_test_data.shape))
-
- # update the dimensions of input image if the required image by
- # the model is smaller
- slices = tuple(slice(0, s_val) for s_val in shape_vals)
-
- # apply the slices to the reshaped_input
- im_test_data = im_test_data[slices]
- exp_test_data = torch.tensor(im_test_data)
-
- # expand input image's dimensions
- for i in range(target_dimension - current_dimension):
- exp_test_data = torch.unsqueeze(exp_test_data, i)
-
# make prediction
pred_data = model(exp_test_data)
pred_data_output = pred_data.detach().numpy()
diff -r 0c0de5546fe1 -r bc28236f407b test-data/output_nucleisegboundarymodel.tif
Binary file test-data/output_nucleisegboundarymodel.tif has changed
diff -r 0c0de5546fe1 -r bc28236f407b test-data/output_nucleisegboundarymodel_matrix.npy
Binary file test-data/output_nucleisegboundarymodel_matrix.npy has changed