Voir sur TensorFlow.org | Exécuter dans Google Colab | Afficher sur GitHub | Télécharger le cahier |
Aperçu
Ce guide fournit une liste des meilleures pratiques pour écrire du code à l'aide de TensorFlow 2 (TF2). Il est destiné aux utilisateurs qui ont récemment basculé depuis TensorFlow 1 (TF1). Reportez-vous à la section migration du guide pour plus d'informations sur la migration de votre code TF1 vers TF2.
Installer
Importez TensorFlow et d'autres dépendances pour les exemples de ce guide.
import tensorflow as tf
import tensorflow_datasets as tfds
Recommandations pour TensorFlow 2 idiomatique
Refactorisez votre code en modules plus petits
Une bonne pratique consiste à refactoriser votre code en fonctions plus petites qui sont appelées selon les besoins. Pour de meilleures performances, vous devez essayer de décorer les plus grands blocs de calcul que vous pouvez dans un tf.function
(notez que les fonctions python imbriquées appelées par un tf.function
ne nécessitent pas leurs propres décorations séparées, sauf si vous souhaitez utiliser différents jit_compile
paramètres de la tf.function
.). Selon votre cas d'utilisation, il peut s'agir de plusieurs étapes d'entraînement ou même de toute votre boucle d'entraînement. Pour les cas d'utilisation d'inférence, il peut s'agir d'une seule passe avant de modèle.
Ajuster le taux d'apprentissage par défaut pour certains tf.keras.optimizer
s
Certains optimiseurs Keras ont des taux d'apprentissage différents dans TF2. Si vous constatez un changement dans le comportement de convergence de vos modèles, vérifiez les taux d'apprentissage par défaut.
Il n'y a aucun changement pour les optimizers.SGD
, les optimizers.Adam
ou les optimizers.RMSprop
.
Les taux d'apprentissage par défaut suivants ont changé :
-
optimizers.Adagrad
de0.01
à0.001
-
optimizers.Adadelta
de1.0
à0.001
-
optimizers.Adamax
de0.002
à0.001
-
optimizers.Nadam
de0.002
à0.001
Utiliser les tf.Module
s et Keras pour gérer les variables
tf.Module
s et tf.keras.layers.Layer
s offrent les variables
pratiques et les propriétés trainable_variables
, qui rassemblent de manière récursive toutes les variables dépendantes. Cela facilite la gestion des variables localement là où elles sont utilisées.
Les couches/modèles Keras héritent de tf.train.Checkpointable
et sont intégrés à @tf.function
, ce qui permet de contrôler directement ou d'exporter des SavedModels à partir d'objets Keras. Vous n'avez pas nécessairement besoin d'utiliser l'API Model.fit
de Keras pour tirer parti de ces intégrations.
Lisez la section sur l'apprentissage par transfert et le réglage fin du guide Keras pour savoir comment collecter un sous-ensemble de variables pertinentes à l'aide de Keras.
Combinez tf.data.Dataset
s et tf.function
Le package TensorFlow Datasets ( tfds
) contient des utilitaires permettant de charger des ensembles de données prédéfinis en tant tf.data.Dataset
. Pour cet exemple, vous pouvez charger le jeu de données MNIST à l'aide tfds
:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Préparez ensuite les données pour la formation :
- Redimensionnez chaque image.
- Mélangez l'ordre des exemples.
- Collectez des lots d'images et d'étiquettes.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
Pour que l'exemple reste court, découpez l'ensemble de données pour ne renvoyer que 5 lots :
train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)
STEPS_PER_EPOCH = 5
train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2021-12-08 17:15:01.637157: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Utilisez l'itération Python régulière pour itérer sur les données d'entraînement qui tiennent dans la mémoire. Sinon, tf.data.Dataset
est le meilleur moyen de diffuser des données d'entraînement à partir du disque. Les ensembles de données sont des itérables (et non des itérateurs) et fonctionnent comme les autres itérables Python dans une exécution hâtive. Vous pouvez utiliser pleinement les fonctionnalités de prélecture/diffusion asynchrones des ensembles de données en enveloppant votre code dans tf.function
, qui remplace l'itération Python par les opérations de graphe équivalentes à l'aide d'AutoGraph.
@tf.function
def train(model, dataset, optimizer):
for x, y in dataset:
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
prediction = model(x, training=True)
loss = loss_fn(prediction, y)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
Si vous utilisez l'API Keras Model.fit
, vous n'aurez pas à vous soucier de l'itération de l'ensemble de données.
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)
Utiliser les boucles d'entraînement Keras
Si vous n'avez pas besoin d'un contrôle de bas niveau de votre processus d'entraînement, il est recommandé d'utiliser les méthodes d' fit
, d' evaluate
et de predict
intégrées de Keras. Ces méthodes fournissent une interface uniforme pour former le modèle quelle que soit l'implémentation (séquentielle, fonctionnelle ou sous-classée).
Les avantages de ces méthodes incluent :
- Ils acceptent les tableaux Numpy, les générateurs Python et
tf.data.Datasets
. - Ils appliquent automatiquement la régularisation et les pertes d'activation.
- Ils prennent en charge
tf.distribute
où le code de formation reste le même quelle que soit la configuration matérielle . - Ils prennent en charge les callables arbitraires comme les pertes et les métriques.
- Ils prennent en charge les rappels tels que
tf.keras.callbacks.TensorBoard
et les rappels personnalisés. - Ils sont performants, utilisant automatiquement les graphes TensorFlow.
Voici un exemple d'entraînement d'un modèle à l'aide d'un Dataset
. Pour plus de détails sur la façon dont cela fonctionne, consultez les didacticiels .
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5 5/5 [==============================] - 9s 7ms/step - loss: 1.5762 - accuracy: 0.4938 Epoch 2/5 2021-12-08 17:15:11.145429: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 6ms/step - loss: 0.5087 - accuracy: 0.8969 Epoch 3/5 2021-12-08 17:15:11.559374: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 2s 5ms/step - loss: 0.3348 - accuracy: 0.9469 Epoch 4/5 2021-12-08 17:15:13.860407: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 5ms/step - loss: 0.2445 - accuracy: 0.9688 Epoch 5/5 2021-12-08 17:15:14.269850: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 6ms/step - loss: 0.2006 - accuracy: 0.9719 2021-12-08 17:15:14.717552: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 1s 4ms/step - loss: 1.4553 - accuracy: 0.5781 Loss 1.4552843570709229, Accuracy 0.578125 2021-12-08 17:15:15.862684: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Personnalisez la formation et écrivez votre propre boucle
Si les modèles Keras fonctionnent pour vous, mais que vous avez besoin de plus de flexibilité et de contrôle de l'étape d'entraînement ou des boucles d'entraînement externes, vous pouvez mettre en œuvre vos propres étapes d'entraînement ou même des boucles d'entraînement entières. Consultez le guide Keras sur la personnalisation de l' fit
pour en savoir plus.
Vous pouvez également implémenter de nombreuses choses en tant que tf.keras.callbacks.Callback
.
Cette méthode présente de nombreux avantages mentionnés précédemment , mais vous donne le contrôle de l'étape du train et même de la boucle extérieure.
Une boucle d'entraînement standard comporte trois étapes :
- Itérez sur un générateur Python ou
tf.data.Dataset
pour obtenir des lots d'exemples. - Utilisez
tf.GradientTape
pour collecter les dégradés. - Utilisez l'un des
tf.keras.optimizers
pour appliquer des mises à jour de poids aux variables du modèle.
Rappelles toi:
- Incluez toujours un argument de
training
sur la méthode d'call
des couches et modèles sous-classés. - Assurez-vous d'appeler le modèle avec l'argument d'
training
défini correctement. - Selon l'utilisation, les variables de modèle peuvent ne pas exister tant que le modèle n'est pas exécuté sur un lot de données.
- Vous devez gérer manuellement des choses comme les pertes de régularisation pour le modèle.
Il n'est pas nécessaire d'exécuter des initialiseurs de variables ou d'ajouter des dépendances de contrôle manuel. tf.function
gère pour vous les dépendances de contrôle automatique et l'initialisation des variables lors de la création.
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
for epoch in range(NUM_EPOCHS):
for inputs, labels in train_data:
train_step(inputs, labels)
print("Finished epoch", epoch)
2021-12-08 17:15:16.714849: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 0 2021-12-08 17:15:17.097043: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 1 2021-12-08 17:15:17.502480: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 2 2021-12-08 17:15:17.873701: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 3 Finished epoch 4 2021-12-08 17:15:18.344196: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Tirez parti de tf.function
avec le flux de contrôle Python
tf.function
fournit un moyen de convertir le flux de contrôle dépendant des données en équivalents en mode graphique comme tf.cond
et tf.while_loop
.
Un endroit commun où le flux de contrôle dépendant des données apparaît est dans les modèles de séquence. tf.keras.layers.RNN
enveloppe une cellule RNN, vous permettant de dérouler la récurrence de manière statique ou dynamique. Par exemple, vous pouvez réimplémenter le déroulement dynamique comme suit.
class DynamicRNN(tf.keras.Model):
def __init__(self, rnn_cell):
super(DynamicRNN, self).__init__(self)
self.cell = rnn_cell
@tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
def call(self, input_data):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
timesteps = tf.shape(input_data)[0]
batch_size = tf.shape(input_data)[1]
outputs = tf.TensorArray(tf.float32, timesteps)
state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
for i in tf.range(timesteps):
output, state = self.cell(input_data[i], state)
outputs = outputs.write(i, output)
return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)
my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)
Lisez le guide des tf.function
pour plus d'informations.
Métriques et pertes de style nouveau
Les métriques et les pertes sont à la fois des objets qui fonctionnent avec impatience et dans tf.function
s.
Un objet loss est appelable et attend ( y_true
, y_pred
) comme arguments :
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815
Utiliser des métriques pour collecter et afficher des données
Vous pouvez utiliser tf.metrics
pour agréger les données et tf.summary
pour consigner les résumés et les rediriger vers un rédacteur à l'aide d'un gestionnaire de contexte. Les résumés sont émis directement au rédacteur, ce qui signifie que vous devez fournir la valeur du step
au site d'appel.
summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
tf.summary.scalar('loss', 0.1, step=42)
Utilisez tf.metrics
pour agréger les données avant de les enregistrer sous forme de résumés. Les métriques sont avec état ; ils accumulent des valeurs et renvoient un résultat cumulé lorsque vous appelez la méthode result
(comme Mean.result
). Effacez les valeurs accumulées avec Model.reset_states
.
def train(model, optimizer, dataset, log_freq=10):
avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
for images, labels in dataset:
loss = train_step(model, optimizer, images, labels)
avg_loss.update_state(loss)
if tf.equal(optimizer.iterations % log_freq, 0):
tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
avg_loss.reset_states()
def test(model, test_x, test_y, step_num):
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
loss = loss_fn(model(test_x, training=False), test_y)
tf.summary.scalar('loss', loss, step=step_num)
train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')
with train_summary_writer.as_default():
train(model, optimizer, dataset)
with test_summary_writer.as_default():
test(model, test_x, test_y, optimizer.iterations)
Visualisez les résumés générés en faisant pointer TensorBoard vers le répertoire du journal des résumés :
tensorboard --logdir /tmp/summaries
Utilisez l'API tf.summary
pour écrire des données récapitulatives à visualiser dans TensorBoard. Pour plus d'informations, lisez le guide tf.summary
.
# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Update the metrics
loss_metric.update_state(total_loss)
accuracy_metric.update_state(labels, predictions)
for epoch in range(NUM_EPOCHS):
# Reset the metrics
loss_metric.reset_states()
accuracy_metric.reset_states()
for inputs, labels in train_data:
train_step(inputs, labels)
# Get the metric results
mean_loss=loss_metric.result()
mean_accuracy = accuracy_metric.result()
print('Epoch: ', epoch)
print(' loss: {:.3f}'.format(mean_loss))
print(' accuracy: {:.3f}'.format(mean_accuracy))
2021-12-08 17:15:19.339736: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 0 loss: 0.142 accuracy: 0.991 2021-12-08 17:15:19.781743: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 1 loss: 0.125 accuracy: 0.997 2021-12-08 17:15:20.219033: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 2 loss: 0.110 accuracy: 0.997 2021-12-08 17:15:20.598085: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 3 loss: 0.099 accuracy: 0.997 Epoch: 4 loss: 0.085 accuracy: 1.000 2021-12-08 17:15:20.981787: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Noms des métriques Keras
Les modèles Keras sont cohérents quant à la gestion des noms de métriques. Lorsque vous transmettez une chaîne dans la liste des métriques, cette chaîne exacte est utilisée comme name
de la métrique . Ces noms sont visibles dans l'objet historique renvoyé par model.fit
et dans les journaux transmis à keras.callbacks
. est défini sur la chaîne que vous avez transmise dans la liste des métriques.
model.compile(
optimizer = tf.keras.optimizers.Adam(0.001),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 5ms/step - loss: 0.0963 - acc: 0.9969 - accuracy: 0.9969 - my_accuracy: 0.9969 2021-12-08 17:15:21.942940: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])
Débogage
Utilisez une exécution rapide pour exécuter votre code étape par étape afin d'inspecter les formes, les types de données et les valeurs. Certaines API, comme tf.function
, tf.keras
, etc. sont conçues pour utiliser l'exécution de Graph, pour les performances et la portabilité. Lors du débogage, utilisez tf.config.run_functions_eagerly(True)
pour utiliser une exécution rapide dans ce code.
Par example:
@tf.function
def f(x):
if x > 0:
import pdb
pdb.set_trace()
x = x + 1
return x
tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
6 @tf.function
7 def f(x):
8 if x > 0:
9 import pdb
10 pdb.set_trace()
11 -> x = x + 1
12 return x
13
14 tf.config.run_functions_eagerly(True)
15 f(tf.constant(1))
[EOF]
Cela fonctionne également à l'intérieur des modèles Keras et d'autres API qui prennent en charge l'exécution rapide :
class CustomModel(tf.keras.models.Model):
@tf.function
def call(self, input_data):
if tf.reduce_mean(input_data) > 0:
return input_data
else:
import pdb
pdb.set_trace()
return input_data // 2
tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
10 if tf.reduce_mean(input_data) > 0:
11 return input_data
12 else:
13 import pdb
14 pdb.set_trace()
15 -> return input_data // 2
16
17
18 tf.config.run_functions_eagerly(True)
19 model = CustomModel()
20 model(tf.constant([-2, -4]))
Remarques:
Les méthodes
tf.keras.Model
telles quefit
,predict
etevaluate
s'exécutent sous forme de graphiques avectf.function
sous le capot.Lors de l'utilisation
tf.keras.Model.compile
, définissezrun_eagerly = True
pour empêcher la logique duModel
d'être enveloppée dans unetf.function
.Utilisez
tf.data.experimental.enable_debug_mode
pour activer le mode débogage pourtf.data
. Lisez la documentation de l' API pour plus de détails.
Ne gardez pas tf.Tensors
dans vos objets
Ces objets tenseurs peuvent être créés soit dans une tf.function
soit dans le contexte impatient, et ces tenseurs se comportent différemment. Utilisez toujours tf.Tensor
s uniquement pour les valeurs intermédiaires.
Pour suivre l'état, utilisez tf.Variable
s car ils sont toujours utilisables dans les deux contextes. Lisez le guide tf.Variable
pour en savoir plus.
Ressources et lectures complémentaires
Lisez les guides et tutoriels TF2 pour en savoir plus sur l'utilisation de TF2.
Si vous utilisiez auparavant TF1.x, il est fortement recommandé de migrer votre code vers TF2. Lisez les guides de migration pour en savoir plus.