-
Notifications
You must be signed in to change notification settings - Fork 361
/
Copy pathrb_tutorial.py
865 lines (726 loc) · 30.2 KB
/
rb_tutorial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
"""
Using Replay Buffers
====================
**Author**: `Vincent Moens <https://github.com/vmoens>`_
.. _rb_tuto:
"""
######################################################################
# Replay buffers are a central piece of any RL or control algorithm.
# Supervised learning methods are usually characterized by a training loop
# where data is randomly pulled from a static dataset and fed successively
# to the model and loss function.
# In RL, things are often slightly different: the data is gathered using the
# model, then temporarily stored in a dynamic structure (the experience
# replay buffer), which serves as dataset for the loss module.
#
# As always, the context in which the buffer is used drastically conditions
# how it is built: some may wish to store trajectories when others will want
# to store single transitions. Specific sampling strategies may be preferable
# in contexts: some items can have a higher priority than others, or it can
# be important to sample with or without replacement.
# Computational factors may also come into play, such as the size of the buffer
# which may exceed the available RAM storage.
#
# For these reasons, TorchRL's replay buffers are fully composable: although
# they come with "batteries included", requiring a minimal effort to be built,
# they also support many customizations such as storage type,
# sampling strategy or data transforms.
#
#
# In this tutorial, you will learn:
#
# - How to build a :ref:`Replay Buffer (RB) <tuto_rb_vanilla>` and use it with
# any datatype;
# - How to customize the :ref:`buffer's storage <tuto_rb_storage>`;
# - How to use :ref:`RBs with TensorDict <tuto_rb_td>`;
# - How to :ref:`sample from or iterate over a replay buffer <tuto_rb_sampling>`,
# and how to define the sampling strategy;
# - How to use :ref:`prioritized replay buffers <tuto_rb_prb>`;
# - How to :ref:`transform data <tuto_rb_transform>` coming in and out from
# the buffer;
# - How to store :ref:`trajectories <tuto_rb_traj>` in the buffer.
#
#
# Basics: building a vanilla replay buffer
# ----------------------------------------
#
# .. _tuto_rb_vanilla:
#
# TorchRL's replay buffers are designed to prioritize modularity,
# composability, efficiency, and simplicity. For instance, creating a basic
# replay buffer is a straightforward process, as shown in the following
# example:
#
import gc
import tempfile
from torchrl.data import ReplayBuffer
buffer = ReplayBuffer()
######################################################################
# By default, this replay buffer will have a size of 1000. Let's check this
# by populating our buffer using the :meth:`~torchrl.data.ReplayBuffer.extend`
# method:
#
print("length before adding elements:", len(buffer))
buffer.extend(range(2000))
print("length after adding elements:", len(buffer))
######################################################################
# We have used the :meth:`~torchrl.data.ReplayBuffer.extend` method which is
# designed to add multiple items all at once. If the object that is passed
# to ``extend`` has more than one dimension, its first dimension is
# considered to be the one to be split in separate elements in the buffer.
#
# This essentially means that when adding multidimensional tensors or
# tensordicts to the buffer, the buffer will only look at the first dimension
# when counting the elements it holds in memory.
# If the object passed it not iterable, an exception will be thrown.
#
# To add items one at a time, the :meth:`~torchrl.data.ReplayBuffer.add` method
# should be used instead.
#
# Customizing the storage
# -----------------------
#
# .. _tuto_rb_storage:
#
# We see that the buffer has been capped to the first 1000 elements that we
# passed to it.
# To change the size, we need to customize our storage.
#
# TorchRL proposes three types of storages:
#
# - The :class:`~torchrl.data.ListStorage` stores elements independently in a
# list. It supports any data type, but this flexibility comes at the cost
# of efficiency;
# - The :class:`~torchrl.data.LazyTensorStorage` stores tensors data
# structures contiguously.
# It works naturally with :class:`~tensordidct.TensorDict`
# (or :class:`~torchrl.data.tensorclass`)
# objects. The storage is contiguous on a per-tensor basis, meaning that
# sampling will be more efficient than when using a list, but the
# implicit restriction is that any data passed to it must have the same
# basic properties (such as shape and dtype) as the first batch of data that
# was used to instantiate the buffer.
# Passing data that does not match this requirement will either raise an
# exception or lead to some undefined behaviors.
# - The :class:`~torchrl.data.LazyMemmapStorage` works as the
# :class:`~torchrl.data.LazyTensorStorage` in that it is lazy (i.e., it
# expects the first batch of data to be instantiated), and it requires data
# that match in shape and dtype for each batch stored. What makes this
# storage unique is that it points to disk files (or uses the filesystem
# storage), meaning that it can support very large datasets while still
# accessing data in a contiguous manner.
#
# Let us see how we can use each of these storages:
from torchrl.data import LazyMemmapStorage, LazyTensorStorage, ListStorage
# We define the maximum size of the buffer
size = 100
######################################################################
# A buffer with a list storage buffer can store any kind of data (but we must
# change the ``collate_fn`` since the default expects numerical data):
buffer_list = ReplayBuffer(storage=ListStorage(size), collate_fn=lambda x: x)
buffer_list.extend(["a", 0, "b"])
print(buffer_list.sample(3))
######################################################################
# Because it is the one with the lowest amount of assumption, the
# :class:`~torchrl.data.ListStorage` is the default storage in TorchRL.
#
# A :class:`~torchrl.data.LazyTensorStorage` can store data contiguously.
# This should be the preferred option when dealing with complicated but
# unchanging data structures of medium size:
buffer_lazytensor = ReplayBuffer(storage=LazyTensorStorage(size))
######################################################################
# Let us create a batch of data of size ``torch.Size([3])`` with 2 tensors
# stored in it:
#
import torch
from tensordict import TensorDict
data = TensorDict(
{
"a": torch.arange(12).view(3, 4),
("b", "c"): torch.arange(15).view(3, 5),
},
batch_size=[3],
)
print(data)
######################################################################
# The first call to :meth:`~torchrl.data.ReplayBuffer.extend` will
# instantiate the storage. The first dimension of the data is unbound into
# separate datapoints:
buffer_lazytensor.extend(data)
print(f"The buffer has {len(buffer_lazytensor)} elements")
######################################################################
# Let us sample from the buffer, and print the data:
#
sample = buffer_lazytensor.sample(5)
print("samples", sample["a"], sample["b", "c"])
######################################################################
# A :class:`~torchrl.data.LazyMemmapStorage` is created in the same manner.
# We can also customize the storage location on disk:
#
with tempfile.TemporaryDirectory() as tempdir:
buffer_lazymemmap = ReplayBuffer(
storage=LazyMemmapStorage(size, scratch_dir=tempdir)
)
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
print(
"the 'a' tensor is stored in", buffer_lazymemmap._storage._storage["a"].filename
)
print(
"the ('b', 'c') tensor is stored in",
buffer_lazymemmap._storage._storage["b", "c"].filename,
)
sample = buffer_lazytensor.sample(5)
print("samples: a=", sample["a"], "\n('b', 'c'):", sample["b", "c"])
del buffer_lazymemmap
######################################################################
# Integration with TensorDict
# ---------------------------
#
# .. _tuto_rb_td:
#
# The tensor location follows the same structure as the TensorDict that
# contains them: this makes it easy to save and load buffers during training.
#
# To use :class:`~tensordict.TensorDict` as a data carrier at its fullest
# potential, the :class:`~torchrl.data.TensorDictReplayBuffer` class can
# be used.
# One of its key benefits is its ability to handle the organization of sampled
# data, along with any additional information that may be required
# (such as sample indices).
#
# It can be built in the same manner as a standard
# :class:`~torchrl.data.ReplayBuffer` and can
# generally be used interchangeably.
#
from torchrl.data import TensorDictReplayBuffer
with tempfile.TemporaryDirectory() as tempdir:
buffer_lazymemmap = TensorDictReplayBuffer(
storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12
)
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
sample = buffer_lazymemmap.sample()
print("sample:", sample)
del buffer_lazymemmap
######################################################################
# Our sample now has an extra ``"index"`` key that indicates what indices
# were sampled.
# Let us have a look at these indices:
print(sample["index"])
######################################################################
# Integration with tensorclass
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The ReplayBuffer class and associated subclasses also work natively with
# :class:`~tensordict.tensorclass` classes, which can conveniently be used to
# encode datasets in a more explicit manner:
from tensordict import tensorclass
@tensorclass
class MyData:
images: torch.Tensor
labels: torch.Tensor
data = MyData(
images=torch.randint(
255,
(10, 64, 64, 3),
),
labels=torch.randint(100, (10,)),
batch_size=[10],
)
buffer_lazy = ReplayBuffer(storage=LazyTensorStorage(size), batch_size=12)
buffer_lazy.extend(data)
print(f"The buffer has {len(buffer_lazy)} elements")
sample = buffer_lazy.sample()
print("sample:", sample)
######################################################################
# As expected. the data has the proper class and shape!
#
# Integration with other tensor structures (PyTrees)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# TorchRL's replay buffers also work with any pytree data structure.
# A PyTree is a nested structure of arbitrary depth made of dicts, lists and/or
# tuples where the leaves are tensors.
# This means that one can store in contiguous memory any such tree structure!
# Various storages can be used:
# :class:`~torchrl.data.replay_buffers.TensorStorage`,
# :class:`~torchrl.data.replay_buffers.LazyMemmapStorage`
# or :class:`~torchrl.data.replay_buffers.LazyTensorStorage` all accept this
# kind of data.
#
# Here is a brief demonstration of what this feature looks like:
#
from torch.utils._pytree import tree_map
######################################################################
# Let's build our replay buffer on RAM:
rb = ReplayBuffer(storage=LazyTensorStorage(size))
data = {
"a": torch.randn(3),
"b": {"c": (torch.zeros(2), [torch.ones(1)])},
30: -torch.ones(()), # non-string keys also work
}
rb.add(data)
# The sample has a similar structure to the data (with a leading dimension of 10 for each tensor)
sample = rb.sample(10)
######################################################################
# With pytrees, any callable can be used as a transform:
def transform(x):
# Zeros all the data in the pytree
return tree_map(lambda y: y * 0, x)
rb.append_transform(transform)
sample = rb.sample(batch_size=12)
######################################################################
# let's check that our transform did its job:
def assert0(x):
assert (x == 0).all()
tree_map(assert0, sample)
######################################################################
# Sampling and iterating over buffers
# -----------------------------------
#
# .. _tuto_rb_sampling:
#
# Replay Buffers support multiple sampling strategies:
#
# - If the batch-size is fixed and can be defined at construction time, it can
# be passed as keyword argument to the buffer;
# - With a fixed batch-size, the replay buffer can be iterated over to gather
# samples;
# - If the batch-size is dynamic, it can be passed to the
# :class:`~torchrl.data.ReplayBuffer.sample` method
# on-the-fly.
#
# Sampling can be done using multithreading, but this is incompatible with the
# last option (at it requires the buffer to know in advance the size of the
# next batch).
#
# Let us see a few examples:
#
# Fixed batch-size
# ~~~~~~~~~~~~~~~~
#
# If the batch-size is passed during construction, it should be omitted when
# sampling:
data = MyData(
images=torch.randint(
255,
(200, 64, 64, 3),
),
labels=torch.randint(100, (200,)),
batch_size=[200],
)
buffer_lazy = ReplayBuffer(storage=LazyTensorStorage(size), batch_size=128)
buffer_lazy.extend(data)
buffer_lazy.sample()
######################################################################
# This batch of data has the size that we wanted it to have (128).
#
# To enable multithreaded sampling, just pass a positive integer to the
# ``prefetch`` keyword argument during construction. This should speed up
# sampling considerably whenever sampling is time consuming (e.g., when
# using prioritized samplers):
buffer_lazy = ReplayBuffer(
storage=LazyTensorStorage(size), batch_size=128, prefetch=10
) # creates a queue of 10 elements to be prefetched in the background
buffer_lazy.extend(data)
print(buffer_lazy.sample())
######################################################################
# Iterating over the buffer with a fixed batch-size
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We can also iterate over the buffer like we would do with a regular
# dataloader, as long as the batch-size is predefined:
for i, data in enumerate(buffer_lazy):
if i == 3:
print(data)
break
del buffer_lazy
######################################################################
# Due to the fact that our sampling technique is entirely random and does not
# prevent replacement, the iterator in question is infinite. However, we can
# make use of the
# :class:`~torchrl.data.replay_buffers.SamplerWithoutReplacement`
# instead, which will transform our buffer into a finite iterator:
#
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
buffer_lazy = ReplayBuffer(
storage=LazyTensorStorage(size), batch_size=32, sampler=SamplerWithoutReplacement()
)
######################################################################
# we create a data that is big enough to get a couple of samples
data = TensorDict(
{
"a": torch.arange(64).view(16, 4),
("b", "c"): torch.arange(128).view(16, 8),
},
batch_size=[16],
)
buffer_lazy.extend(data)
for _i, _ in enumerate(buffer_lazy):
continue
print(f"A total of {_i+1} batches have been collected")
del buffer_lazy
######################################################################
# Dynamic batch-size
# ~~~~~~~~~~~~~~~~~~
#
# In contrast to what we have seen earlier, the ``batch_size`` keyword
# argument can be omitted and passed directly to the ``sample`` method:
buffer_lazy = ReplayBuffer(
storage=LazyTensorStorage(size), sampler=SamplerWithoutReplacement()
)
buffer_lazy.extend(data)
print("sampling 3 elements:", buffer_lazy.sample(3))
print("sampling 5 elements:", buffer_lazy.sample(5))
del buffer_lazy
######################################################################
# Prioritized Replay buffers
# --------------------------
#
# .. _tuto_rb_prb:
#
# TorchRL also provides an interface for
# `prioritized replay buffers <https://arxiv.org/abs/1511.05952>`_.
# This buffer class samples data according to a priority signal that is passed
# through the data.
#
# Although this tool is compatible with non-tensordict data, we encourage
# using TensorDict instead as it makes it possible to carry meta-data in and
# out from the buffer with little effort.
#
# Let us first see how to build a prioritized replay buffer in the generic
# case. The :math:`\alpha` and :math:`\beta` hyperparameters
# have to be manually set:
from torchrl.data.replay_buffers.samplers import PrioritizedSampler
size = 100
rb = ReplayBuffer(
storage=ListStorage(size),
sampler=PrioritizedSampler(max_capacity=size, alpha=0.8, beta=1.1),
collate_fn=lambda x: x,
)
######################################################################
# Extending the replay buffer returns the items indices, which we will need
# later to update the priority:
indices = rb.extend([1, "foo", None])
######################################################################
# The sampler expects to have a priority for each element. When added to the
# buffer, the priority is set to a default value of 1. Once the priority has
# been computed (usually through the loss), it must be updated in the buffer.
#
# This is done via the :meth:`~torchrl.data.ReplayBuffer.update_priority`
# method, which requires the indices as well as the priority.
# We assign an artificially high priority to the second sample in the dataset
# to observe its effect on sampling:
#
rb.update_priority(index=indices, priority=torch.tensor([0, 1_000, 0.1]))
######################################################################
# We observe that sampling from the buffer returns mostly the second sample
# (``"foo"``):
#
sample, info = rb.sample(10, return_info=True)
print(sample)
######################################################################
# The info contains the relative weights of the items as well as the indices.
print(info)
######################################################################
# We see that using a prioritized replay buffer requires a series of extra
# steps in the training loop compared with a regular buffer:
#
# - After collecting data and extending the buffer, the priority of the
# items must be updated;
# - After computing the loss and getting a "priority signal" from it, we must
# update again the priority of the items in the buffer.
# This requires us to keep track of the indices.
#
# This drastically hampers the reusability of the buffer: if one is to write
# a training script where both a prioritized and a regular buffer can be
# created, she must add a considerable amount of control flow to make sure
# that the appropriate methods are called at the appropriate place, if and
# only if a prioritized buffer is being used.
#
# Let us see how we can improve this with :class:`~tensordict.TensorDict`.
# We saw that the :class:`~torchrl.data.TensorDictReplayBuffer` returns data
# augmented with their relative storage indices. One feature we did not mention
# is that this class also ensures that the priority
# signal is automatically parsed to the prioritized sampler if present during
# extension.
#
# The combination of these features simplifies things in several ways:
# - When extending the buffer, the priority signal will automatically be
# parsed if present and the priority will accurately be assigned;
# - The indices will be stored in the sampled tensordicts, making it easy to
# update the priority after the loss computation.
# - When computing the loss, the priority signal will be registered in the
# tensordict passed to the loss module, making it possible to update the
# weights without effort:
#
# ..code - block::Python
#
# >>> data = replay_buffer.sample()
# >>> loss_val = loss_module(data)
# >>> replay_buffer.update_tensordict_priority(data)
#
# The following code illustrates these concepts. We build a replay buffer with
# a prioritized sampler, and indicate in the constructor the entry where
# the priority signal should be fetched:
rb = TensorDictReplayBuffer(
storage=ListStorage(size),
sampler=PrioritizedSampler(size, alpha=0.8, beta=1.1),
priority_key="td_error",
batch_size=1024,
)
######################################################################
# Let us choose a priority signal that is proportional to the storage index:
#
data["td_error"] = torch.arange(data.numel())
rb.extend(data)
sample = rb.sample()
######################################################################
# higher indices should occur more frequently:
from matplotlib import pyplot as plt
fig = plt.hist(sample["index"].numpy())
plt.show()
######################################################################
# Once we have worked with our sample, we update the priority key using
# the :meth:`torchrl.data.TensorDictReplayBuffer.update_tensordict_priority`
# method.
# For the sake of showing how this works, let us revert the priority of the
# sampled items:
#
sample = rb.sample()
sample["td_error"] = data.numel() - sample["index"]
rb.update_tensordict_priority(sample)
######################################################################
# Now, higher indices should occur less frequently:
sample = rb.sample()
fig = plt.hist(sample["index"].numpy())
plt.show()
######################################################################
# Using transforms
# ----------------
#
# .. _tuto_rb_transform:
#
# The data stored in a replay buffer may not be ready to be presented to a
# loss module.
# In some cases, the data produced by a collector can be too heavy to be
# saved as-is. Examples of this include converting images from ``uint8`` to
# floating point tensors, or concatenating successive frames when using
# decision transformers.
#
# Data can be processed in and out of a buffer just by appending the
# appropriate transform to it.
# Here are a few examples:
#
# Saving raw images
# ~~~~~~~~~~~~~~~~~
#
# ``uint8``-typed tensors are comparatively much less memory expensive than
# the floating point tensors we usually feed to our models. For this reason,
# it can be useful to save the raw images.
# The following script show how one can build a collector that returns only
# the raw images but uses the transformed ones for inference, and how these
# transformations can be recycled in the replay buffer:
from torchrl.collectors import SyncDataCollector
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
Compose,
GrayScale,
Resize,
ToTensorImage,
TransformedEnv,
)
from torchrl.envs.utils import RandomPolicy
env = TransformedEnv(
GymEnv("CartPole-v1", from_pixels=True),
Compose(
ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
Resize(in_keys=["pixels_trsf"], w=64, h=64),
GrayScale(in_keys=["pixels_trsf"]),
),
)
######################################################################
# let us have a look at a rollout:
print(env.rollout(3))
######################################################################
# We have just created an environment that produces pixels. These images
# are processed to be fed to a policy.
# We would like to store the raw images, and not their transforms.
# To do this, we will append a transform to the collector to select the keys
# we want to see appearing:
from torchrl.envs.transforms import ExcludeTransform
collector = SyncDataCollector(
env,
RandomPolicy(env.action_spec),
frames_per_batch=10,
total_frames=1000,
postproc=ExcludeTransform("pixels_trsf", ("next", "pixels_trsf"), "collector"),
)
######################################################################
# Let us have a look at a batch of data, and control that the
# ``"pixels_trsf"`` keys have been discarded:
for data in collector:
print(data)
break
collector.shutdown()
######################################################################
# We create a replay buffer with the same transform as the environment.
# There is, however, a detail that needs to be addressed: transforms
# used without environments are oblivious to the data structure.
# When appending a transform to an environment, the data in the ``"next"``
# nested tensordict is transformed first and then copied at the root during
# the rollout execution. When working with static data, this is not the case.
# Nevertheless, our data comes with a nested "next" tensordict that will be
# ignored by our transform if we don't explicitly instruct it to take care of
# it. We manually add these keys to the transform:
t = Compose(
ToTensorImage(
in_keys=["pixels", ("next", "pixels")],
out_keys=["pixels_trsf", ("next", "pixels_trsf")],
),
Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(1000), transform=t, batch_size=16)
rb.extend(data)
######################################################################
# We can check that a ``sample`` method sees the transformed images reappear:
#
print(rb.sample())
######################################################################
# A more complex examples: using CatFrames
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The :class:`~torchrl.envs.transforms.CatFrames` transform unfolds the observations
# through time, creating a n-back memory of past events that allow the model
# to take the past events into account (in the case of POMDPs or with
# recurrent policies such as Decision Transformers). Storing these concatenated
# frames can consume a considerable amount of memory. It can also be
# problematic when the n-back window needs to be different (usually longer)
# during training and inference. We solve this problem by executing the
# ``CatFrames`` transform separately in the two phases.
from torchrl.envs import CatFrames, UnsqueezeTransform
######################################################################
# We create a standard list of transforms for environments that return pixel-based
# observations:
env = TransformedEnv(
GymEnv("CartPole-v1", from_pixels=True),
Compose(
ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
Resize(in_keys=["pixels_trsf"], w=64, h=64),
GrayScale(in_keys=["pixels_trsf"]),
UnsqueezeTransform(-4, in_keys=["pixels_trsf"]),
CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]),
),
)
collector = SyncDataCollector(
env,
RandomPolicy(env.action_spec),
frames_per_batch=10,
total_frames=1000,
)
for data in collector:
print(data)
break
collector.shutdown()
######################################################################
# The buffer transform looks pretty much like the environment one, but with
# extra ``("next", ...)`` keys like before:
#
t = Compose(
ToTensorImage(
in_keys=["pixels", ("next", "pixels")],
out_keys=["pixels_trsf", ("next", "pixels_trsf")],
),
Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
UnsqueezeTransform(-4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(size), transform=t, batch_size=16)
data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf"))
rb.add(data_exclude)
######################################################################
# Let us sample one batch from the buffer. The shape of the transformed
# pixel keys should have a length of 4 along the 4th dimension starting from
# the end:
#
s = rb.sample(1) # the buffer has only one element
print(s)
######################################################################
# After a bit of processing (excluding non-used keys etc), we see that the
# data generated online and offline match!
assert (data.exclude("collector") == s.squeeze(0).exclude("index", "collector")).all()
######################################################################
# Storing trajectories
# --------------------
#
# .. _tuto_rb_traj:
#
# In many cases, it is desirable to access trajectories from the buffer rather
# than simple transitions. TorchRL offers multiple ways of achieving this.
#
# The preferred way is currently to store trajectories along the first
# dimension of the buffer and use a :class:`~torchrl.data.SliceSampler` to
# sample these batches of data. This class only needs a couple of information
# about your data structure to do its job (not that as of now it is only
# compatible with tensordict-structured data): the number of slices or their
# length and some information about where the separation between the
# episodes can be found (e.g. :ref:`recall that <gs_storage_collector>` with a
# :ref:`DataCollector <ref_collectors>`, the trajectory id is stored in
# ``("collector", "traj_ids")``). In this simple example, we construct a data
# with 4 consecutive short trajectories and sample 4 slices out of it, each of
# length 2 (since the batch size is 8, and 8 items // 4 slices = 2 time steps).
# We mark the steps as well.
from torchrl.data import SliceSampler
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(size),
sampler=SliceSampler(traj_key="episode", num_slices=4),
batch_size=8,
)
episode = torch.zeros(10, dtype=torch.int)
episode[:3] = 1
episode[3:5] = 2
episode[5:7] = 3
episode[7:] = 4
steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)])
data = TensorDict(
{
"episode": episode,
"obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5),
"act": torch.randn((20,)).expand(10, 20),
"other": torch.randn((20, 50)).expand(10, 20, 50),
"steps": steps,
},
[10],
)
rb.extend(data)
sample = rb.sample()
print("episode are grouped", sample["episode"])
print("steps are successive", sample["steps"])
gc.collect()
######################################################################
# Conclusion
# ----------
#
# We have seen how a replay buffer can be used in TorchRL, from its simplest
# usage to more advanced ones where the data need to be transformed or stored
# in particular ways.
# You should now be able to:
#
# - Create a Replay Buffer, customize its storage, sampler and transforms;
# - Choose the best storage type for your problem (list, memory or disk-based);
# - Minimize the memory footprint of your buffer.
#
# Next steps
# ----------
#
# - Check the data API reference to learn about offline datasets in TorchRL,
# which are based on our Replay Buffer API;
# - Check other samplers such as
# :class:`~torchrl.data.SamplerWithoutReplacement`,
# :class:`~torchrl.data.PrioritizedSliceSampler` and
# :class:`~torchrl.data.SliceSamplerWithoutReplacement`, or other writers
# such as :class:`~torchrl.data.TensorDictMaxValueWriter`.
# - Check how to checkpoint ReplayBuffers in :ref:`the doc <checkpoint-rb>`.