-
Notifications
You must be signed in to change notification settings - Fork 361
/
Copy pathmulti_task.py
215 lines (170 loc) · 6.97 KB
/
multi_task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
Task-specific policy in multi-task environments
================================================
This tutorial details how multi-task policies and batched environments can be used.
"""
##############################################################################
# At the end of this tutorial, you will be capable of writing policies that
# can compute actions in diverse settings using a distinct set of weights.
# You will also be able to execute diverse environments in parallel.
# sphinx_gallery_start_ignore
import warnings
warnings.filterwarnings("ignore")
from torch import multiprocessing
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
is_sphinx = __sphinx_build__
except NameError:
is_sphinx = False
try:
multiprocessing.set_start_method("spawn" if is_sphinx else "fork")
except RuntimeError:
pass
# sphinx_gallery_end_ignore
from tensordict import LazyStackedTensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
##############################################################################
from torchrl.envs import CatTensors, Compose, DoubleToFloat, ParallelEnv, TransformedEnv
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.modules import MLP
###############################################################################
# We design two environments, one humanoid that must complete the stand task
# and another that must learn to walk.
env1 = DMControlEnv("humanoid", "stand")
env1_obs_keys = list(env1.observation_spec.keys())
env1 = TransformedEnv(
env1,
Compose(
CatTensors(env1_obs_keys, "observation_stand", del_keys=False),
CatTensors(env1_obs_keys, "observation"),
DoubleToFloat(
in_keys=["observation_stand", "observation"],
in_keys_inv=["action"],
),
),
)
env2 = DMControlEnv("humanoid", "walk")
env2_obs_keys = list(env2.observation_spec.keys())
env2 = TransformedEnv(
env2,
Compose(
CatTensors(env2_obs_keys, "observation_walk", del_keys=False),
CatTensors(env2_obs_keys, "observation"),
DoubleToFloat(
in_keys=["observation_walk", "observation"],
in_keys_inv=["action"],
),
),
)
###############################################################################
tdreset1 = env1.reset()
tdreset2 = env2.reset()
# With LazyStackedTensorDict, stacking is done in a lazy manner: the original tensordicts
# can still be recovered by indexing the main tensordict
tdreset = LazyStackedTensorDict.lazy_stack([tdreset1, tdreset2], 0)
assert tdreset[0] is tdreset1
###############################################################################
print(tdreset[0])
###############################################################################
# Policy
# ^^^^^^
#
# We will design a policy where a backbone reads the "observation" key.
# Then specific sub-components will read the "observation_stand" and
# "observation_walk" keys of the stacked tensordicts, if they are present,
# and pass them through the dedicated sub-network.
action_dim = env1.action_spec.shape[-1]
###############################################################################
policy_common = TensorDictModule(
nn.Linear(67, 64), in_keys=["observation"], out_keys=["hidden"]
)
policy_stand = TensorDictModule(
MLP(67 + 64, action_dim, depth=2),
in_keys=["observation_stand", "hidden"],
out_keys=["action"],
)
policy_walk = TensorDictModule(
MLP(67 + 64, action_dim, depth=2),
in_keys=["observation_walk", "hidden"],
out_keys=["action"],
)
seq = TensorDictSequential(
policy_common, policy_stand, policy_walk, partial_tolerant=True
)
###############################################################################
# Let's check that our sequence outputs actions for a single env (stand).
seq(env1.reset())
###############################################################################
# Let's check that our sequence outputs actions for a single env (walk).
seq(env2.reset())
###############################################################################
# This also works with the stack: now the stand and walk keys have
# disappeared, because they're not shared by all tensordicts. But the
# ``TensorDictSequential`` still performed the operations. Note that the
# backbone was executed in a vectorized way - not in a loop - which is more efficient.
seq(tdreset)
###############################################################################
# Executing diverse tasks in parallel
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# We can parallelize the operations if the common key-value pairs share the
# same specs (in particular their shape and dtype must match: you can't do the
# following if the observation shapes are different but are pointed to by the
# same key).
#
# If ParallelEnv receives a single env making function, it will assume that
# a single task has to be performed. If a list of functions is provided, then
# it will assume that we are in a multi-task setting.
def env1_maker():
return TransformedEnv(
DMControlEnv("humanoid", "stand"),
Compose(
CatTensors(env1_obs_keys, "observation_stand", del_keys=False),
CatTensors(env1_obs_keys, "observation"),
DoubleToFloat(
in_keys=["observation_stand", "observation"],
in_keys_inv=["action"],
),
),
)
def env2_maker():
return TransformedEnv(
DMControlEnv("humanoid", "walk"),
Compose(
CatTensors(env2_obs_keys, "observation_walk", del_keys=False),
CatTensors(env2_obs_keys, "observation"),
DoubleToFloat(
in_keys=["observation_walk", "observation"],
in_keys_inv=["action"],
),
),
)
env = ParallelEnv(2, [env1_maker, env2_maker])
assert not env._single_task
tdreset = env.reset()
print(tdreset)
print(tdreset[0])
print(tdreset[1]) # should be different
###############################################################################
# Let's pass the output through our network.
tdreset = seq(tdreset)
print(tdreset)
print(tdreset[0])
print(tdreset[1]) # should be different but all have an "action" key
env.step(tdreset) # computes actions and execute steps in parallel
print(tdreset)
print(tdreset[0])
print(tdreset[1]) # next_observation has now been written
###############################################################################
# Rollout
# ^^^^^^^
td_rollout = env.rollout(100, policy=seq, return_contiguous=False)
###############################################################################
td_rollout[:, 0] # tensordict of the first step: only the common keys are shown
###############################################################################
td_rollout[0] # tensordict of the first env: the stand obs is present
env.close()
del env