diff transformer_network.py @ 5:9ec705bd11cb draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
author bgruening
date Sun, 16 Oct 2022 11:51:32 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/transformer_network.py	Sun Oct 16 11:51:32 2022 +0000
@@ -0,0 +1,39 @@
+import tensorflow as tf
+from tensorflow.keras.layers import (Dense, Dropout, Embedding, Layer,
+                                     LayerNormalization, MultiHeadAttention)
+from tensorflow.keras.models import Sequential
+
+
+class TransformerBlock(Layer):
+    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
+        super(TransformerBlock, self).__init__()
+        self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=rate)
+        self.ffn = Sequential(
+            [Dense(ff_dim, activation="relu"), Dense(embed_dim)]
+        )
+        self.layernorm1 = LayerNormalization(epsilon=1e-6)
+        self.layernorm2 = LayerNormalization(epsilon=1e-6)
+        self.dropout1 = Dropout(rate)
+        self.dropout2 = Dropout(rate)
+
+    def call(self, inputs, training):
+        attn_output, attention_scores = self.att(inputs, inputs, inputs, return_attention_scores=True, training=training)
+        attn_output = self.dropout1(attn_output, training=training)
+        out1 = self.layernorm1(inputs + attn_output)
+        ffn_output = self.ffn(out1)
+        ffn_output = self.dropout2(ffn_output, training=training)
+        return self.layernorm2(out1 + ffn_output), attention_scores
+
+
+class TokenAndPositionEmbedding(Layer):
+    def __init__(self, maxlen, vocab_size, embed_dim):
+        super(TokenAndPositionEmbedding, self).__init__()
+        self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim, mask_zero=True)
+        self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim, mask_zero=True)
+
+    def call(self, x):
+        maxlen = tf.shape(x)[-1]
+        positions = tf.range(start=0, limit=maxlen, delta=1)
+        positions = self.pos_emb(positions)
+        x = self.token_emb(x)
+        return x + positions