Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ and this project adheres to
### Major Changes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Added the Random Network Distillation (RND) intrinsic reward signal to the Pytorch
trainers. To use RND, add a `rnd` section to the `reward_signals` section of your
yaml configuration file. [More information here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-Configuration-File.md#rnd-intrinsic-reward)

### Minor Changes
#### com.unity.ml-agents (C#)
Expand Down
32 changes: 32 additions & 0 deletions config/ppo/PyramidsRND.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
behaviors:
Pyramids:
trainer_type: ppo
hyperparameters:
batch_size: 128
buffer_size: 2048
learning_rate: 0.0003
beta: 0.01
epsilon: 0.2
lambd: 0.95
num_epoch: 3
learning_rate_schedule: linear
network_settings:
normalize: false
hidden_units: 512
num_layers: 2
vis_encode_type: simple
reward_signals:
extrinsic:
gamma: 0.99
strength: 1.0
rnd:
gamma: 0.99
strength: 0.01
encoding_size: 64
learning_rate: 0.0001
keep_checkpoints: 5
max_steps: 3000000
time_horizon: 128
summary_freq: 30000
framework: pytorch
threaded: true
24 changes: 23 additions & 1 deletion docs/ML-Agents-Overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- [A Quick Note on Reward Signals](#a-quick-note-on-reward-signals)
- [Deep Reinforcement Learning](#deep-reinforcement-learning)
- [Curiosity for Sparse-reward Environments](#curiosity-for-sparse-reward-environments)
- [RND for Sparse-reward Environments](#rnd-for-sparse-reward-environments)
- [Imitation Learning](#imitation-learning)
- [GAIL (Generative Adversarial Imitation Learning)](#gail-generative-adversarial-imitation-learning)
- [Behavioral Cloning (BC)](#behavioral-cloning-bc)
Expand Down Expand Up @@ -359,7 +360,7 @@ The total reward that the agent will learn to maximize can be a mix of extrinsic
and intrinsic reward signals.

The ML-Agents Toolkit allows reward signals to be defined in a modular way, and
we provide three reward signals that can the mixed and matched to help shape
we provide four reward signals that can the mixed and matched to help shape
your agent's behavior:

- `extrinsic`: represents the rewards defined in your environment, and is
Expand All @@ -369,6 +370,9 @@ your agent's behavior:
- `curiosity`: represents an intrinsic reward signal that encourages exploration
in sparse-reward environments that is defined by the Curiosity module (see
below).
- `rnd`: represents an intrinsic reward signal that encourages exploration
in sparse-reward environments that is defined by the Curiosity module (see
below). (Not available for TensorFlow trainers)

### Deep Reinforcement Learning

Expand Down Expand Up @@ -417,6 +421,24 @@ model is, the larger the reward will be.
For more information, see our dedicated
[blog post on the Curiosity module](https://blogs.unity3d.com/2018/06/26/solving-sparse-reward-tasks-with-curiosity/).

#### RND for Sparse-reward Environments

Similarly to Curiosity, Random Network Distillation (RND) is useful in sparse or rare
reward environments as it helps the Agent explore. The RND Module is implemented following
the paper [Exploration by Random Network Distillation](https://arxiv.org/abs/1810.12894).
RND uses two networks:
- The first is a network with fixed random weights that takes observations as inputs and
generates an encoding
- The second is a network with similar architecture that is trained to predict the
outputs of the first network and uses the observations the Agent collects as training data.

The loss (the squared difference between the predicted and actual encoded observations)
of the trained model is used as intrinsic reward. The more an Agent visits a state, the
more accurate the predictions and the lower the rewards which encourages the Agent to
explore new states with higher prediction errors.

__Note:__ RND is not available for TensorFlow trainers (only PyTorch trainers)

### Imitation Learning

It is often more intuitive to simply demonstrate the behavior we want an agent
Expand Down
13 changes: 13 additions & 0 deletions docs/Training-Configuration-File.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- [Extrinsic Rewards](#extrinsic-rewards)
- [Curiosity Intrinsic Reward](#curiosity-intrinsic-reward)
- [GAIL Intrinsic Reward](#gail-intrinsic-reward)
- [RND Intrinsic Reward](#rnd-intrinsic-reward)
- [Reward Signal Settings for SAC](#reward-signal-settings-for-sac)
- [Behavioral Cloning](#behavioral-cloning)
- [Memory-enhanced Agents using Recurrent Neural Networks](#memory-enhanced-agents-using-recurrent-neural-networks)
Expand Down Expand Up @@ -118,6 +119,18 @@ settings:
| `gail -> use_actions` | (default = `false`) Determines whether the discriminator should discriminate based on both observations and actions, or just observations. Set to True if you want the agent to mimic the actions from the demonstrations, and False if you'd rather have the agent visit the same states as in the demonstrations but with possibly different actions. Setting to False is more likely to be stable, especially with imperfect demonstrations, but may learn slower. |
| `gail -> use_vail` | (default = `false`) Enables a variational bottleneck within the GAIL discriminator. This forces the discriminator to learn a more general representation and reduces its tendency to be "too good" at discriminating, making learning more stable. However, it does increase training time. Enable this if you notice your imitation learning is unstable, or unable to learn the task at hand. |

### RND Intrinsic Reward

Random Network Distillation (RND) is only available for the PyTorch trainers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain in more detail why a user would want to enable this? What scenarios does it help solve better, and what are the drawbacks/weaknesses?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you try to enable this with tensorflow? Or without torch installed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You get an error that says the intrinsic reward signal type is not recognized.

To enable RND, provide these settings:

| **Setting** | **Description** |
| :--------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `rnd -> strength` | (default = `1.0`) Magnitude of the curiosity reward generated by the intrinsic rnd module. This should be scaled in order to ensure it is large enough to not be overwhelmed by extrinsic reward signals in the environment. Likewise it should not be too large to overwhelm the extrinsic reward signal. <br><br>Typical range: `0.001` - `0.01` |
| `rnd -> gamma` | (default = `0.99`) Discount factor for future rewards. <br><br>Typical range: `0.8` - `0.995` |
| `rnd -> encoding_size` | (default = `64`) Size of the encoding used by the intrinsic RND model. <br><br>Typical range: `64` - `256` |
| `curiosity -> learning_rate` | (default = `3e-4`) Learning rate used to update the RND module. This should be large enough for the RND module to quickly learn the state representation, but small enough to allow for stable learning. <br><br>Typical range: `1e-5` - `1e-3`


## Behavioral Cloning

Expand Down
3 changes: 3 additions & 0 deletions ml-agents/mlagents/trainers/model_saver/tf_model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def export(self, output_filepath: str, behavior_name: str) -> None:
# only on worker-0 if there are multiple workers
if self.policy and self.policy.rank is not None and self.policy.rank != 0:
return
if self.graph is None:
logger.info("No model to export")
return
export_policy_model(
self.model_path, output_filepath, behavior_name, self.graph, self.sess
)
Expand Down
8 changes: 8 additions & 0 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,14 @@ class RewardSignalType(Enum):
EXTRINSIC: str = "extrinsic"
GAIL: str = "gail"
CURIOSITY: str = "curiosity"
RND: str = "rnd"

def to_settings(self) -> type:
_mapping = {
RewardSignalType.EXTRINSIC: RewardSignalSettings,
RewardSignalType.GAIL: GAILSettings,
RewardSignalType.CURIOSITY: CuriositySettings,
RewardSignalType.RND: RNDSettings,
}
return _mapping[self]

Expand Down Expand Up @@ -214,6 +216,12 @@ class CuriositySettings(RewardSignalSettings):
learning_rate: float = 3e-4


@attr.s(auto_attribs=True)
class RNDSettings(RewardSignalSettings):
encoding_size: int = 64
learning_rate: float = 1e-4


# SAMPLERS #############################################################################
class ParameterRandomizationType(Enum):
UNIFORM: str = "uniform"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy as np
import pytest
from mlagents.torch_utils import torch
from mlagents.trainers.torch.components.reward_providers import (
RNDRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents.trainers.settings import RNDSettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,
)

SEED = [42]


@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
],
)
def test_construction(behavior_spec: BehaviorSpec) -> None:
curiosity_settings = RNDSettings(32, 0.01)
curiosity_settings.strength = 0.1
curiosity_rp = RNDRewardProvider(behavior_spec, curiosity_settings)
assert curiosity_rp.strength == 0.1
assert curiosity_rp.name == "RND"


@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,), (64, 66, 1)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
],
)
def test_factory(behavior_spec: BehaviorSpec) -> None:
curiosity_settings = RNDSettings(32, 0.01)
curiosity_rp = create_reward_provider(
RewardSignalType.RND, behavior_spec, curiosity_settings
)
assert curiosity_rp.name == "RND"


@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
],
)
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
rnd_settings = RNDSettings(32, 0.01)
rnd_rp = RNDRewardProvider(behavior_spec, rnd_settings)
buffer = create_agent_buffer(behavior_spec, 5)
rnd_rp.update(buffer)
reward_old = rnd_rp.evaluate(buffer)[0]
for _ in range(100):
rnd_rp.update(buffer)
reward_new = rnd_rp.evaluate(buffer)[0]
assert reward_new < reward_old
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import ( # noqa F401
GAILRewardProvider,
)
from mlagents.trainers.torch.components.reward_providers.rnd_reward_provider import ( # noqa F401
RNDRewardProvider,
)
from mlagents.trainers.torch.components.reward_providers.reward_provider_factory import ( # noqa F401
create_reward_provider,
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import (
GAILRewardProvider,
)
from mlagents.trainers.torch.components.reward_providers.rnd_reward_provider import (
RNDRewardProvider,
)

from mlagents_envs.base_env import BehaviorSpec

NAME_TO_CLASS: Dict[RewardSignalType, Type[BaseRewardProvider]] = {
RewardSignalType.EXTRINSIC: ExtrinsicRewardProvider,
RewardSignalType.CURIOSITY: CuriosityRewardProvider,
RewardSignalType.GAIL: GAILRewardProvider,
RewardSignalType.RND: RNDRewardProvider,
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
from typing import Dict
from mlagents.torch_utils import torch

from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import (
BaseRewardProvider,
)
from mlagents.trainers.settings import RNDSettings

from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.networks import NetworkBody
from mlagents.trainers.settings import NetworkSettings, EncoderType


class RNDRewardProvider(BaseRewardProvider):
"""
Implementation of Random Network Distillation : https://arxiv.org/pdf/1810.12894.pdf
"""

def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None:
super().__init__(specs, settings)
self._ignore_done = True
self._random_network = RNDNetwork(specs, settings)
self._training_network = RNDNetwork(specs, settings)
self.optimizer = torch.optim.Adam(
self._training_network.parameters(), lr=settings.learning_rate
)

def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
with torch.no_grad():
target = self._random_network(mini_batch)
prediction = self._training_network(mini_batch)
rewards = torch.sum((prediction - target) ** 2, dim=1)
return rewards.detach().cpu().numpy()

def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
with torch.no_grad():
target = self._random_network(mini_batch)
prediction = self._training_network(mini_batch)
loss = torch.mean(torch.sum((prediction - target) ** 2, dim=1))
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {"Losses/RND Loss": loss.detach().cpu().numpy()}


class RNDNetwork(torch.nn.Module):
EPSILON = 1e-10

def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None:
super().__init__()
self._policy_specs = specs
state_encoder_settings = NetworkSettings(
normalize=True,
hidden_units=settings.encoding_size,
num_layers=3,
vis_encode_type=EncoderType.SIMPLE,
memory=None,
)
self._encoder = NetworkBody(specs.observation_shapes, state_encoder_settings)

def forward(self, mini_batch: AgentBuffer) -> torch.Tensor:
n_vis = len(self._encoder.visual_processors)
hidden, _ = self._encoder.forward(
vec_inputs=[
ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float)
],
vis_inputs=[
ModelUtils.list_to_tensor(
mini_batch["visual_obs%d" % i], dtype=torch.float
)
for i in range(n_vis)
],
)
self._encoder.update_normalization(torch.tensor(mini_batch["vector_obs"]))
return hidden