-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Random Network Distillation for Torch #4473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fc17b99
02bd483
35f02c5
45605e8
b4d58de
1209229
00c5237
03e51f4
10e5d0c
d0330f5
7313872
13a7694
d6677f7
7c1c8a4
78f3fca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
behaviors: | ||
Pyramids: | ||
trainer_type: ppo | ||
hyperparameters: | ||
batch_size: 128 | ||
buffer_size: 2048 | ||
learning_rate: 0.0003 | ||
beta: 0.01 | ||
epsilon: 0.2 | ||
lambd: 0.95 | ||
num_epoch: 3 | ||
learning_rate_schedule: linear | ||
network_settings: | ||
normalize: false | ||
hidden_units: 512 | ||
num_layers: 2 | ||
vis_encode_type: simple | ||
reward_signals: | ||
extrinsic: | ||
gamma: 0.99 | ||
strength: 1.0 | ||
rnd: | ||
gamma: 0.99 | ||
strength: 0.01 | ||
encoding_size: 64 | ||
learning_rate: 0.0001 | ||
keep_checkpoints: 5 | ||
max_steps: 3000000 | ||
time_horizon: 128 | ||
summary_freq: 30000 | ||
framework: pytorch | ||
threaded: true |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
- [Extrinsic Rewards](#extrinsic-rewards) | ||
- [Curiosity Intrinsic Reward](#curiosity-intrinsic-reward) | ||
- [GAIL Intrinsic Reward](#gail-intrinsic-reward) | ||
- [RND Intrinsic Reward](#rnd-intrinsic-reward) | ||
- [Reward Signal Settings for SAC](#reward-signal-settings-for-sac) | ||
- [Behavioral Cloning](#behavioral-cloning) | ||
- [Memory-enhanced Agents using Recurrent Neural Networks](#memory-enhanced-agents-using-recurrent-neural-networks) | ||
|
@@ -118,6 +119,18 @@ settings: | |
| `gail -> use_actions` | (default = `false`) Determines whether the discriminator should discriminate based on both observations and actions, or just observations. Set to True if you want the agent to mimic the actions from the demonstrations, and False if you'd rather have the agent visit the same states as in the demonstrations but with possibly different actions. Setting to False is more likely to be stable, especially with imperfect demonstrations, but may learn slower. | | ||
| `gail -> use_vail` | (default = `false`) Enables a variational bottleneck within the GAIL discriminator. This forces the discriminator to learn a more general representation and reduces its tendency to be "too good" at discriminating, making learning more stable. However, it does increase training time. Enable this if you notice your imitation learning is unstable, or unable to learn the task at hand. | | ||
|
||
### RND Intrinsic Reward | ||
|
||
Random Network Distillation (RND) is only available for the PyTorch trainers. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if you try to enable this with tensorflow? Or without torch installed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You get an error that says the intrinsic reward signal type is not recognized. |
||
To enable RND, provide these settings: | ||
|
||
| **Setting** | **Description** | | ||
| :--------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | ||
| `rnd -> strength` | (default = `1.0`) Magnitude of the curiosity reward generated by the intrinsic rnd module. This should be scaled in order to ensure it is large enough to not be overwhelmed by extrinsic reward signals in the environment. Likewise it should not be too large to overwhelm the extrinsic reward signal. <br><br>Typical range: `0.001` - `0.01` | | ||
| `rnd -> gamma` | (default = `0.99`) Discount factor for future rewards. <br><br>Typical range: `0.8` - `0.995` | | ||
| `rnd -> encoding_size` | (default = `64`) Size of the encoding used by the intrinsic RND model. <br><br>Typical range: `64` - `256` | | ||
| `curiosity -> learning_rate` | (default = `3e-4`) Learning rate used to update the RND module. This should be large enough for the RND module to quickly learn the state representation, but small enough to allow for stable learning. <br><br>Typical range: `1e-5` - `1e-3` | ||
|
||
|
||
## Behavioral Cloning | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import numpy as np | ||
import pytest | ||
from mlagents.torch_utils import torch | ||
from mlagents.trainers.torch.components.reward_providers import ( | ||
RNDRewardProvider, | ||
create_reward_provider, | ||
) | ||
from mlagents_envs.base_env import BehaviorSpec, ActionType | ||
from mlagents.trainers.settings import RNDSettings, RewardSignalType | ||
from mlagents.trainers.tests.torch.test_reward_providers.utils import ( | ||
create_agent_buffer, | ||
) | ||
|
||
SEED = [42] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"behavior_spec", | ||
[ | ||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), | ||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), | ||
], | ||
) | ||
def test_construction(behavior_spec: BehaviorSpec) -> None: | ||
curiosity_settings = RNDSettings(32, 0.01) | ||
curiosity_settings.strength = 0.1 | ||
curiosity_rp = RNDRewardProvider(behavior_spec, curiosity_settings) | ||
assert curiosity_rp.strength == 0.1 | ||
assert curiosity_rp.name == "RND" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"behavior_spec", | ||
[ | ||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), | ||
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ActionType.CONTINUOUS, 5), | ||
BehaviorSpec([(10,), (64, 66, 1)], ActionType.DISCRETE, (2, 3)), | ||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)), | ||
], | ||
) | ||
def test_factory(behavior_spec: BehaviorSpec) -> None: | ||
curiosity_settings = RNDSettings(32, 0.01) | ||
curiosity_rp = create_reward_provider( | ||
RewardSignalType.RND, behavior_spec, curiosity_settings | ||
) | ||
assert curiosity_rp.name == "RND" | ||
|
||
|
||
@pytest.mark.parametrize("seed", SEED) | ||
@pytest.mark.parametrize( | ||
"behavior_spec", | ||
[ | ||
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ActionType.CONTINUOUS, 5), | ||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), | ||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)), | ||
], | ||
) | ||
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None: | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
rnd_settings = RNDSettings(32, 0.01) | ||
rnd_rp = RNDRewardProvider(behavior_spec, rnd_settings) | ||
buffer = create_agent_buffer(behavior_spec, 5) | ||
rnd_rp.update(buffer) | ||
reward_old = rnd_rp.evaluate(buffer)[0] | ||
for _ in range(100): | ||
rnd_rp.update(buffer) | ||
reward_new = rnd_rp.evaluate(buffer)[0] | ||
assert reward_new < reward_old |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
from typing import Dict | ||
from mlagents.torch_utils import torch | ||
|
||
from mlagents.trainers.buffer import AgentBuffer | ||
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( | ||
BaseRewardProvider, | ||
) | ||
from mlagents.trainers.settings import RNDSettings | ||
|
||
from mlagents_envs.base_env import BehaviorSpec | ||
from mlagents.trainers.torch.utils import ModelUtils | ||
from mlagents.trainers.torch.networks import NetworkBody | ||
from mlagents.trainers.settings import NetworkSettings, EncoderType | ||
|
||
|
||
class RNDRewardProvider(BaseRewardProvider): | ||
""" | ||
Implementation of Random Network Distillation : https://arxiv.org/pdf/1810.12894.pdf | ||
""" | ||
|
||
def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None: | ||
super().__init__(specs, settings) | ||
self._ignore_done = True | ||
self._random_network = RNDNetwork(specs, settings) | ||
self._training_network = RNDNetwork(specs, settings) | ||
self.optimizer = torch.optim.Adam( | ||
self._training_network.parameters(), lr=settings.learning_rate | ||
) | ||
|
||
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: | ||
with torch.no_grad(): | ||
target = self._random_network(mini_batch) | ||
prediction = self._training_network(mini_batch) | ||
rewards = torch.sum((prediction - target) ** 2, dim=1) | ||
return rewards.detach().cpu().numpy() | ||
|
||
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: | ||
with torch.no_grad(): | ||
target = self._random_network(mini_batch) | ||
prediction = self._training_network(mini_batch) | ||
loss = torch.mean(torch.sum((prediction - target) ** 2, dim=1)) | ||
self.optimizer.zero_grad() | ||
loss.backward() | ||
self.optimizer.step() | ||
return {"Losses/RND Loss": loss.detach().cpu().numpy()} | ||
|
||
|
||
class RNDNetwork(torch.nn.Module): | ||
EPSILON = 1e-10 | ||
|
||
def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None: | ||
super().__init__() | ||
self._policy_specs = specs | ||
state_encoder_settings = NetworkSettings( | ||
normalize=True, | ||
hidden_units=settings.encoding_size, | ||
num_layers=3, | ||
vis_encode_type=EncoderType.SIMPLE, | ||
memory=None, | ||
) | ||
self._encoder = NetworkBody(specs.observation_shapes, state_encoder_settings) | ||
|
||
def forward(self, mini_batch: AgentBuffer) -> torch.Tensor: | ||
n_vis = len(self._encoder.visual_processors) | ||
hidden, _ = self._encoder.forward( | ||
vec_inputs=[ | ||
ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float) | ||
], | ||
vis_inputs=[ | ||
ModelUtils.list_to_tensor( | ||
mini_batch["visual_obs%d" % i], dtype=torch.float | ||
) | ||
for i in range(n_vis) | ||
], | ||
) | ||
self._encoder.update_normalization(torch.tensor(mini_batch["vector_obs"])) | ||
return hidden |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain in more detail why a user would want to enable this? What scenarios does it help solve better, and what are the drawbacks/weaknesses?