tf.keras.utils.PyDataset
Stay organized with collections
Save and categorize content based on your preferences.
Base class for defining a parallel dataset using Python code.
tf.keras.utils.PyDataset(
workers=1, use_multiprocessing=False, max_queue_size=10
)
Every PyDataset
must implement the __getitem__()
and the __len__()
methods. If you want to modify your dataset between epochs,
you may additionally implement on_epoch_end()
.
The __getitem__()
method should return a complete batch
(not a single sample), and the __len__
method should return
the number of batches in the dataset (rather than the number of samples).
Args |
workers
|
Number of workers to use in multithreading or
multiprocessing.
|
use_multiprocessing
|
Whether to use Python multiprocessing for
parallelism. Setting this to True means that your
dataset will be replicated in multiple forked processes.
This is necessary to gain compute-level (rather than I/O level)
benefits from parallelism. However it can only be set to
True if your dataset can be safely pickled.
|
max_queue_size
|
Maximum number of batches to keep in the queue
when iterating over the dataset in a multithreaded or
multipricessed setting.
Reduce this value to reduce the CPU memory consumption of
your dataset. Defaults to 10.
|
Notes:
PyDataset
is a safer way to do multiprocessing.
This structure guarantees that the model will only train
once on each sample per epoch, which is not the case
with Python generators.
- The arguments
workers
, use_multiprocessing
, and max_queue_size
exist to configure how fit()
uses parallelism to iterate
over the dataset. They are not being used by the PyDataset
class
directly. When you are manually iterating over a PyDataset
,
no parallelism is applied.
Example:
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class CIFAR10PyDataset(keras.utils.PyDataset):
def __init__(self, x_set, y_set, batch_size, **kwargs):
super().__init__(**kwargs)
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
# Return number of batches.
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
# Return x, y for batch idx.
low = idx * self.batch_size
# Cap upper bound at array length; the last batch may be smaller
# if the total number of items is not a multiple of batch size.
high = min(low + self.batch_size, len(self.x))
batch_x = self.x[low:high]
batch_y = self.y[low:high]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
Attributes |
max_queue_size
|
|
num_batches
|
Number of batches in the PyDataset.
|
use_multiprocessing
|
|
workers
|
|
Methods
on_epoch_end
View source
on_epoch_end()
Method called at the end of every epoch.
__getitem__
View source
__getitem__(
index
)
Gets batch at position index
.
Args |
index
|
position of the batch in the PyDataset.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-06-07 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-06-07 UTC."],[],[],null,["# tf.keras.utils.PyDataset\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/keras-team/keras/tree/v3.3.3/keras/src/trainers/data_adapters/py_dataset_adapter.py#L18-L175) |\n\nBase class for defining a parallel dataset using Python code.\n\n#### View aliases\n\n\n**Main aliases**\n\n[`tf.keras.utils.Sequence`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/PyDataset)\n\n\u003cbr /\u003e\n\n tf.keras.utils.PyDataset(\n workers=1, use_multiprocessing=False, max_queue_size=10\n )\n\nEvery `PyDataset` must implement the `__getitem__()` and the `__len__()`\nmethods. If you want to modify your dataset between epochs,\nyou may additionally implement `on_epoch_end()`.\nThe `__getitem__()` method should return a complete batch\n(not a single sample), and the `__len__` method should return\nthe number of batches in the dataset (rather than the number of samples).\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-----------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `workers` | Number of workers to use in multithreading or multiprocessing. |\n| `use_multiprocessing` | Whether to use Python multiprocessing for parallelism. Setting this to `True` means that your dataset will be replicated in multiple forked processes. This is necessary to gain compute-level (rather than I/O level) benefits from parallelism. However it can only be set to `True` if your dataset can be safely pickled. |\n| `max_queue_size` | Maximum number of batches to keep in the queue when iterating over the dataset in a multithreaded or multipricessed setting. Reduce this value to reduce the CPU memory consumption of your dataset. Defaults to 10. |\n\n\u003cbr /\u003e\n\n#### Notes:\n\n- `PyDataset` is a safer way to do multiprocessing. This structure guarantees that the model will only train once on each sample per epoch, which is not the case with Python generators.\n- The arguments `workers`, `use_multiprocessing`, and `max_queue_size` exist to configure how `fit()` uses parallelism to iterate over the dataset. They are not being used by the `PyDataset` class directly. When you are manually iterating over a `PyDataset`, no parallelism is applied.\n\n#### Example:\n\n from skimage.io import imread\n from skimage.transform import resize\n import numpy as np\n import math\n\n # Here, `x_set` is list of path to the images\n # and `y_set` are the associated classes.\n\n class CIFAR10PyDataset(keras.utils.PyDataset):\n\n def __init__(self, x_set, y_set, batch_size, **kwargs):\n super().__init__(**kwargs)\n self.x, self.y = x_set, y_set\n self.batch_size = batch_size\n\n def __len__(self):\n # Return number of batches.\n return math.ceil(len(self.x) / self.batch_size)\n\n def __getitem__(self, idx):\n # Return x, y for batch idx.\n low = idx * self.batch_size\n # Cap upper bound at array length; the last batch may be smaller\n # if the total number of items is not a multiple of batch size.\n high = min(low + self.batch_size, len(self.x))\n batch_x = self.x[low:high]\n batch_y = self.y[low:high]\n\n return np.array([\n resize(imread(file_name), (200, 200))\n for file_name in batch_x]), np.array(batch_y)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|-----------------------|-------------------------------------|\n| `max_queue_size` | \u003cbr /\u003e \u003cbr /\u003e |\n| `num_batches` | Number of batches in the PyDataset. |\n| `use_multiprocessing` | \u003cbr /\u003e \u003cbr /\u003e |\n| `workers` | \u003cbr /\u003e \u003cbr /\u003e |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `on_epoch_end`\n\n[View source](https://github.com/keras-team/keras/tree/v3.3.3/keras/src/trainers/data_adapters/py_dataset_adapter.py#L173-L175) \n\n on_epoch_end()\n\nMethod called at the end of every epoch.\n\n### `__getitem__`\n\n[View source](https://github.com/keras-team/keras/tree/v3.3.3/keras/src/trainers/data_adapters/py_dataset_adapter.py#L146-L155) \n\n __getitem__(\n index\n )\n\nGets batch at position `index`.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|---------|-----------------------------------------|\n| `index` | position of the batch in the PyDataset. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A batch ||\n\n\u003cbr /\u003e"]]