Skip to content

Commit aadaca9

Browse files
author
Ervin Teng
committed
Fix discrete actions and GridWorld
1 parent 3442de5 commit aadaca9

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

ml-agents/mlagents/trainers/models_torch.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
MultiCategoricalDistribution,
1010
)
1111
from mlagents.trainers.exception import UnityTrainerException
12+
from mlagents.trainers.models import EncoderType
1213

1314
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
1415
EncoderFunction = Callable[
@@ -18,12 +19,6 @@
1819
EPSILON = 1e-7
1920

2021

21-
class EncoderType(Enum):
22-
SIMPLE = "simple"
23-
NATURE_CNN = "nature_cnn"
24-
RESNET = "resnet"
25-
26-
2722
class ActionType(Enum):
2823
DISCRETE = "discrete"
2924
CONTINUOUS = "continuous"
@@ -113,7 +108,7 @@ def forward(self, vec_inputs, vis_inputs, memories=None, sequence_length=1):
113108
hidden = encoder(vis_input)
114109
vis_embeds.append(hidden)
115110

116-
#embedding = vec_embeds[0]
111+
# embedding = vec_embeds[0]
117112
if len(vec_embeds) > 0:
118113
vec_embeds = torch.stack(vec_embeds, dim=-1).sum(dim=-1)
119114
if len(vis_embeds) > 0:
@@ -254,7 +249,14 @@ def forward(
254249
vec_inputs, vis_inputs, masks, memories, sequence_length
255250
)
256251
sampled_actions = self.sample_action(dists)
257-
return sampled_actions, dists[0].pdf(sampled_actions), self.version_number, self.memory_size, self.is_continuous_int, self.act_size_vector
252+
return (
253+
sampled_actions,
254+
dists[0].pdf(sampled_actions),
255+
self.version_number,
256+
self.memory_size,
257+
self.is_continuous_int,
258+
self.act_size_vector,
259+
)
258260

259261

260262
class Critic(nn.Module):
@@ -444,7 +446,9 @@ def __init__(self, height, width, initial_channels, final_hidden):
444446
self.layers = []
445447
last_channel = initial_channels
446448
for _, channel in enumerate(n_channels):
447-
self.layers.append(nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1))
449+
self.layers.append(
450+
nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1)
451+
)
448452
self.layers.append(nn.MaxPool2d([3, 3], [2, 2]))
449453
height, width = pool_out_shape((height, width), 3)
450454
for _ in range(n_blocks):
@@ -473,7 +477,7 @@ def forward_block(input_hidden, block_layers):
473477
def forward(self, visual_obs):
474478
batch_size = visual_obs.shape[0]
475479
hidden = visual_obs
476-
for idx, layer in enumerate(self.layers):
480+
for layer in self.layers:
477481
if isinstance(layer, nn.Module):
478482
hidden = layer(hidden)
479483
elif isinstance(layer, list):
@@ -503,6 +507,7 @@ def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
503507
EncoderType.NATURE_CNN: NatureVisualEncoder,
504508
EncoderType.RESNET: ResNetVisualEncoder,
505509
}
510+
print(encoder_type, ENCODER_FUNCTION_BY_TYPE.get(encoder_type))
506511
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)
507512

508513
@staticmethod

ml-agents/mlagents/trainers/policy/torch_policy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,10 @@ def sample_actions(self, vec_obs, vis_obs, masks=None, memories=None, seq_len=1)
150150

151151
actions = self.actor_critic.sample_action(dists)
152152
log_probs, entropies = self.actor_critic.get_probs_and_entropy(actions, dists)
153-
actions = torch.squeeze(actions)
153+
if self.use_continuous_act:
154+
actions = actions[:, :, 0]
155+
else:
156+
actions = actions[:, 0, :]
154157

155158
return actions, log_probs, entropies, value_heads, memories
156159

0 commit comments

Comments
 (0)