|
| 1 | + |
| 2 | +import json |
| 3 | +import os |
| 4 | +import torch |
| 5 | +import tensorflow as tf |
| 6 | +import argparse |
| 7 | +from mlagents.trainers.learn import run_cli, parse_command_line |
| 8 | +from mlagents.trainers.settings import RunOptions |
| 9 | +from mlagents.trainers.stats import StatsReporter |
| 10 | +from mlagents.trainers.ppo.trainer import TestingConfiguration |
| 11 | +from mlagents_envs.timers import _thread_timer_stacks |
| 12 | + |
| 13 | + |
| 14 | + |
| 15 | + |
| 16 | +def run_experiment(name:str, steps:int, use_torch:bool, num_torch_threads:int, use_gpu:bool, num_envs :int= 1, config_name=None): |
| 17 | + TestingConfiguration.env_name = name |
| 18 | + TestingConfiguration.max_steps = steps |
| 19 | + TestingConfiguration.use_torch = use_torch |
| 20 | + TestingConfiguration.device = "cuda:0" if use_gpu else "cpu" |
| 21 | + if use_gpu: |
| 22 | + tf.device("/GPU:0") |
| 23 | + else: |
| 24 | + tf.device("/device:CPU:0") |
| 25 | + if (not torch.cuda.is_available() and use_gpu): |
| 26 | + return name, str(steps), str(use_torch), str(num_torch_threads), str(num_envs), str(use_gpu), "na","na","na","na","na","na","na" |
| 27 | + if config_name is None: |
| 28 | + config_name = name |
| 29 | + run_options = parse_command_line([f"config/ppo/{config_name}.yaml", "--num-envs", f"{num_envs}"]) |
| 30 | + run_options.checkpoint_settings.run_id = f"{name}_test_" +str(steps) +"_"+("torch" if use_torch else "tf") |
| 31 | + run_options.checkpoint_settings.force = True |
| 32 | + # run_options.env_settings.num_envs = num_envs |
| 33 | + for trainer_settings in run_options.behaviors.values(): |
| 34 | + trainer_settings.threaded = False |
| 35 | + timers_path = os.path.join("results", run_options.checkpoint_settings.run_id, "run_logs", "timers.json") |
| 36 | + if use_torch: |
| 37 | + torch.set_num_threads(num_torch_threads) |
| 38 | + run_cli(run_options) |
| 39 | + StatsReporter.writers.clear() |
| 40 | + StatsReporter.stats_dict.clear() |
| 41 | + _thread_timer_stacks.clear() |
| 42 | + with open(timers_path) as timers_json_file: |
| 43 | + timers_json = json.load(timers_json_file) |
| 44 | + total = timers_json["total"] |
| 45 | + tc_advance = timers_json["children"]["TrainerController.start_learning"]["children"]["TrainerController.advance"] |
| 46 | + evaluate = timers_json["children"]["TrainerController.start_learning"]["children"]["TrainerController.advance"]["children"]["env_step"]["children"]["SubprocessEnvManager._take_step"]["children"] |
| 47 | + update = timers_json["children"]["TrainerController.start_learning"]["children"]["TrainerController.advance"]["children"]["trainer_advance"]["children"]["_update_policy"]["children"] |
| 48 | + tc_advance_total = tc_advance["total"] |
| 49 | + tc_advance_count = tc_advance["count"] |
| 50 | + if use_torch: |
| 51 | + update_total = update["TorchPPOOptimizer.update"]["total"] |
| 52 | + evaluate_total = evaluate["TorchPolicy.evaluate"]["total"] |
| 53 | + update_count = update["TorchPPOOptimizer.update"]["count"] |
| 54 | + evaluate_count = evaluate["TorchPolicy.evaluate"]["count"] |
| 55 | + else: |
| 56 | + update_total = update["TFPPOOptimizer.update"]["total"] |
| 57 | + evaluate_total = evaluate["NNPolicy.evaluate"]["total"] |
| 58 | + update_count = update["TFPPOOptimizer.update"]["count"] |
| 59 | + evaluate_count= evaluate["NNPolicy.evaluate"]["count"] |
| 60 | + # todo: do total / count |
| 61 | + return name, str(steps), str(use_torch), str(num_torch_threads), str(num_envs), str(use_gpu), str(total), str(tc_advance_total), str(tc_advance_count), str(update_total), str(update_count), str(evaluate_total), str(evaluate_count) |
| 62 | + |
| 63 | + |
| 64 | +def main(): |
| 65 | + parser = argparse.ArgumentParser() |
| 66 | + parser.add_argument("--steps", default=25000, type=int, help="The number of steps") |
| 67 | + parser.add_argument("--num-envs", default=1, type=int, help="The number of envs") |
| 68 | + parser.add_argument("--gpu", default = False, action="store_true", help="If true, will use the GPU") |
| 69 | + parser.add_argument("--threads", default=False, action="store_true", help="If true, will try both 1 and 8 threads for torch") |
| 70 | + parser.add_argument("--ball", default=False, action="store_true", help="If true, will only do 3dball") |
| 71 | + args = parser.parse_args() |
| 72 | + |
| 73 | + if args.gpu: |
| 74 | + os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
| 75 | + else: |
| 76 | + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
| 77 | + |
| 78 | + envs_config_tuples = [("3DBall", "3DBall"), ("GridWorld", "GridWorld"), ("PushBlock", "PushBlock"), ("Hallway", "Hallway"), ("CrawlerStaticTarget", "CrawlerStatic"), ("VisualHallway", "VisualHallway")] |
| 79 | + if args.ball: |
| 80 | + envs_config_tuples=[("3DBall", "3DBall")] |
| 81 | + |
| 82 | + |
| 83 | + labels = ("name", "steps", "use_torch", "num_torch_threads", "num_envs", "use_gpu" , "total", "tc_advance_total", "tc_advance_count", "update_total", "update_count", "evaluate_total", "evaluate_count") |
| 84 | + |
| 85 | + results = [] |
| 86 | + results.append(labels) |
| 87 | + f = open(f"result_data_steps_{args.steps}_envs_{args.num_envs}_gpu_{args.gpu}_thread_{args.threads}.txt", "w") |
| 88 | + f.write(" ".join(labels)+ "\n") |
| 89 | + |
| 90 | + for env_config in envs_config_tuples: |
| 91 | + data = run_experiment(name = env_config[0], steps=args.steps, use_torch=True, num_torch_threads=1, use_gpu=args.gpu, num_envs = args.num_envs, config_name=env_config[1]) |
| 92 | + results.append(data) |
| 93 | + f.write(" ".join(data) + "\n") |
| 94 | + |
| 95 | + if args.threads: |
| 96 | + data = run_experiment(name = env_config[0], steps=args.steps, use_torch=True, num_torch_threads=8, use_gpu=args.gpu, num_envs = args.num_envs, config_name=env_config[1]) |
| 97 | + results.append(data) |
| 98 | + f.write(" ".join(data)+ "\n") |
| 99 | + |
| 100 | + |
| 101 | + data = run_experiment(name = env_config[0], steps=args.steps, use_torch=False, num_torch_threads=1, use_gpu=args.gpu, num_envs = args.num_envs, config_name=env_config[1]) |
| 102 | + results.append(data) |
| 103 | + f.write(" ".join(data)+ "\n") |
| 104 | + for r in results: |
| 105 | + print(*r) |
| 106 | + f.close() |
| 107 | + |
| 108 | + |
| 109 | +if __name__ == "__main__": |
| 110 | + main() |
| 111 | + |
0 commit comments