annotate test-data/tf-script.py @ 2:6708db9ee47e draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit c21261bb8373090c26cf5195134b30538b5bc714
author bgruening
date Fri, 20 Jan 2023 10:48:52 +0000
parents ff9bb9df06a7
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
1 import numpy as np
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
2 import tensorflow as tf
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
3
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
4 (mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data()
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
5 mnist_images, mnist_labels = mnist_images[:128], mnist_labels[:128]
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
6 dataset = tf.data.Dataset.from_tensor_slices((tf.cast(mnist_images[..., tf.newaxis] / 255, tf.float32), tf.cast(mnist_labels, tf.int64)))
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
7 dataset = dataset.shuffle(1000).batch(32)
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
8
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
9 tot_loss = []
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
10 epochs = 1
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
11
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
12 mnist_model = tf.keras.Sequential([
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
13 tf.keras.layers.Conv2D(16, [3, 3], activation='relu'),
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
14 tf.keras.layers.Conv2D(16, [3, 3], activation='relu'),
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
15 tf.keras.layers.GlobalAveragePooling2D(),
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
16 tf.keras.layers.Dense(10)
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
17 ])
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
18
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
19 optimizer = tf.keras.optimizers.Adam()
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
20 loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
21
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
22 for epoch in range(epochs):
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
23 loss_history = []
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
24 for (batch, (images, labels)) in enumerate(dataset):
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
25 with tf.GradientTape() as tape:
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
26 logits = mnist_model(images, training=True)
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
27 loss_value = loss_object(labels, logits)
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
28 loss_history.append(loss_value.numpy().mean())
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
29 grads = tape.gradient(loss_value, mnist_model.trainable_variables)
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
30 optimizer.apply_gradients(zip(grads, mnist_model.trainable_variables))
ff9bb9df06a7 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff changeset
31 tot_loss.append(np.mean(loss_history))