comparison main.py @ 3:bc28236f407b draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/bioimaging commit e08711c242a340a1671dfca35f52d3724086e968
author bgruening
date Wed, 26 Feb 2025 10:27:28 +0000
parents 9f9ae2ac7820
children 2b61d8fcfa52
comparison
equal deleted inserted replaced
2:0c0de5546fe1 3:bc28236f407b
5 import argparse 5 import argparse
6 6
7 import imageio 7 import imageio
8 import numpy as np 8 import numpy as np
9 import torch 9 import torch
10 import torch.nn.functional as F
10 11
11 12
12 def find_dim_order(user_in_shape, input_image): 13 def dynamic_resize(image: torch.Tensor, target_shape: tuple):
13 """ 14 """
14 Find the correct order of input image's 15 Resize an input tensor dynamically to the target shape.
15 shape. For a few models, the order of input size 16
16 mentioned in the RDF.yaml file is reversed compared 17 Parameters:
17 to the input image's original size. If it is reversed, 18 - image: Input tensor with shape (C, D1, D2, ..., DN) (any number of spatial dims)
18 transpose the image to find correct order of image's 19 - target_shape: Tuple specifying the target shape (C', D1', D2', ..., DN')
19 dimensions. 20
21 Returns:
22 - Resized tensor with target shape target_shape.
20 """ 23 """
21 image_shape = list(input_image.shape) 24 # Extract input shape
22 # reverse the input shape provided from RDF.yaml file 25 input_shape = image.shape
23 correct_order = user_in_shape.split(",")[::-1] 26 num_dims = len(input_shape) # Includes channels and spatial dimensions
24 # remove 1s from the original dimensions 27
25 correct_order = [int(i) for i in correct_order if i != "1"] 28 # Ensure target shape matches the number of dimensions
26 if (correct_order[0] == image_shape[-1]) and (correct_order != image_shape): 29 if len(target_shape) != num_dims:
27 input_image = torch.tensor(input_image.transpose()) 30 raise ValueError(
28 return input_image, correct_order 31 f"Target shape {target_shape} must match input dimensions {num_dims}"
32 )
33
34 # Extract target channels and spatial sizes
35 target_channels = target_shape[0] # First element is the target channel count
36 target_spatial_size = target_shape[1:] # Remaining elements are spatial dimensions
37
38 # Add batch dim (N=1) for resizing
39 image = image.unsqueeze(0)
40
41 # Choose the best interpolation mode based on dimensionality
42 if num_dims == 4:
43 interp_mode = "trilinear"
44 elif num_dims == 3:
45 interp_mode = "bilinear"
46 elif num_dims == 2:
47 interp_mode = "bicubic"
48 else:
49 interp_mode = "nearest"
50
51 # Resize spatial dimensions dynamically
52 image = F.interpolate(
53 image, size=target_spatial_size, mode=interp_mode, align_corners=False
54 )
55
56 # Adjust channels if necessary
57 current_channels = image.shape[1]
58
59 if target_channels > current_channels:
60 # Expand channels by repeating existing ones
61 expand_factor = target_channels // current_channels
62 remainder = target_channels % current_channels
63 image = image.repeat(1, expand_factor, *[1] * (num_dims - 1))
64
65 if remainder > 0:
66 extra_channels = image[
67 :, :remainder, ...
68 ] # Take the first few channels to match target
69 image = torch.cat([image, extra_channels], dim=1)
70
71 elif target_channels < current_channels:
72 # Reduce channels by averaging adjacent ones
73 image = image[:, :target_channels, ...] # Simply slice to reduce channels
74 return image.squeeze(0) # Remove batch dimension before returning
29 75
30 76
31 if __name__ == "__main__": 77 if __name__ == "__main__":
32 arg_parser = argparse.ArgumentParser() 78 arg_parser = argparse.ArgumentParser()
33 arg_parser.add_argument("-im", "--imaging_model", required=True, help="Input BioImage model") 79 arg_parser.add_argument(
34 arg_parser.add_argument("-ii", "--image_file", required=True, help="Input image file") 80 "-im", "--imaging_model", required=True, help="Input BioImage model"
35 arg_parser.add_argument("-is", "--image_size", required=True, help="Input image file's size") 81 )
82 arg_parser.add_argument(
83 "-ii", "--image_file", required=True, help="Input image file"
84 )
85 arg_parser.add_argument(
86 "-is", "--image_size", required=True, help="Input image file's size"
87 )
88 arg_parser.add_argument(
89 "-ia", "--image_axes", required=True, help="Input image file's axes"
90 )
36 91
37 # get argument values 92 # get argument values
38 args = vars(arg_parser.parse_args()) 93 args = vars(arg_parser.parse_args())
39 model_path = args["imaging_model"] 94 model_path = args["imaging_model"]
40 input_image_path = args["image_file"] 95 input_image_path = args["image_file"]
96 input_size = args["image_size"]
41 97
42 # load all embedded images in TIF file 98 # load all embedded images in TIF file
43 test_data = imageio.v3.imread(input_image_path, index="...") 99 test_data = imageio.v3.imread(input_image_path, index="...")
100 test_data = test_data.astype(np.float32)
44 test_data = np.squeeze(test_data) 101 test_data = np.squeeze(test_data)
45 test_data = test_data.astype(np.float32)
46 102
47 # assess the correct dimensions of TIF input image 103 target_image_dim = input_size.split(",")[::-1]
48 input_image_shape = args["image_size"] 104 target_image_dim = [int(i) for i in target_image_dim if i != "1"]
49 im_test_data, shape_vals = find_dim_order(input_image_shape, test_data) 105 target_image_dim = tuple(target_image_dim)
106
107 exp_test_data = torch.tensor(test_data)
108 # check if image dimensions are reversed
109 reversed_order = list(reversed(range(exp_test_data.dim())))
110 exp_test_data_T = exp_test_data.permute(*reversed_order)
111 if exp_test_data_T.shape == target_image_dim:
112 exp_test_data = exp_test_data_T
113 if exp_test_data.shape != target_image_dim:
114 for i in range(len(target_image_dim) - exp_test_data.dim()):
115 exp_test_data = exp_test_data.unsqueeze(i)
116 try:
117 exp_test_data = dynamic_resize(exp_test_data, target_image_dim)
118 except Exception as e:
119 raise RuntimeError(f"Error during resizing: {e}") from e
120
121 current_dimension = len(exp_test_data.shape)
122 input_axes = args["image_axes"]
123 target_dimension = len(input_axes)
124 # expand input image based on the number of target dimensions
125 for i in range(target_dimension - current_dimension):
126 exp_test_data = torch.unsqueeze(exp_test_data, i)
50 127
51 # load model 128 # load model
52 model = torch.load(model_path) 129 model = torch.load(model_path)
53 model.eval() 130 model.eval()
54
55 # find the number of dimensions required by the model
56 target_dimension = 0
57 for param in model.named_parameters():
58 target_dimension = len(param[1].shape)
59 break
60 current_dimension = len(list(im_test_data.shape))
61
62 # update the dimensions of input image if the required image by
63 # the model is smaller
64 slices = tuple(slice(0, s_val) for s_val in shape_vals)
65
66 # apply the slices to the reshaped_input
67 im_test_data = im_test_data[slices]
68 exp_test_data = torch.tensor(im_test_data)
69
70 # expand input image's dimensions
71 for i in range(target_dimension - current_dimension):
72 exp_test_data = torch.unsqueeze(exp_test_data, i)
73 131
74 # make prediction 132 # make prediction
75 pred_data = model(exp_test_data) 133 pred_data = model(exp_test_data)
76 pred_data_output = pred_data.detach().numpy() 134 pred_data_output = pred_data.detach().numpy()
77 135