-
Notifications
You must be signed in to change notification settings - Fork 4.4k
[WIP] Demonstration provider #4988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
eb4821e
6181acb
c4d0852
b934340
d95b126
46f99e6
92377b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import os | ||
from typing import List, Tuple | ||
import numpy as np | ||
from mlagents.trainers.buffer import AgentBuffer | ||
from mlagents_envs.communicator_objects.agent_info_action_pair_pb2 import ( | ||
AgentInfoActionPairProto, | ||
) | ||
from mlagents.trainers.trajectory import ObsUtil | ||
from mlagents_envs.rpc_utils import behavior_spec_from_proto, steps_from_proto | ||
from mlagents_envs.base_env import BehaviorSpec | ||
from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto | ||
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import ( | ||
DemonstrationMetaProto, | ||
) | ||
from mlagents_envs.timers import timed, hierarchical_timer | ||
from google.protobuf.internal.decoder import _DecodeVarint32 # type: ignore | ||
from google.protobuf.internal.encoder import _EncodeVarint # type: ignore | ||
|
||
|
||
INITIAL_POS = 33 | ||
SUPPORTED_DEMONSTRATION_VERSIONS = frozenset([0, 1]) | ||
|
||
|
||
@timed | ||
def load_demonstration( | ||
file_paths: List[str], | ||
) -> Tuple[BehaviorSpec, List[AgentInfoActionPairProto]]: | ||
""" | ||
Loads and parses a demonstration file. | ||
:param file_path: Location of demonstration file (.demo). | ||
:return: BrainParameter and list of AgentInfoActionPairProto containing demonstration data. | ||
""" | ||
|
||
# First 32 bytes of file dedicated to meta-data. | ||
behavior_spec = None | ||
brain_param_proto = None | ||
info_action_pairs = [] | ||
total_expected = 0 | ||
for _file_path in file_paths: | ||
with open(_file_path, "rb") as fp: | ||
with hierarchical_timer("read_file"): | ||
data = fp.read() | ||
next_pos, pos, obs_decoded = 0, 0, 0 | ||
while pos < len(data): | ||
next_pos, pos = _DecodeVarint32(data, pos) | ||
if obs_decoded == 0: | ||
meta_data_proto = DemonstrationMetaProto() | ||
meta_data_proto.ParseFromString(data[pos : pos + next_pos]) | ||
if ( | ||
meta_data_proto.api_version | ||
not in SUPPORTED_DEMONSTRATION_VERSIONS | ||
): | ||
raise RuntimeError( | ||
f"Can't load Demonstration data from an unsupported version ({meta_data_proto.api_version})" | ||
) | ||
total_expected += meta_data_proto.number_steps | ||
pos = INITIAL_POS | ||
if obs_decoded == 1: | ||
brain_param_proto = BrainParametersProto() | ||
brain_param_proto.ParseFromString(data[pos : pos + next_pos]) | ||
pos += next_pos | ||
if obs_decoded > 1: | ||
agent_info_action = AgentInfoActionPairProto() | ||
agent_info_action.ParseFromString(data[pos : pos + next_pos]) | ||
if behavior_spec is None: | ||
behavior_spec = behavior_spec_from_proto( | ||
brain_param_proto, agent_info_action.agent_info | ||
) | ||
info_action_pairs.append(agent_info_action) | ||
if len(info_action_pairs) == total_expected: | ||
break | ||
pos += next_pos | ||
obs_decoded += 1 | ||
if not behavior_spec: | ||
raise RuntimeError( | ||
f"No BrainParameters found in demonstration file(s) at {file_paths}." | ||
) | ||
return behavior_spec, info_action_pairs | ||
|
||
|
||
def write_delimited(f, message): | ||
msg_string = message.SerializeToString() | ||
msg_size = len(msg_string) | ||
_EncodeVarint(f.write, msg_size) | ||
f.write(msg_string) | ||
|
||
|
||
def write_demo(demo_path, meta_data_proto, brain_param_proto, agent_info_protos): | ||
with open(demo_path, "wb") as f: | ||
# write metadata | ||
write_delimited(f, meta_data_proto) | ||
f.seek(INITIAL_POS) | ||
write_delimited(f, brain_param_proto) | ||
|
||
for agent in agent_info_protos: | ||
write_delimited(f, agent) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import abc | ||
|
||
import numpy as np | ||
|
||
from typing import List, NamedTuple | ||
|
||
from mlagents_envs.base_env import ActionTuple, BehaviorSpec | ||
|
||
from mlagents.trainers.buffer import AgentBuffer, BufferKey | ||
from mlagents.trainers.trajectory import ObsUtil | ||
|
||
|
||
class DemonstrationExperience(NamedTuple): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are trimmed down versions of AgentExperience and Trajectory classes, based on what's currently in demo_loader. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I feel like we shouldn't duplicate the conversion code here between AgentExperience and Trajectory (esp. with the teammate observations coming in, it becomes quite a fat function - and at some point I imagine we'll have teammate demonstrations as well). Wonder if we can have a BaseAgentExperience be the base class that is used here and in the AgentProcessor, and have the AgentExperience (PolicyAgentExperience?) inherit from it? Or some other way of composing these two. |
||
obs: List[np.ndarray] | ||
reward: float | ||
done: bool | ||
action: ActionTuple | ||
prev_action: np.ndarray | ||
interrupted: bool | ||
|
||
|
||
class DemonstrationTrajectory(NamedTuple): | ||
experiences: List[DemonstrationExperience] | ||
|
||
def to_agentbuffer(self) -> AgentBuffer: | ||
""" | ||
Converts a Trajectory to an AgentBuffer | ||
:param trajectory: A Trajectory | ||
:returns: AgentBuffer. Note that the length of the AgentBuffer will be one | ||
less than the trajectory, as the next observation need to be populated from the last | ||
step of the trajectory. | ||
""" | ||
agent_buffer_trajectory = AgentBuffer() | ||
for exp in self.experiences: | ||
for i, obs in enumerate(exp.obs): | ||
agent_buffer_trajectory[ObsUtil.get_name_at(i)].append(obs) | ||
|
||
# TODO Not in demo_loader | ||
agent_buffer_trajectory[BufferKey.MASKS].append(1.0) | ||
agent_buffer_trajectory[BufferKey.DONE].append(exp.done) | ||
|
||
agent_buffer_trajectory[BufferKey.CONTINUOUS_ACTION].append( | ||
exp.action.continuous | ||
) | ||
agent_buffer_trajectory[BufferKey.DISCRETE_ACTION].append( | ||
exp.action.discrete | ||
) | ||
|
||
agent_buffer_trajectory[BufferKey.PREV_ACTION].append(exp.prev_action) | ||
agent_buffer_trajectory[BufferKey.ENVIRONMENT_REWARDS].append(exp.reward) | ||
|
||
return agent_buffer_trajectory | ||
|
||
|
||
class DemonstrationProvider(abc.ABC): | ||
@abc.abstractmethod | ||
def get_behavior_spec(self) -> BehaviorSpec: | ||
pass | ||
|
||
@abc.abstractmethod | ||
def pop_trajectories(self) -> List[DemonstrationTrajectory]: | ||
pass | ||
|
||
def to_agentbuffer(self, training_length: int) -> AgentBuffer: | ||
buffer_out = AgentBuffer() | ||
trajectories = self.pop_trajectories() | ||
for trajectory in trajectories: | ||
temp_buffer = trajectory.to_agentbuffer() | ||
temp_buffer.resequence_and_append( | ||
buffer_out, batch_size=None, training_length=training_length | ||
) | ||
return buffer_out |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import os | ||
from typing import List | ||
import numpy as np | ||
|
||
|
||
from mlagents_envs.base_env import ActionTuple, BehaviorSpec, ActionSpec | ||
from mlagents_envs.communicator_objects.agent_info_action_pair_pb2 import ( | ||
AgentInfoActionPairProto, | ||
) | ||
from mlagents_envs.rpc_utils import steps_from_proto | ||
|
||
|
||
from mlagents.trainers.demonstrations.demonstration_provider import ( | ||
DemonstrationProvider, | ||
DemonstrationExperience, | ||
DemonstrationTrajectory, | ||
) | ||
from mlagents.trainers.demonstrations.demonstration_proto_utils import ( | ||
load_demonstration, | ||
) | ||
|
||
|
||
class LocalDemonstrationProvider(DemonstrationProvider): | ||
def __init__(self, file_path: str): | ||
super().__init__() | ||
|
||
demo_paths = self._get_demo_files(file_path) | ||
behavior_spec, info_action_pairs, = load_demonstration(demo_paths) | ||
self._behavior_spec = behavior_spec | ||
self._info_action_pairs = info_action_pairs | ||
|
||
def get_behavior_spec(self) -> BehaviorSpec: | ||
return self._behavior_spec | ||
|
||
def pop_trajectories(self) -> List[DemonstrationTrajectory]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to add docstrings here. But the idea is that GAIL, etc could be converted to use pop_trajectories() directly. Then if we want DemonstrationProviders to be able to load new demonstrations on the fly, the logic can be kept in the DemonstrationProvider and the consumer doesn't need to know about it, it just gets a fresh batch of trajectories. |
||
trajectories = LocalDemonstrationProvider._info_action_pairs_to_trajectories( | ||
self._behavior_spec, self._info_action_pairs | ||
) | ||
self._info_action_pairs = [] | ||
return trajectories | ||
|
||
@staticmethod | ||
def _get_demo_files(path: str) -> List[str]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from demo_loader.get_demo_files |
||
""" | ||
Retrieves the demonstration file(s) from a path. | ||
:param path: Path of demonstration file or directory. | ||
:return: List of demonstration files | ||
|
||
Raises errors if |path| is invalid. | ||
""" | ||
if os.path.isfile(path): | ||
if not path.endswith(".demo"): | ||
raise ValueError("The path provided is not a '.demo' file.") | ||
return [path] | ||
elif os.path.isdir(path): | ||
paths = [ | ||
os.path.join(path, name) | ||
for name in os.listdir(path) | ||
if name.endswith(".demo") | ||
] | ||
if not paths: | ||
raise ValueError( | ||
"There are no '.demo' files in the provided directory." | ||
) | ||
return paths | ||
else: | ||
raise FileNotFoundError( | ||
f"The demonstration file or directory {path} does not exist." | ||
) | ||
|
||
@staticmethod | ||
def _info_action_pairs_to_trajectories( | ||
behavior_spec: BehaviorSpec, info_action_pairs: List[AgentInfoActionPairProto] | ||
) -> List[DemonstrationTrajectory]: | ||
trajectories_out: List[DemonstrationTrajectory] = [] | ||
current_experiences = [] | ||
previous_action = np.zeros( | ||
behavior_spec.action_spec.continuous_size, dtype=np.float32 | ||
) # TODO or discrete? | ||
for pair_index, pair in enumerate(info_action_pairs): | ||
|
||
# Extract the observations from the decision/terminal steps | ||
current_decision_step, current_terminal_step = steps_from_proto( | ||
[pair.agent_info], behavior_spec | ||
) | ||
if len(current_terminal_step) == 1: | ||
obs = list(current_terminal_step.values())[0].obs | ||
else: | ||
obs = list(current_decision_step.values())[0].obs | ||
|
||
action_tuple = LocalDemonstrationProvider._get_action_tuple( | ||
pair, behavior_spec.action_spec | ||
) | ||
|
||
exp = DemonstrationExperience( | ||
obs=obs, | ||
reward=pair.agent_info.reward, # TODO next step's reward? | ||
done=pair.agent_info.done, | ||
action=action_tuple, | ||
prev_action=previous_action, | ||
interrupted=pair.agent_info.max_step_reached, | ||
) | ||
current_experiences.append(exp) | ||
previous_action = np.array( | ||
pair.action_info.vector_actions_deprecated, dtype=np.float32 | ||
) | ||
if pair.agent_info.done or pair_index == len(info_action_pairs) - 1: | ||
trajectories_out.append( | ||
DemonstrationTrajectory(experiences=current_experiences) | ||
) | ||
current_experiences = [] | ||
|
||
return trajectories_out | ||
|
||
@staticmethod | ||
def _get_action_tuple( | ||
pair: AgentInfoActionPairProto, action_spec: ActionSpec | ||
) -> ActionTuple: | ||
continuous_actions = None | ||
discrete_actions = None | ||
|
||
if ( | ||
len(pair.action_info.continuous_actions) == 0 | ||
and len(pair.action_info.discrete_actions) == 0 | ||
): | ||
if action_spec.continuous_size > 0: | ||
continuous_actions = pair.action_info.vector_actions_deprecated | ||
else: | ||
discrete_actions = pair.action_info.vector_actions_deprecated | ||
else: | ||
if action_spec.continuous_size > 0: | ||
continuous_actions = pair.action_info.continuous_actions | ||
if action_spec.discrete_size > 0: | ||
discrete_actions = pair.action_info.discrete_actions | ||
|
||
# TODO 2D? | ||
continuous_np = ( | ||
np.array(continuous_actions, dtype=np.float32) | ||
if continuous_actions | ||
else None | ||
) | ||
discrete_np = ( | ||
np.array(discrete_actions, dtype=np.float32) if discrete_actions else None | ||
) | ||
|
||
return ActionTuple(continuous_np, discrete_np) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was demo_loader.load_demonstration