Mercurial > repos > bgruening > bioimage_inference
comparison main.py @ 3:bc28236f407b draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/bioimaging commit e08711c242a340a1671dfca35f52d3724086e968
author | bgruening |
---|---|
date | Wed, 26 Feb 2025 10:27:28 +0000 |
parents | 9f9ae2ac7820 |
children | 2b61d8fcfa52 |
comparison
equal
deleted
inserted
replaced
2:0c0de5546fe1 | 3:bc28236f407b |
---|---|
5 import argparse | 5 import argparse |
6 | 6 |
7 import imageio | 7 import imageio |
8 import numpy as np | 8 import numpy as np |
9 import torch | 9 import torch |
10 import torch.nn.functional as F | |
10 | 11 |
11 | 12 |
12 def find_dim_order(user_in_shape, input_image): | 13 def dynamic_resize(image: torch.Tensor, target_shape: tuple): |
13 """ | 14 """ |
14 Find the correct order of input image's | 15 Resize an input tensor dynamically to the target shape. |
15 shape. For a few models, the order of input size | 16 |
16 mentioned in the RDF.yaml file is reversed compared | 17 Parameters: |
17 to the input image's original size. If it is reversed, | 18 - image: Input tensor with shape (C, D1, D2, ..., DN) (any number of spatial dims) |
18 transpose the image to find correct order of image's | 19 - target_shape: Tuple specifying the target shape (C', D1', D2', ..., DN') |
19 dimensions. | 20 |
21 Returns: | |
22 - Resized tensor with target shape target_shape. | |
20 """ | 23 """ |
21 image_shape = list(input_image.shape) | 24 # Extract input shape |
22 # reverse the input shape provided from RDF.yaml file | 25 input_shape = image.shape |
23 correct_order = user_in_shape.split(",")[::-1] | 26 num_dims = len(input_shape) # Includes channels and spatial dimensions |
24 # remove 1s from the original dimensions | 27 |
25 correct_order = [int(i) for i in correct_order if i != "1"] | 28 # Ensure target shape matches the number of dimensions |
26 if (correct_order[0] == image_shape[-1]) and (correct_order != image_shape): | 29 if len(target_shape) != num_dims: |
27 input_image = torch.tensor(input_image.transpose()) | 30 raise ValueError( |
28 return input_image, correct_order | 31 f"Target shape {target_shape} must match input dimensions {num_dims}" |
32 ) | |
33 | |
34 # Extract target channels and spatial sizes | |
35 target_channels = target_shape[0] # First element is the target channel count | |
36 target_spatial_size = target_shape[1:] # Remaining elements are spatial dimensions | |
37 | |
38 # Add batch dim (N=1) for resizing | |
39 image = image.unsqueeze(0) | |
40 | |
41 # Choose the best interpolation mode based on dimensionality | |
42 if num_dims == 4: | |
43 interp_mode = "trilinear" | |
44 elif num_dims == 3: | |
45 interp_mode = "bilinear" | |
46 elif num_dims == 2: | |
47 interp_mode = "bicubic" | |
48 else: | |
49 interp_mode = "nearest" | |
50 | |
51 # Resize spatial dimensions dynamically | |
52 image = F.interpolate( | |
53 image, size=target_spatial_size, mode=interp_mode, align_corners=False | |
54 ) | |
55 | |
56 # Adjust channels if necessary | |
57 current_channels = image.shape[1] | |
58 | |
59 if target_channels > current_channels: | |
60 # Expand channels by repeating existing ones | |
61 expand_factor = target_channels // current_channels | |
62 remainder = target_channels % current_channels | |
63 image = image.repeat(1, expand_factor, *[1] * (num_dims - 1)) | |
64 | |
65 if remainder > 0: | |
66 extra_channels = image[ | |
67 :, :remainder, ... | |
68 ] # Take the first few channels to match target | |
69 image = torch.cat([image, extra_channels], dim=1) | |
70 | |
71 elif target_channels < current_channels: | |
72 # Reduce channels by averaging adjacent ones | |
73 image = image[:, :target_channels, ...] # Simply slice to reduce channels | |
74 return image.squeeze(0) # Remove batch dimension before returning | |
29 | 75 |
30 | 76 |
31 if __name__ == "__main__": | 77 if __name__ == "__main__": |
32 arg_parser = argparse.ArgumentParser() | 78 arg_parser = argparse.ArgumentParser() |
33 arg_parser.add_argument("-im", "--imaging_model", required=True, help="Input BioImage model") | 79 arg_parser.add_argument( |
34 arg_parser.add_argument("-ii", "--image_file", required=True, help="Input image file") | 80 "-im", "--imaging_model", required=True, help="Input BioImage model" |
35 arg_parser.add_argument("-is", "--image_size", required=True, help="Input image file's size") | 81 ) |
82 arg_parser.add_argument( | |
83 "-ii", "--image_file", required=True, help="Input image file" | |
84 ) | |
85 arg_parser.add_argument( | |
86 "-is", "--image_size", required=True, help="Input image file's size" | |
87 ) | |
88 arg_parser.add_argument( | |
89 "-ia", "--image_axes", required=True, help="Input image file's axes" | |
90 ) | |
36 | 91 |
37 # get argument values | 92 # get argument values |
38 args = vars(arg_parser.parse_args()) | 93 args = vars(arg_parser.parse_args()) |
39 model_path = args["imaging_model"] | 94 model_path = args["imaging_model"] |
40 input_image_path = args["image_file"] | 95 input_image_path = args["image_file"] |
96 input_size = args["image_size"] | |
41 | 97 |
42 # load all embedded images in TIF file | 98 # load all embedded images in TIF file |
43 test_data = imageio.v3.imread(input_image_path, index="...") | 99 test_data = imageio.v3.imread(input_image_path, index="...") |
100 test_data = test_data.astype(np.float32) | |
44 test_data = np.squeeze(test_data) | 101 test_data = np.squeeze(test_data) |
45 test_data = test_data.astype(np.float32) | |
46 | 102 |
47 # assess the correct dimensions of TIF input image | 103 target_image_dim = input_size.split(",")[::-1] |
48 input_image_shape = args["image_size"] | 104 target_image_dim = [int(i) for i in target_image_dim if i != "1"] |
49 im_test_data, shape_vals = find_dim_order(input_image_shape, test_data) | 105 target_image_dim = tuple(target_image_dim) |
106 | |
107 exp_test_data = torch.tensor(test_data) | |
108 # check if image dimensions are reversed | |
109 reversed_order = list(reversed(range(exp_test_data.dim()))) | |
110 exp_test_data_T = exp_test_data.permute(*reversed_order) | |
111 if exp_test_data_T.shape == target_image_dim: | |
112 exp_test_data = exp_test_data_T | |
113 if exp_test_data.shape != target_image_dim: | |
114 for i in range(len(target_image_dim) - exp_test_data.dim()): | |
115 exp_test_data = exp_test_data.unsqueeze(i) | |
116 try: | |
117 exp_test_data = dynamic_resize(exp_test_data, target_image_dim) | |
118 except Exception as e: | |
119 raise RuntimeError(f"Error during resizing: {e}") from e | |
120 | |
121 current_dimension = len(exp_test_data.shape) | |
122 input_axes = args["image_axes"] | |
123 target_dimension = len(input_axes) | |
124 # expand input image based on the number of target dimensions | |
125 for i in range(target_dimension - current_dimension): | |
126 exp_test_data = torch.unsqueeze(exp_test_data, i) | |
50 | 127 |
51 # load model | 128 # load model |
52 model = torch.load(model_path) | 129 model = torch.load(model_path) |
53 model.eval() | 130 model.eval() |
54 | |
55 # find the number of dimensions required by the model | |
56 target_dimension = 0 | |
57 for param in model.named_parameters(): | |
58 target_dimension = len(param[1].shape) | |
59 break | |
60 current_dimension = len(list(im_test_data.shape)) | |
61 | |
62 # update the dimensions of input image if the required image by | |
63 # the model is smaller | |
64 slices = tuple(slice(0, s_val) for s_val in shape_vals) | |
65 | |
66 # apply the slices to the reshaped_input | |
67 im_test_data = im_test_data[slices] | |
68 exp_test_data = torch.tensor(im_test_data) | |
69 | |
70 # expand input image's dimensions | |
71 for i in range(target_dimension - current_dimension): | |
72 exp_test_data = torch.unsqueeze(exp_test_data, i) | |
73 | 131 |
74 # make prediction | 132 # make prediction |
75 pred_data = model(exp_test_data) | 133 pred_data = model(exp_test_data) |
76 pred_data_output = pred_data.detach().numpy() | 134 pred_data_output = pred_data.detach().numpy() |
77 | 135 |