Mercurial > repos > bgruening > create_tool_recommendation_model
annotate transformer_network.py @ 5:9ec705bd11cb draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
author | bgruening |
---|---|
date | Sun, 16 Oct 2022 11:51:32 +0000 |
parents | |
children |
rev | line source |
---|---|
5
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
1 import tensorflow as tf |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
2 from tensorflow.keras.layers import (Dense, Dropout, Embedding, Layer, |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
3 LayerNormalization, MultiHeadAttention) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
4 from tensorflow.keras.models import Sequential |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
5 |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
6 |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
7 class TransformerBlock(Layer): |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
8 def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1): |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
9 super(TransformerBlock, self).__init__() |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
10 self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=rate) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
11 self.ffn = Sequential( |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
12 [Dense(ff_dim, activation="relu"), Dense(embed_dim)] |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
13 ) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
14 self.layernorm1 = LayerNormalization(epsilon=1e-6) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
15 self.layernorm2 = LayerNormalization(epsilon=1e-6) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
16 self.dropout1 = Dropout(rate) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
17 self.dropout2 = Dropout(rate) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
18 |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
19 def call(self, inputs, training): |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
20 attn_output, attention_scores = self.att(inputs, inputs, inputs, return_attention_scores=True, training=training) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
21 attn_output = self.dropout1(attn_output, training=training) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
22 out1 = self.layernorm1(inputs + attn_output) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
23 ffn_output = self.ffn(out1) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
24 ffn_output = self.dropout2(ffn_output, training=training) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
25 return self.layernorm2(out1 + ffn_output), attention_scores |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
26 |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
27 |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
28 class TokenAndPositionEmbedding(Layer): |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
29 def __init__(self, maxlen, vocab_size, embed_dim): |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
30 super(TokenAndPositionEmbedding, self).__init__() |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
31 self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim, mask_zero=True) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
32 self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim, mask_zero=True) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
33 |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
34 def call(self, x): |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
35 maxlen = tf.shape(x)[-1] |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
36 positions = tf.range(start=0, limit=maxlen, delta=1) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
37 positions = self.pos_emb(positions) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
38 x = self.token_emb(x) |
9ec705bd11cb
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff
changeset
|
39 return x + positions |