7
7
from mlagents .trainers .policy .torch_policy import TorchPolicy
8
8
from mlagents .trainers .optimizer .torch_optimizer import TorchOptimizer
9
9
from mlagents .trainers .settings import TrainerSettings , PPOSettings
10
+ from mlagents .trainers .models_torch import list_to_tensor
10
11
11
12
12
13
class TorchPPOOptimizer (TorchOptimizer ):
@@ -91,18 +92,18 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
91
92
returns = {}
92
93
old_values = {}
93
94
for name in self .reward_signals :
94
- old_values [name ] = torch . as_tensor (batch ["{}_value_estimates" .format (name )])
95
- returns [name ] = torch . as_tensor (batch ["{}_returns" .format (name )])
95
+ old_values [name ] = list_to_tensor (batch ["{}_value_estimates" .format (name )])
96
+ returns [name ] = list_to_tensor (batch ["{}_returns" .format (name )])
96
97
97
- vec_obs = [torch . as_tensor (batch ["vector_obs" ])]
98
- act_masks = torch . as_tensor (batch ["action_mask" ])
98
+ vec_obs = [list_to_tensor (batch ["vector_obs" ])]
99
+ act_masks = list_to_tensor (batch ["action_mask" ])
99
100
if self .policy .use_continuous_act :
100
- actions = torch . as_tensor (batch ["actions" ]).unsqueeze (- 1 )
101
+ actions = list_to_tensor (batch ["actions" ]).unsqueeze (- 1 )
101
102
else :
102
- actions = torch . as_tensor (batch ["actions" ], dtype = torch .long )
103
+ actions = list_to_tensor (batch ["actions" ], dtype = torch .long )
103
104
104
105
memories = [
105
- torch . as_tensor (batch ["memory" ][i ])
106
+ list_to_tensor (batch ["memory" ][i ])
106
107
for i in range (0 , len (batch ["memory" ]), self .policy .sequence_length )
107
108
]
108
109
if len (memories ) > 0 :
@@ -113,7 +114,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
113
114
for idx , _ in enumerate (
114
115
self .policy .actor_critic .network_body .visual_encoders
115
116
):
116
- vis_ob = torch . as_tensor (batch ["visual_obs%d" % idx ])
117
+ vis_ob = list_to_tensor (batch ["visual_obs%d" % idx ])
117
118
vis_obs .append (vis_ob )
118
119
else :
119
120
vis_obs = []
@@ -127,10 +128,10 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
127
128
)
128
129
value_loss = self .ppo_value_loss (values , old_values , returns )
129
130
policy_loss = self .ppo_policy_loss (
130
- torch . as_tensor (batch ["advantages" ]),
131
+ list_to_tensor (batch ["advantages" ]),
131
132
log_probs ,
132
- torch . as_tensor (batch ["action_probs" ]),
133
- torch . as_tensor (batch ["masks" ], dtype = torch .int32 ),
133
+ list_to_tensor (batch ["action_probs" ]),
134
+ list_to_tensor (batch ["masks" ], dtype = torch .int32 ),
134
135
)
135
136
loss = (
136
137
policy_loss
0 commit comments