changeset 0:50a3903370a6 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/rna_tools/rnaformer commit 81ae3a5a824f820d2ee4212dcacb334b3dabd683
author bgruening
date Sun, 26 Jan 2025 08:45:36 +0000
parents
children
files infer_rnaformer.xml macros.xml test-data/fasta_input1.fa test-data/fasta_input_false1.fa test-data/rna_2d_pred_FASTA.txt test-data/rna_2d_pred_out.txt test-data/rna_2d_pred_text.txt
diffstat 7 files changed, 269 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/infer_rnaformer.xml	Sun Jan 26 08:45:36 2025 +0000
@@ -0,0 +1,229 @@
+<tool id="infer_rnaformer" name="@EXECUTABLE@" version="@TOOL_VERSION@+galaxy1" profile="22.05">
+    <description>Predict the secondary structure of an RNA with RNAformer</description>
+    <macros>
+        <import>macros.xml</import>
+    </macros>
+    <expand macro="requirements">
+        <requirement type="package" version="1.83">biopython</requirement>
+    </expand>
+    <command detect_errors="exit_code"><![CDATA[
+    mkdir -p './model' &&
+    wget -O './model/RNAformer_32M_state_dict_intra_family_finetuned.pth' 'https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_state_dict_intra_family_finetuned.pth'
+    &&
+    wget -O './model/RNAformer_32M_config_intra_family_finetuned.yml' 'https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_config_intra_family_finetuned.yml'
+    &&
+    python '$script_file' > '$output'
+]]></command>
+<configfiles>
+        <configfile name="script_file"><![CDATA[import RNAformer
+import os
+import argparse
+import torch
+import urllib.request
+import logging
+from collections import defaultdict
+import torch.cuda
+import loralib as lora
+from RNAformer.model.RNAformer import RiboFormer
+from RNAformer.utils.configuration import Config
+
+from Bio import SeqIO
+
+import logging
+import sys
+
+def is_valid_rna_sequence(sequence):
+    """Check if the sequence contains only RNA bases."""
+    valid_bases = {'A', 'C', 'G', 'U', 'N'}  # Include 'N' if unknown bases are allowed
+    return all(base in valid_bases for base in sequence.upper())
+
+config_file_path = 'model/RNAformer_32M_config_intra_family_finetuned.yml'
+model_file_path = 'model/RNAformer_32M_state_dict_intra_family_finetuned.pth'
+
+config = Config(config_file=config_file_path)
+config.RNAformer.cycling = 6
+model = RiboFormer(config.RNAformer)
+state_dict_file = model_file_path
+
+#if str($input_type.input_type) == 'True'
+fasta_path = '$input_type.fasta_input'
+sequences = [str(record.seq) for record in SeqIO.parse(fasta_path, 'fasta')]
+#else
+sequence_string = "$input_type.rna_input_string"
+sequences = [seq.strip() for seq in sequence_string.split(',')]
+#end if
+
+for seq in sequences:
+    if not is_valid_rna_sequence(seq):
+        print(f"Invalid RNA sequence detected: {seq}. Please ensure only RNA sequences are used as input.", file=sys.stderr)
+        sys.exit(1)
+
+lora_config = {
+    "r": config.r,
+    "lora_alpha": config.lora_alpha,
+    "lora_dropout": config.lora_dropout,
+}
+
+with torch.no_grad():
+    for name, module in model.named_modules():
+        if any(replace_key in name for replace_key in config.replace_layer):
+            parent = model.get_submodule(".".join(name.split(".")[:-1]))
+            target_name = name.split(".")[-1]
+            target = model.get_submodule(name)
+            if isinstance(target, torch.nn.Linear) and "qkv" in name:
+                new_module = lora.MergedLinear(target.in_features, target.out_features,
+                                            bias=target.bias is not None,
+                                            enable_lora=[True, True, True], **lora_config)
+                new_module.weight.copy_(target.weight)
+                if target.bias is not None:
+                    new_module.bias.copy_(target.bias)
+            elif isinstance(target, torch.nn.Linear):
+                new_module = lora.Linear(target.in_features, target.out_features,
+                                        bias=target.bias is not None, **lora_config)
+                new_module.weight.copy_(target.weight)
+                if target.bias is not None:
+                    new_module.bias.copy_(target.bias)
+
+            elif isinstance(target, torch.nn.Conv2d):
+                kernel_size = target.kernel_size[0]
+                new_module = lora.Conv2d(target.in_channels, target.out_channels, kernel_size,
+                                        padding=(kernel_size - 1) // 2, bias=target.bias is not None,
+                                        **lora_config)
+
+                new_module.conv.weight.copy_(target.weight)
+                if target.bias is not None:
+                    new_module.conv.bias.copy_(target.bias)
+            setattr(parent, target_name, new_module)
+
+state_dict = torch.load(state_dict_file, map_location=torch.device('cpu'))
+
+model.load_state_dict(state_dict, strict=True)
+model_name = state_dict_file.split(".pth")[0]
+
+if torch.cuda.is_available():
+        model = model.cuda()
+
+        # check GPU can do bf16
+        if torch.cuda.is_bf16_supported():
+            model = model.bfloat16()
+        else:
+            model = model.half()
+
+model.eval()
+predicted_structures = []
+
+orig_seq = ""
+
+for sequence in sequences:
+    with torch.no_grad():
+        device = "cpu"
+
+        seq_vocab = ['A', 'C', 'G', 'U', 'N']
+        seq_stoi = dict(zip(seq_vocab, range(len(seq_vocab))))
+
+        pdb_sample = 1
+
+        length = len(sequence)
+        src_seq = torch.LongTensor(list(map(seq_stoi.get, sequence)))
+
+        orig_seq = sequence
+
+        sample = {}
+        sample['src_seq'] = src_seq.clone()
+        sample['length'] = torch.LongTensor([length])[0]
+        sample['pdb_sample'] = torch.LongTensor([pdb_sample])[0]
+
+        sequence = sample['src_seq'].unsqueeze(0).to(device)
+        src_len = torch.LongTensor([sequence.shape[-1]]).to(device)
+        pdb_sample = torch.FloatTensor([[1]]).to(device)
+
+        logits, pair_mask = model(sequence, src_len, pdb_sample)
+
+        pred_mat = torch.sigmoid(logits[0, :, :, -1]) > 0.5
+
+        pos_id = torch.where(pred_mat == True)
+        pos1_id = pos_id[0].cpu().tolist()
+        pos2_id = pos_id[1].cpu().tolist()
+        predicted_structure = f"Pairing index 1: {pos1_id} \nPairing index 2: {pos2_id}"
+        pairs = [[a, b] for a, b in zip(pos1_id, pos2_id)]
+
+        seqlen = len(sample['src_seq'])
+        dot_bracket =['.'] * seqlen
+        pk_count = 0
+        pk_list = []
+        for i in range(len(pos1_id)):
+            open_index = pos1_id[i]
+            close_index = pos2_id[i]
+
+            if 0 <= open_index < len(dot_bracket) and 0 <= close_index < len(dot_bracket):
+                if dot_bracket[open_index] == '.' and dot_bracket[close_index] == '.':
+                    dot_bracket[open_index] = '('
+                    dot_bracket[close_index] = ')'
+                else:
+                    ## pseudoknots or multiplets present in structure- cannot represent with dot-bracket
+                    pk_count += 1
+                    pk_list.append(pairs[i])
+        dot_bracket_str_pred = ''.join(dot_bracket)
+
+        print(f"{orig_seq}")
+        print(f"Base pairs: {str(pairs)}")
+        print(f"Predicted Structure: {dot_bracket_str_pred}")
+        if pk_count > 0:
+            print(f"NOTE: {pk_count} pseudoknots and/or multiplets present in predicted structure excluded from dot-bracket notation: {pk_list}")
+
+
+
+]]></configfile>
+    </configfiles>
+    <inputs>
+        <conditional name="input_type">
+            <param name="input_type" type="select" label="Input from FASTA file">
+                <option value="False">Provide a single RNA sequence string as text</option>
+                <option value="True">Provide a FASTA file</option>
+            </param>
+            <when value="False">
+                <param name="rna_input_string" label="Sequence(s) to fold" type="text" value="GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA" help="Enter RNA sequences. Separate multiple RNA sequences by commas.">
+                    <sanitizer>
+                        <valid>
+                            <add value="ACGUacgu,"/>
+                        </valid>
+                    </sanitizer>
+                </param>
+            </when>
+            <when value="True">
+                <param format="fasta" name="fasta_input" type="data" label="Sequence to fold (FASTA file)"/>
+            </when>
+        </conditional>
+    </inputs>
+    <outputs>
+        <data name="output" format="txt" label="output"/>
+    </outputs>
+    <tests>
+        <test>
+            <param name="input_type" value="False"/>
+            <param name="rna_input_string" value="GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA"/>
+            <output name="output" file="rna_2d_pred_out.txt"/>
+        </test>
+        <test>
+            <param name="input_type" value="True"/>
+            <param name="fasta_input" value="fasta_input1.fa"/>
+            <output name="output" file="rna_2d_pred_FASTA.txt"/>
+        </test>
+    </tests>
+    <help><![CDATA[
+    **RNAformer**
+        This tool reads RNA sequences and predicts their secondary structure using RNAformer. Note that unlike conventional methods, RNAformer is capable of predicting all possible secondary structure motifs, including pseudoknots and multiplets. These cannot be represented in dot-bracket notation and thus the output will be partial in these cases, excluding these which will be noted in the output file below the (partial) dot-bracket structure. 
+
+    **Input format**
+
+    RNAformer requires one or more RNA sequences either as a single FASTA file or as plain text.
+
+    **Outputs**
+
+    - Predicted secondary structure as a text file containing the following:
+        - RNA input sequence
+        - Base pairs of predicted secondary structure
+        - Predicted secondary structure in dot-bracket notation
+    ]]></help>
+    <expand macro="citations" />
+</tool>
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/macros.xml	Sun Jan 26 08:45:36 2025 +0000
@@ -0,0 +1,16 @@
+<macros>
+    <token name="@EXECUTABLE@">RNAformer</token>
+    <token name="@TOOL_VERSION@">1.0.0</token>
+    <token name="@profile@">22.05</token>
+    <xml name="requirements">
+        <requirements>
+            <requirement type="package" version="0.0.1">rnaformer</requirement>
+            <yield/>
+        </requirements>
+    </xml>
+    <xml name="citations">
+        <citations>
+            <citation type="doi">10.1101/2024.02.12.579881</citation>
+        </citations>
+    </xml>
+</macros>
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/fasta_input1.fa	Sun Jan 26 08:45:36 2025 +0000
@@ -0,0 +1,4 @@
+>Anolis_caro_chrUn_GL343590.trna2_AlaAGC (218800-218872)  Ala (AGC) 73 bp  Sc: 49.55
+UGGGAAUUAGCUCAAAUGGUAGAGCGCUCGCUUAGCAUGUGAGAGGUAGUGGGAUCGAUGCCCACAUUCUCCA
+>Anolis_caro_chrUn_GL343207.trna3_AlaAGC (1513626-1513698)  Ala (AGC) 73 bp  Sc: 56.15
+GGGGAAUUAGCUCAAAUGGUAGAGCGCUCGCUUAGCAUGCGAGAGGUAGCGGGAUUGAUGCCCGCAUUCUCCA
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/fasta_input_false1.fa	Sun Jan 26 08:45:36 2025 +0000
@@ -0,0 +1,4 @@
+>Anolis_caro_chrUn_GL343590.trna2_AlaAGC (218800-218872)  Ala (AGC) 73 bp  Sc: 49.55
+TGGGAATTAGCTCAAATGGAAGAGCGCTCGCTTAGCATGTGAGAGGTAGTGGGATCGATGCCCACATTCTCCA
+>Anolis_caro_chrUn_GL343207.trna3_AlaAGC (1513626-1513698)  Ala (AGC) 73 bp  Sc: 56.15
+GGGAATTAGCTCAAATGGAAGAGCGCTCGCTTAGCATGCGAGAGGTAGCGGGATTGATGCCCGCATTCTCCA
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/rna_2d_pred_FASTA.txt	Sun Jan 26 08:45:36 2025 +0000
@@ -0,0 +1,8 @@
+UGGGAAUUAGCUCAAAUGGUAGAGCGCUCGCUUAGCAUGUGAGAGGUAGUGGGAUCGAUGCCCACAUUCUCCA
+Base pairs: [[0, 71], [1, 70], [2, 69], [3, 68], [4, 67], [5, 66], [6, 65], [7, 13], [7, 14], [7, 20], [7, 47], [8, 22], [9, 24], [9, 44], [10, 23], [11, 22], [12, 21], [13, 7], [13, 20], [17, 54], [18, 55], [20, 7], [21, 12], [21, 45], [22, 8], [22, 11], [23, 10], [24, 9], [25, 43], [26, 42], [27, 41], [28, 40], [29, 39], [30, 38], [31, 37], [32, 36], [36, 32], [37, 31], [38, 30], [39, 29], [40, 28], [41, 27], [42, 26], [43, 25], [44, 9], [45, 21], [48, 64], [49, 63], [50, 62], [51, 61], [52, 60], [53, 57], [54, 17], [55, 18], [57, 53], [60, 52], [61, 51], [62, 50], [63, 49], [64, 48], [65, 6], [66, 5], [67, 4], [68, 3], [69, 2], [70, 1], [71, 0]]
+Predicted Structure: (((((((((((.()...((..))))((((((((...))))))))....(((((()).)..)))))))))))).
+NOTE: 39 pseudoknots and/or multiplets present in predicted structure excluded from dot-bracket notation: [[7, 14], [7, 20], [7, 47], [9, 44], [11, 22], [13, 7], [13, 20], [20, 7], [21, 12], [21, 45], [22, 8], [22, 11], [23, 10], [24, 9], [36, 32], [37, 31], [38, 30], [39, 29], [40, 28], [41, 27], [42, 26], [43, 25], [44, 9], [45, 21], [54, 17], [55, 18], [57, 53], [60, 52], [61, 51], [62, 50], [63, 49], [64, 48], [65, 6], [66, 5], [67, 4], [68, 3], [69, 2], [70, 1], [71, 0]]
+GGGGAAUUAGCUCAAAUGGUAGAGCGCUCGCUUAGCAUGCGAGAGGUAGCGGGAUUGAUGCCCGCAUUCUCCA
+Base pairs: [[0, 71], [1, 70], [2, 69], [3, 68], [4, 67], [5, 66], [6, 65], [7, 13], [7, 14], [7, 20], [8, 22], [9, 24], [9, 44], [10, 23], [11, 22], [12, 21], [13, 7], [13, 20], [14, 7], [20, 7], [20, 13], [21, 12], [21, 45], [22, 8], [22, 11], [23, 10], [24, 9], [25, 43], [26, 42], [27, 41], [28, 40], [29, 39], [30, 38], [31, 37], [36, 32], [37, 31], [38, 30], [39, 29], [40, 28], [41, 27], [42, 26], [43, 25], [44, 9], [45, 21], [48, 64], [49, 63], [50, 62], [51, 61], [52, 60], [60, 52], [61, 51], [62, 50], [63, 49], [64, 48], [65, 6], [66, 5], [67, 4], [68, 3], [69, 2], [70, 1], [71, 0]]
+Predicted Structure: (((((((((((.().......))))((((((()...()))))))....(((((.......)))))))))))).
+NOTE: 36 pseudoknots and/or multiplets present in predicted structure excluded from dot-bracket notation: [[7, 14], [7, 20], [9, 44], [11, 22], [13, 7], [13, 20], [14, 7], [20, 7], [20, 13], [21, 12], [21, 45], [22, 8], [22, 11], [23, 10], [24, 9], [37, 31], [38, 30], [39, 29], [40, 28], [41, 27], [42, 26], [43, 25], [44, 9], [45, 21], [60, 52], [61, 51], [62, 50], [63, 49], [64, 48], [65, 6], [66, 5], [67, 4], [68, 3], [69, 2], [70, 1], [71, 0]]
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/rna_2d_pred_out.txt	Sun Jan 26 08:45:36 2025 +0000
@@ -0,0 +1,4 @@
+GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA
+Base pairs: [[0, 76], [1, 75], [2, 74], [3, 73], [4, 72], [5, 71], [6, 70], [7, 13], [7, 21], [8, 23], [9, 25], [10, 24], [11, 23], [12, 22], [13, 7], [17, 59], [18, 60], [21, 7], [21, 13], [22, 12], [23, 8], [23, 11], [24, 10], [25, 9], [39, 30], [40, 29], [42, 50], [43, 49], [44, 48], [48, 44], [49, 43], [50, 42], [53, 69], [54, 68], [55, 67], [56, 66], [57, 65], [58, 62], [59, 17], [60, 18], [62, 58], [65, 57], [66, 56], [67, 55], [68, 54], [69, 53], [70, 6], [71, 5], [72, 4], [73, 3], [74, 2], [75, 1], [76, 0]]
+Predicted Structure: (((((((((((.()...((...))))...))........((.(((...)))..(((((()).)..))))))))))))....
+NOTE: 28 pseudoknots and/or multiplets present in predicted structure excluded from dot-bracket notation: [[7, 21], [11, 23], [13, 7], [21, 7], [21, 13], [22, 12], [23, 8], [23, 11], [24, 10], [25, 9], [48, 44], [49, 43], [50, 42], [59, 17], [60, 18], [62, 58], [65, 57], [66, 56], [67, 55], [68, 54], [69, 53], [70, 6], [71, 5], [72, 4], [73, 3], [74, 2], [75, 1], [76, 0]]
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/rna_2d_pred_text.txt	Sun Jan 26 08:45:36 2025 +0000
@@ -0,0 +1,4 @@
+GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA
+Base pairs: [[0, 76], [1, 75], [2, 74], [3, 73], [4, 72], [5, 71], [6, 70], [7, 13], [7, 21], [8, 23], [9, 25], [10, 24], [11, 23], [12, 22], [13, 7], [17, 59], [18, 60], [21, 7], [21, 13], [22, 12], [23, 8], [23, 11], [24, 10], [25, 9], [39, 30], [40, 29], [42, 50], [43, 49], [44, 48], [48, 44], [49, 43], [50, 42], [53, 69], [54, 68], [55, 67], [56, 66], [57, 65], [58, 62], [59, 17], [60, 18], [62, 58], [65, 57], [66, 56], [67, 55], [68, 54], [69, 53], [70, 6], [71, 5], [72, 4], [73, 3], [74, 2], [75, 1], [76, 0]]
+Predicted Structure: (((((((((((.()...((...))))...))........((.(((...)))..(((((()).)..))))))))))))....
+NOTE: 28 pseudoknots and/or multiplets present in predicted structure excluded from dot-bracket notation: [[7, 21], [11, 23], [13, 7], [21, 7], [21, 13], [22, 12], [23, 8], [23, 11], [24, 10], [25, 9], [48, 44], [49, 43], [50, 42], [59, 17], [60, 18], [62, 58], [65, 57], [66, 56], [67, 55], [68, 54], [69, 53], [70, 6], [71, 5], [72, 4], [73, 3], [74, 2], [75, 1], [76, 0]]