@@ -57,6 +57,50 @@ def test_load_and_set(dummy_config, use_discrete):
57
57
np .testing .assert_array_equal (w , lw )
58
58
59
59
60
+ def test_resume (dummy_config , tmp_path ):
61
+ mock_specs = mb .setup_test_behavior_specs (
62
+ True , False , vector_action_space = [2 ], vector_obs_space = 1
63
+ )
64
+ behavior_id_team0 = "test_brain?team=0"
65
+ behavior_id_team1 = "test_brain?team=1"
66
+ brain_name = BehaviorIdentifiers .from_name_behavior_id (behavior_id_team0 ).brain_name
67
+ tmp_path = tmp_path .as_posix ()
68
+ ppo_trainer = PPOTrainer (brain_name , 0 , dummy_config , True , False , 0 , tmp_path )
69
+ controller = GhostController (100 )
70
+ trainer = GhostTrainer (
71
+ ppo_trainer , brain_name , controller , 0 , dummy_config , True , tmp_path
72
+ )
73
+
74
+ parsed_behavior_id0 = BehaviorIdentifiers .from_name_behavior_id (behavior_id_team0 )
75
+ policy = trainer .create_policy (parsed_behavior_id0 , mock_specs )
76
+ trainer .add_policy (parsed_behavior_id0 , policy )
77
+
78
+ parsed_behavior_id1 = BehaviorIdentifiers .from_name_behavior_id (behavior_id_team1 )
79
+ policy = trainer .create_policy (parsed_behavior_id1 , mock_specs )
80
+ trainer .add_policy (parsed_behavior_id1 , policy )
81
+
82
+ trainer .save_model ()
83
+
84
+ # Make a new trainer, check that the policies are the same
85
+ ppo_trainer2 = PPOTrainer (brain_name , 0 , dummy_config , True , True , 0 , tmp_path )
86
+ trainer2 = GhostTrainer (
87
+ ppo_trainer2 , brain_name , controller , 0 , dummy_config , True , tmp_path
88
+ )
89
+ policy = trainer2 .create_policy (parsed_behavior_id0 , mock_specs )
90
+ trainer2 .add_policy (parsed_behavior_id0 , policy )
91
+
92
+ policy = trainer2 .create_policy (parsed_behavior_id1 , mock_specs )
93
+ trainer2 .add_policy (parsed_behavior_id1 , policy )
94
+
95
+ trainer1_policy = trainer .get_policy (parsed_behavior_id1 .behavior_id )
96
+ trainer2_policy = trainer2 .get_policy (parsed_behavior_id1 .behavior_id )
97
+ weights = trainer1_policy .get_weights ()
98
+ weights2 = trainer2_policy .get_weights ()
99
+
100
+ for w , lw in zip (weights , weights2 ):
101
+ np .testing .assert_array_equal (w , lw )
102
+
103
+
60
104
def test_process_trajectory (dummy_config ):
61
105
mock_specs = mb .setup_test_behavior_specs (
62
106
True , False , vector_action_space = [2 ], vector_obs_space = 1
0 commit comments