Skip to content
43 changes: 35 additions & 8 deletions experiment_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def run_experiment(
name: str,
steps: int,
use_torch: bool,
algo: str,
num_torch_threads: int,
use_gpu: bool,
num_envs: int = 1,
Expand All @@ -32,6 +33,7 @@ def run_experiment(
name,
str(steps),
str(use_torch),
algo,
str(num_torch_threads),
str(num_envs),
str(use_gpu),
Expand All @@ -46,7 +48,7 @@ def run_experiment(
if config_name is None:
config_name = name
run_options = parse_command_line(
[f"config/ppo/{config_name}.yaml", "--num-envs", f"{num_envs}"]
[f"config/{algo}/{config_name}.yaml", "--num-envs", f"{num_envs}"]
)
run_options.checkpoint_settings.run_id = (
f"{name}_test_" + str(steps) + "_" + ("torch" if use_torch else "tf")
Expand Down Expand Up @@ -87,20 +89,29 @@ def run_experiment(
tc_advance_total = tc_advance["total"]
tc_advance_count = tc_advance["count"]
if use_torch:
update_total = update["TorchPPOOptimizer.update"]["total"]
if algo == "ppo":
update_total = update["TorchPPOOptimizer.update"]["total"]
update_count = update["TorchPPOOptimizer.update"]["count"]
else:
update_total = update["SACTrainer._update_policy"]["total"]
update_count = update["SACTrainer._update_policy"]["count"]
evaluate_total = evaluate["TorchPolicy.evaluate"]["total"]
update_count = update["TorchPPOOptimizer.update"]["count"]
evaluate_count = evaluate["TorchPolicy.evaluate"]["count"]
else:
update_total = update["TFPPOOptimizer.update"]["total"]
if algo == "ppo":
update_total = update["TFPPOOptimizer.update"]["total"]
update_count = update["TFPPOOptimizer.update"]["count"]
else:
update_total = update["SACTrainer._update_policy"]["total"]
update_count = update["SACTrainer._update_policy"]["count"]
evaluate_total = evaluate["NNPolicy.evaluate"]["total"]
update_count = update["TFPPOOptimizer.update"]["count"]
evaluate_count = evaluate["NNPolicy.evaluate"]["count"]
# todo: do total / count
return (
name,
str(steps),
str(use_torch),
algo,
str(num_torch_threads),
str(num_envs),
str(use_gpu),
Expand Down Expand Up @@ -133,28 +144,41 @@ def main():
action="store_true",
help="If true, will only do 3dball",
)
parser.add_argument(
"--sac",
default=False,
action="store_true",
help="If true, will run sac instead of ppo",
)
args = parser.parse_args()

if args.gpu:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
else:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

algo = "ppo"
if args.sac:
algo = "sac"

envs_config_tuples = [
("3DBall", "3DBall"),
("GridWorld", "GridWorld"),
("PushBlock", "PushBlock"),
("Hallway", "Hallway"),
("CrawlerStaticTarget", "CrawlerStatic"),
("VisualHallway", "VisualHallway"),
]
if algo == "ppo":
envs_config_tuples += [("Hallway", "Hallway"),
("VisualHallway", "VisualHallway")]
if args.ball:
envs_config_tuples = [("3DBall", "3DBall")]


labels = (
"name",
"steps",
"use_torch",
"algorithm",
"num_torch_threads",
"num_envs",
"use_gpu",
Expand All @@ -170,7 +194,7 @@ def main():
results = []
results.append(labels)
f = open(
f"result_data_steps_{args.steps}_envs_{args.num_envs}_gpu_{args.gpu}_thread_{args.threads}.txt",
f"result_data_steps_{args.steps}_algo_{algo}_envs_{args.num_envs}_gpu_{args.gpu}_thread_{args.threads}.txt",
"w",
)
f.write(" ".join(labels) + "\n")
Expand All @@ -180,6 +204,7 @@ def main():
name=env_config[0],
steps=args.steps,
use_torch=True,
algo=algo,
num_torch_threads=1,
use_gpu=args.gpu,
num_envs=args.num_envs,
Expand All @@ -193,6 +218,7 @@ def main():
name=env_config[0],
steps=args.steps,
use_torch=True,
algo=algo,
num_torch_threads=8,
use_gpu=args.gpu,
num_envs=args.num_envs,
Expand All @@ -205,6 +231,7 @@ def main():
name=env_config[0],
steps=args.steps,
use_torch=False,
algo=algo,
num_torch_threads=1,
use_gpu=args.gpu,
num_envs=args.num_envs,
Expand Down
50 changes: 43 additions & 7 deletions ml-agents/mlagents/trainers/distributions_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ def __init__(self, mean, std):
self.std = std

def sample(self):
return self.mean + torch.randn_like(self.mean) * self.std
sample = self.mean + torch.randn_like(self.mean) * self.std
return sample

def log_prob(self, value):
var = self.std ** 2
log_scale = self.std.log()
log_scale = torch.log(self.std + EPSILON)
return (
-((value - self.mean) ** 2) / (2 * var)
-((value - self.mean) ** 2) / (2 * var + EPSILON)
- log_scale
- math.log(math.sqrt(2 * math.pi))
)
Expand All @@ -29,7 +30,28 @@ def pdf(self, value):
return torch.exp(log_prob)

def entropy(self):
return torch.log(2 * math.pi * math.e * self.std)
return torch.log(2 * math.pi * math.e * self.std + EPSILON)


class TanhGaussianDistInstance(GaussianDistInstance):
def __init__(self, mean, std):
super().__init__(mean, std)
self.transform = torch.distributions.transforms.TanhTransform(cache_size=1)

def sample(self):
unsquashed_sample = super().sample()
squashed = self.transform(unsquashed_sample)
return squashed

def _inverse_tanh(self, value):
capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
return 0.5 * torch.log((1 + capped_value) / (1 - capped_value) + EPSILON)

def log_prob(self, value):
unsquashed = self.transform.inv(value)
return super().log_prob(unsquashed) - self.transform.log_abs_det_jacobian(
unsquashed, value
)


class CategoricalDistInstance(nn.Module):
Expand All @@ -47,15 +69,26 @@ def pdf(self, value):
def log_prob(self, value):
return torch.log(self.pdf(value))

def all_log_prob(self):
return torch.log(self.probs)

def entropy(self):
return torch.sum(self.probs * torch.log(self.probs), dim=-1)


class GaussianDistribution(nn.Module):
def __init__(self, hidden_size, num_outputs, conditional_sigma=False, **kwargs):
def __init__(
self,
hidden_size,
num_outputs,
conditional_sigma=False,
tanh_squash=False,
**kwargs
):
super(GaussianDistribution, self).__init__(**kwargs)
self.conditional_sigma = conditional_sigma
self.mu = nn.Linear(hidden_size, num_outputs)
self.tanh_squash = tanh_squash
nn.init.xavier_uniform_(self.mu.weight, gain=0.01)
if conditional_sigma:
self.log_sigma = nn.Linear(hidden_size, num_outputs)
Expand All @@ -68,10 +101,13 @@ def __init__(self, hidden_size, num_outputs, conditional_sigma=False, **kwargs):
def forward(self, inputs):
mu = self.mu(inputs)
if self.conditional_sigma:
log_sigma = self.log_sigma(inputs)
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be consts?

else:
log_sigma = self.log_sigma
return [GaussianDistInstance(mu, torch.exp(log_sigma))]
if self.tanh_squash:
return [TanhGaussianDistInstance(mu, torch.exp(log_sigma))]
else:
return [GaussianDistInstance(mu, torch.exp(log_sigma))]


class MultiCategoricalDistribution(nn.Module):
Expand Down
Loading