|
9 | 9 | MultiCategoricalDistribution,
|
10 | 10 | )
|
11 | 11 | from mlagents.trainers.exception import UnityTrainerException
|
| 12 | +from mlagents.trainers.models import EncoderType |
12 | 13 |
|
13 | 14 | ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
|
14 | 15 | EncoderFunction = Callable[
|
|
18 | 19 | EPSILON = 1e-7
|
19 | 20 |
|
20 | 21 |
|
21 |
| -class EncoderType(Enum): |
22 |
| - SIMPLE = "simple" |
23 |
| - NATURE_CNN = "nature_cnn" |
24 |
| - RESNET = "resnet" |
25 |
| - |
26 |
| - |
27 | 22 | class ActionType(Enum):
|
28 | 23 | DISCRETE = "discrete"
|
29 | 24 | CONTINUOUS = "continuous"
|
@@ -113,7 +108,7 @@ def forward(self, vec_inputs, vis_inputs, memories=None, sequence_length=1):
|
113 | 108 | hidden = encoder(vis_input)
|
114 | 109 | vis_embeds.append(hidden)
|
115 | 110 |
|
116 |
| - #embedding = vec_embeds[0] |
| 111 | + # embedding = vec_embeds[0] |
117 | 112 | if len(vec_embeds) > 0:
|
118 | 113 | vec_embeds = torch.stack(vec_embeds, dim=-1).sum(dim=-1)
|
119 | 114 | if len(vis_embeds) > 0:
|
@@ -254,7 +249,14 @@ def forward(
|
254 | 249 | vec_inputs, vis_inputs, masks, memories, sequence_length
|
255 | 250 | )
|
256 | 251 | sampled_actions = self.sample_action(dists)
|
257 |
| - return sampled_actions, dists[0].pdf(sampled_actions), self.version_number, self.memory_size, self.is_continuous_int, self.act_size_vector |
| 252 | + return ( |
| 253 | + sampled_actions, |
| 254 | + dists[0].pdf(sampled_actions), |
| 255 | + self.version_number, |
| 256 | + self.memory_size, |
| 257 | + self.is_continuous_int, |
| 258 | + self.act_size_vector, |
| 259 | + ) |
258 | 260 |
|
259 | 261 |
|
260 | 262 | class Critic(nn.Module):
|
@@ -444,7 +446,9 @@ def __init__(self, height, width, initial_channels, final_hidden):
|
444 | 446 | self.layers = []
|
445 | 447 | last_channel = initial_channels
|
446 | 448 | for _, channel in enumerate(n_channels):
|
447 |
| - self.layers.append(nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1)) |
| 449 | + self.layers.append( |
| 450 | + nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1) |
| 451 | + ) |
448 | 452 | self.layers.append(nn.MaxPool2d([3, 3], [2, 2]))
|
449 | 453 | height, width = pool_out_shape((height, width), 3)
|
450 | 454 | for _ in range(n_blocks):
|
@@ -473,7 +477,7 @@ def forward_block(input_hidden, block_layers):
|
473 | 477 | def forward(self, visual_obs):
|
474 | 478 | batch_size = visual_obs.shape[0]
|
475 | 479 | hidden = visual_obs
|
476 |
| - for idx, layer in enumerate(self.layers): |
| 480 | + for layer in self.layers: |
477 | 481 | if isinstance(layer, nn.Module):
|
478 | 482 | hidden = layer(hidden)
|
479 | 483 | elif isinstance(layer, list):
|
@@ -503,6 +507,7 @@ def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
|
503 | 507 | EncoderType.NATURE_CNN: NatureVisualEncoder,
|
504 | 508 | EncoderType.RESNET: ResNetVisualEncoder,
|
505 | 509 | }
|
| 510 | + print(encoder_type, ENCODER_FUNCTION_BY_TYPE.get(encoder_type)) |
506 | 511 | return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)
|
507 | 512 |
|
508 | 513 | @staticmethod
|
|
0 commit comments