diff --git a/.gitignore b/.gitignore index cf8183463613..f018a111ea33 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,6 @@ tags *.lock # DS_Store (MacOS) -.DS_Store \ No newline at end of file +.DS_Store +# RL pipelines may produce mp4 outputs +*.mp4 \ No newline at end of file diff --git a/examples/community/pipeline.py b/examples/community/pipeline.py new file mode 100644 index 000000000000..7e3f2b832b1f --- /dev/null +++ b/examples/community/pipeline.py @@ -0,0 +1,99 @@ +import torch + +import tqdm +from diffusers import DiffusionPipeline +from diffusers.models.unet_1d import UNet1DModel +from diffusers.utils.dummy_pt_objects import DDPMScheduler + + +class ValueGuidedDiffuserPipeline(DiffusionPipeline): + def __init__( + self, + value_function: UNet1DModel, + unet: UNet1DModel, + scheduler: DDPMScheduler, + env, + ): + super().__init__() + self.value_function = value_function + self.unet = unet + self.scheduler = scheduler + self.env = env + self.data = env.get_dataset() + self.means = dict() + for key in self.data.keys(): + try: + self.means[key] = self.data[key].mean() + except: + pass + self.stds = dict() + for key in self.data.keys(): + try: + self.stds[key] = self.data[key].std() + except: + pass + self.state_dim = env.observation_space.shape[0] + self.action_dim = env.action_space.shape[0] + + def normalize(self, x_in, key): + return (x_in - self.means[key]) / self.stds[key] + + def de_normalize(self, x_in, key): + return x_in * self.stds[key] + self.means[key] + + def to_torch(self, x_in): + if type(x_in) is dict: + return {k: self.to_torch(v) for k, v in x_in.items()} + elif torch.is_tensor(x_in): + return x_in.to(self.unet.device) + return torch.tensor(x_in, device=self.unet.device) + + def reset_x0(self, x_in, cond, act_dim): + for key, val in cond.items(): + x_in[:, key, act_dim:] = val.clone() + return x_in + + def run_diffusion(self, x, conditions, n_guide_steps, scale): + batch_size = x.shape[0] + y = None + for i in tqdm.tqdm(self.scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) + # 3. call the sample function + for _ in range(n_guide_steps): + with torch.enable_grad(): + x.requires_grad_() + y = self.value_function(x.permute(0, 2, 1), timesteps).sample + grad = torch.autograd.grad([y.sum()], [x])[0] + + posterior_variance = self.scheduler._get_variance(i) + model_std = torch.exp(0.5 * posterior_variance) + grad = model_std * grad + grad[timesteps < 2] = 0 + x = x.detach() + x = x + scale * grad + x = self.reset_x0(x, conditions, self.action_dim) + prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) + x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] + + # 4. apply conditions to the trajectory + x = self.reset_x0(x, conditions, self.action_dim) + x = self.to_torch(x) + return x, y + + def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + obs = self.normalize(obs, "observations") + obs = obs[None].repeat(batch_size, axis=0) + conditions = {0: self.to_torch(obs)} + shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + x1 = torch.randn(shape, device=self.unet.device) + x = self.reset_x0(x1, conditions, self.action_dim) + x = self.to_torch(x) + x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + sorted_idx = y.argsort(0, descending=True).squeeze() + sorted_values = x[sorted_idx] + actions = sorted_values[:, :, : self.action_dim] + actions = actions.detach().cpu().numpy() + denorm_actions = self.de_normalize(actions, key="actions") + denorm_actions = denorm_actions[0, 0] + return denorm_actions diff --git a/examples/community/value_guided_diffuser.py b/examples/community/value_guided_diffuser.py new file mode 100644 index 000000000000..7e3f2b832b1f --- /dev/null +++ b/examples/community/value_guided_diffuser.py @@ -0,0 +1,99 @@ +import torch + +import tqdm +from diffusers import DiffusionPipeline +from diffusers.models.unet_1d import UNet1DModel +from diffusers.utils.dummy_pt_objects import DDPMScheduler + + +class ValueGuidedDiffuserPipeline(DiffusionPipeline): + def __init__( + self, + value_function: UNet1DModel, + unet: UNet1DModel, + scheduler: DDPMScheduler, + env, + ): + super().__init__() + self.value_function = value_function + self.unet = unet + self.scheduler = scheduler + self.env = env + self.data = env.get_dataset() + self.means = dict() + for key in self.data.keys(): + try: + self.means[key] = self.data[key].mean() + except: + pass + self.stds = dict() + for key in self.data.keys(): + try: + self.stds[key] = self.data[key].std() + except: + pass + self.state_dim = env.observation_space.shape[0] + self.action_dim = env.action_space.shape[0] + + def normalize(self, x_in, key): + return (x_in - self.means[key]) / self.stds[key] + + def de_normalize(self, x_in, key): + return x_in * self.stds[key] + self.means[key] + + def to_torch(self, x_in): + if type(x_in) is dict: + return {k: self.to_torch(v) for k, v in x_in.items()} + elif torch.is_tensor(x_in): + return x_in.to(self.unet.device) + return torch.tensor(x_in, device=self.unet.device) + + def reset_x0(self, x_in, cond, act_dim): + for key, val in cond.items(): + x_in[:, key, act_dim:] = val.clone() + return x_in + + def run_diffusion(self, x, conditions, n_guide_steps, scale): + batch_size = x.shape[0] + y = None + for i in tqdm.tqdm(self.scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) + # 3. call the sample function + for _ in range(n_guide_steps): + with torch.enable_grad(): + x.requires_grad_() + y = self.value_function(x.permute(0, 2, 1), timesteps).sample + grad = torch.autograd.grad([y.sum()], [x])[0] + + posterior_variance = self.scheduler._get_variance(i) + model_std = torch.exp(0.5 * posterior_variance) + grad = model_std * grad + grad[timesteps < 2] = 0 + x = x.detach() + x = x + scale * grad + x = self.reset_x0(x, conditions, self.action_dim) + prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) + x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] + + # 4. apply conditions to the trajectory + x = self.reset_x0(x, conditions, self.action_dim) + x = self.to_torch(x) + return x, y + + def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + obs = self.normalize(obs, "observations") + obs = obs[None].repeat(batch_size, axis=0) + conditions = {0: self.to_torch(obs)} + shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + x1 = torch.randn(shape, device=self.unet.device) + x = self.reset_x0(x1, conditions, self.action_dim) + x = self.to_torch(x) + x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + sorted_idx = y.argsort(0, descending=True).squeeze() + sorted_values = x[sorted_idx] + actions = sorted_values[:, :, : self.action_dim] + actions = actions.detach().cpu().numpy() + denorm_actions = self.de_normalize(actions, key="actions") + denorm_actions = denorm_actions[0, 0] + return denorm_actions diff --git a/examples/diffuser/run_diffuser.py b/examples/diffuser/run_diffuser.py new file mode 100644 index 000000000000..b29d89992dfc --- /dev/null +++ b/examples/diffuser/run_diffuser.py @@ -0,0 +1,122 @@ +import numpy as np +import torch + +import d4rl # noqa +import gym +import tqdm +import train_diffuser +from diffusers import DDPMScheduler, UNet1DModel + + +env_name = "hopper-medium-expert-v2" +env = gym.make(env_name) +data = env.get_dataset() # dataset is only used for normalization in this colab + +DEVICE = "cpu" +DTYPE = torch.float + +# diffusion model settings +n_samples = 4 # number of trajectories planned via diffusion +horizon = 128 # length of sampled trajectories +state_dim = env.observation_space.shape[0] +action_dim = env.action_space.shape[0] +num_inference_steps = 100 # number of difusion steps + + +# Two generators for different parts of the diffusion loop to work in colab +generator_cpu = torch.Generator(device="cpu") + +scheduler = DDPMScheduler(num_train_timesteps=100, beta_schedule="squaredcos_cap_v2") + +# 3 different pretrained models are available for this task. +# The horizion represents the length of trajectories used in training. +network = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE) +# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor256").to(device=DEVICE) +# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor512").to(device=DEVICE) + + +# network specific constants for inference +clip_denoised = network.clip_denoised +predict_epsilon = network.predict_epsilon + +# [ observation_dim ] --> [ n_samples x observation_dim ] +obs = env.reset() +total_reward = 0 +done = False +T = 300 +rollout = [obs.copy()] + +try: + for t in tqdm.tqdm(range(T)): + obs_raw = obs + + # normalize observations for forward passes + obs = train_diffuser.normalize(obs, data, "observations") + obs = obs[None].repeat(n_samples, axis=0) + conditions = {0: train_diffuser.to_torch(obs, device=DEVICE)} + + # constants for inference + batch_size = len(conditions[0]) + shape = (batch_size, horizon, state_dim + action_dim) + + # sample random initial noise vector + x1 = torch.randn(shape, device=DEVICE, generator=generator_cpu) + + # this model is conditioned from an initial state, so you will see this function + # multiple times to change the initial state of generated data to the state + # generated via env.reset() above or env.step() below + x = train_diffuser.reset_x0(x1, conditions, action_dim) + + # convert a np observation to torch for model forward pass + x = train_diffuser.to_torch(x) + + eta = 1.0 # noise factor for sampling reconstructed state + + # run the diffusion process + # for i in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): + for i in tqdm.tqdm(scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((batch_size,), i, device=DEVICE, dtype=torch.long) + + # 1. generate prediction from model + with torch.no_grad(): + residual = network(x, timesteps).sample + + # 2. use the model prediction to reconstruct an observation (de-noise) + obs_reconstruct = scheduler.step(residual, i, x, predict_epsilon=predict_epsilon)["prev_sample"] + + # 3. [optional] add posterior noise to the sample + if eta > 0: + noise = torch.randn(obs_reconstruct.shape, generator=generator_cpu).to(obs_reconstruct.device) + posterior_variance = scheduler._get_variance(i) # * noise + # no noise when t == 0 + # NOTE: original implementation missing sqrt on posterior_variance + obs_reconstruct = ( + obs_reconstruct + int(i > 0) * (0.5 * posterior_variance) * eta * noise + ) # MJ had as log var, exponentiated + + # 4. apply conditions to the trajectory + obs_reconstruct_postcond = train_diffuser.reset_x0(obs_reconstruct, conditions, action_dim) + x = train_diffuser.to_torch(obs_reconstruct_postcond) + plans = train_diffuser.helpers.to_np(x[:, :, :action_dim]) + # select random plan + idx = np.random.randint(plans.shape[0]) + # select action at correct time + action = plans[idx, 0, :] + actions = train_diffuser.de_normalize(action, data, "actions") + # execute action in environment + next_observation, reward, terminal, _ = env.step(action) + + # update return + total_reward += reward + print(f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}") + + # save observations for rendering + rollout.append(next_observation.copy()) + obs = next_observation +except KeyboardInterrupt: + pass + +print(f"Total reward: {total_reward}") +render = train_diffuser.MuJoCoRenderer(env) +train_diffuser.show_sample(render, np.expand_dims(np.stack(rollout), axis=0)) diff --git a/examples/diffuser/run_diffuser_value_guided.py b/examples/diffuser/run_diffuser_value_guided.py new file mode 100644 index 000000000000..4272ec2c3106 --- /dev/null +++ b/examples/diffuser/run_diffuser_value_guided.py @@ -0,0 +1,94 @@ +import d4rl # noqa +import gym +import tqdm + +# import train_diffuser +from diffusers import DDPMScheduler, DiffusionPipeline, UNet1DModel + + +config = dict( + n_samples=64, + horizon=32, + num_inference_steps=20, + n_guide_steps=2, + scale_grad_by_std=True, + scale=0.1, + eta=0.0, + t_grad_cutoff=2, + device="cpu", +) + + +def _run(): + env_name = "hopper-medium-v2" + env = gym.make(env_name) + + # Cuda settings for colab + # torch.cuda.get_device_name(0) + DEVICE = config["device"] + + # Two generators for different parts of the diffusion loop to work in colab + scheduler = DDPMScheduler( + num_train_timesteps=config["num_inference_steps"], + beta_schedule="squaredcos_cap_v2", + clip_sample=False, + variance_type="fixed_small_log", + ) + + # 3 different pretrained models are available for this task. + # The horizion represents the length of trajectories used in training. + # network = ValueFunction(training_horizon=horizon, dim=32, dim_mults=(1, 2, 4, 8), transition_dim=14, cond_dim=11) + + network = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32").to(device=DEVICE).eval() + unet = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-unet-hor32").to(device=DEVICE).eval() + pipeline = DiffusionPipeline.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", + value_function=network, + unet=unet, + scheduler=scheduler, + env=env, + custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", + ) + # unet = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE) + # network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor512").to(device=DEVICE) + + # add a batch dimension and repeat for multiple samples + # [ observation_dim ] --> [ n_samples x observation_dim ] + env.seed(0) + obs = env.reset() + total_reward = 0 + total_score = 0 + T = 1000 + rollout = [obs.copy()] + try: + for t in tqdm.tqdm(range(T)): + # 1. Call the policy + # normalize observations for forward passes + denorm_actions = pipeline(obs, planning_horizon=32) + + # execute action in environment + next_observation, reward, terminal, _ = env.step(denorm_actions) + score = env.get_normalized_score(total_reward) + # update return + total_reward += reward + total_score += score + print( + f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" + f" {total_score}" + ) + # save observations for rendering + rollout.append(next_observation.copy()) + + obs = next_observation + except KeyboardInterrupt: + pass + + print(f"Total reward: {total_reward}") + + +def run(): + _run() + + +if __name__ == "__main__": + run() diff --git a/examples/diffuser/train_diffuser.py b/examples/diffuser/train_diffuser.py new file mode 100644 index 000000000000..b063a0456d97 --- /dev/null +++ b/examples/diffuser/train_diffuser.py @@ -0,0 +1,312 @@ +import os +import warnings + +import numpy as np +import torch + +import d4rl # noqa +import gym +import mediapy as media +import mujoco_py as mjc +import tqdm +from diffusers import DDPMScheduler, UNet1DModel + + +# Define some helper functions + + +DTYPE = torch.float + + +def normalize(x_in, data, key): + means = data[key].mean(axis=0) + stds = data[key].std(axis=0) + return (x_in - means) / stds + + +def de_normalize(x_in, data, key): + means = data[key].mean(axis=0) + stds = data[key].std(axis=0) + return x_in * stds + means + + +def to_torch(x_in, dtype=None, device="cuda"): + dtype = dtype or DTYPE + device = device + if type(x_in) is dict: + return {k: to_torch(v, dtype, device) for k, v in x_in.items()} + elif torch.is_tensor(x_in): + return x_in.to(device).type(dtype) + return torch.tensor(x_in, dtype=dtype, device=device) + + +def reset_x0(x_in, cond, act_dim): + for key, val in cond.items(): + x_in[:, key, act_dim:] = val.clone() + return x_in + + +def run_diffusion(x, scheduler, network, unet, conditions, action_dim, config): + y = None + for i in tqdm.tqdm(scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((config["n_samples"],), i, device=config["device"], dtype=torch.long) + # 3. call the sample function + for _ in range(config["n_guide_steps"]): + with torch.enable_grad(): + x.requires_grad_() + y = network(x, timesteps).sample + grad = torch.autograd.grad([y.sum()], [x])[0] + if config["scale_grad_by_std"]: + posterior_variance = scheduler._get_variance(i) + model_std = torch.exp(0.5 * posterior_variance) + grad = model_std * grad + grad[timesteps < config["t_grad_cutoff"]] = 0 + x = x.detach() + x = x + config["scale"] * grad + x = reset_x0(x, conditions, action_dim) + # with torch.no_grad(): + prev_x = unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) + x = scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] + + # 3. [optional] add posterior noise to the sample + if config["eta"] > 0: + noise = torch.randn(x.shape).to(x.device) + posterior_variance = scheduler._get_variance(i) # * noise + # no noise when t == 0 + # NOTE: original implementation missing sqrt on posterior_variance + x = x + int(i > 0) * (0.5 * posterior_variance) * config["eta"] * noise # MJ had as log var, exponentiated + + # 4. apply conditions to the trajectory + x = reset_x0(x, conditions, action_dim) + x = to_torch(x, device=config["device"]) + # y = network(x, timesteps).sample + return x, y + + +def to_np(x_in): + if torch.is_tensor(x_in): + x_in = x_in.detach().cpu().numpy() + return x_in + + +# from MJ's Diffuser code +# https://github.com/jannerm/diffuser/blob/76ae49ae85ba1c833bf78438faffdc63b8b4d55d/diffuser/utils/colab.py#L79 +def mkdir(savepath): + """ + returns `True` iff `savepath` is created + """ + if not os.path.exists(savepath): + os.makedirs(savepath) + return True + else: + return False + + +def show_sample(renderer, observations, filename="sample.mp4", savebase="videos"): + """ + observations : [ batch_size x horizon x observation_dim ] + """ + + mkdir(savebase) + savepath = os.path.join(savebase, filename) + + images = [] + for rollout in observations: + # [ horizon x height x width x channels ] + img = renderer._renders(rollout, partial=True) + images.append(img) + + # [ horizon x height x (batch_size * width) x channels ] + images = np.concatenate(images, axis=2) + media.write_video(savepath, images, fps=60) + media.show_video(images, codec="h264", fps=60) + return images + + +# Code adapted from Michael Janner +# source: https://github.com/jannerm/diffuser/blob/main/diffuser/utils/rendering.py + + +def env_map(env_name): + """ + map D4RL dataset names to custom fully-observed + variants for rendering + """ + if "halfcheetah" in env_name: + return "HalfCheetahFullObs-v2" + elif "hopper" in env_name: + return "HopperFullObs-v2" + elif "walker2d" in env_name: + return "Walker2dFullObs-v2" + else: + return env_name + + +def get_image_mask(img): + background = (img == 255).all(axis=-1, keepdims=True) + mask = ~background.repeat(3, axis=-1) + return mask + + +def atmost_2d(x): + while x.ndim > 2: + x = x.squeeze(0) + return x + + +def set_state(env, state): + qpos_dim = env.sim.data.qpos.size + qvel_dim = env.sim.data.qvel.size + if not state.size == qpos_dim + qvel_dim: + warnings.warn( + f"[ utils/rendering ] Expected state of size {qpos_dim + qvel_dim}, but got state of size {state.size}" + ) + state = state[: qpos_dim + qvel_dim] + + env.set_state(state[:qpos_dim], state[qpos_dim:]) + + +class MuJoCoRenderer: + """ + default mujoco renderer + """ + + def __init__(self, env): + if type(env) is str: + env = env_map(env) + self.env = gym.make(env) + else: + self.env = env + # - 1 because the envs in renderer are fully-observed + # @TODO : clean up + self.observation_dim = np.prod(self.env.observation_space.shape) - 1 + self.action_dim = np.prod(self.env.action_space.shape) + try: + self.viewer = mjc.MjRenderContextOffscreen(self.env.sim) + except: + print("[ utils/rendering ] Warning: could not initialize offscreen renderer") + self.viewer = None + + def pad_observation(self, observation): + state = np.concatenate( + [ + np.zeros(1), + observation, + ] + ) + return state + + def pad_observations(self, observations): + qpos_dim = self.env.sim.data.qpos.size + # xpos is hidden + xvel_dim = qpos_dim - 1 + xvel = observations[:, xvel_dim] + xpos = np.cumsum(xvel) * self.env.dt + states = np.concatenate( + [ + xpos[:, None], + observations, + ], + axis=-1, + ) + return states + + def render(self, observation, dim=256, partial=False, qvel=True, render_kwargs=None, conditions=None): + if type(dim) == int: + dim = (dim, dim) + + if self.viewer is None: + return np.zeros((*dim, 3), np.uint8) + + if render_kwargs is None: + xpos = observation[0] if not partial else 0 + render_kwargs = {"trackbodyid": 2, "distance": 3, "lookat": [xpos, -0.5, 1], "elevation": -20} + + for key, val in render_kwargs.items(): + if key == "lookat": + self.viewer.cam.lookat[:] = val[:] + else: + setattr(self.viewer.cam, key, val) + + if partial: + state = self.pad_observation(observation) + else: + state = observation + + qpos_dim = self.env.sim.data.qpos.size + if not qvel or state.shape[-1] == qpos_dim: + qvel_dim = self.env.sim.data.qvel.size + state = np.concatenate([state, np.zeros(qvel_dim)]) + + set_state(self.env, state) + + self.viewer.render(*dim) + data = self.viewer.read_pixels(*dim, depth=False) + data = data[::-1, :, :] + return data + + def _renders(self, observations, **kwargs): + images = [] + for observation in observations: + img = self.render(observation, **kwargs) + images.append(img) + return np.stack(images, axis=0) + + def renders(self, samples, partial=False, **kwargs): + if partial: + samples = self.pad_observations(samples) + partial = False + + sample_images = self._renders(samples, partial=partial, **kwargs) + + composite = np.ones_like(sample_images[0]) * 255 + + for img in sample_images: + mask = get_image_mask(img) + composite[mask] = img[mask] + + return composite + + def __call__(self, *args, **kwargs): + return self.renders(*args, **kwargs) + + +env_name = "hopper-medium-expert-v2" +env = gym.make(env_name) +data = env.get_dataset() # dataset is only used for normalization in this colab + +# Cuda settings for colab +# torch.cuda.get_device_name(0) +DEVICE = "cpu" +DTYPE = torch.float + +# diffusion model settings +n_samples = 4 # number of trajectories planned via diffusion +horizon = 128 # length of sampled trajectories +state_dim = env.observation_space.shape[0] +action_dim = env.action_space.shape[0] +num_inference_steps = 100 # number of difusion steps + +obs = env.reset() +obs_raw = obs + +# normalize observations for forward passes +obs = normalize(obs, data, "observations") + + +# Two generators for different parts of the diffusion loop to work in colab +generator = torch.Generator(device="cuda") +generator_cpu = torch.Generator(device="cpu") +network = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE) + +scheduler = DDPMScheduler(num_train_timesteps=100, beta_schedule="squaredcos_cap_v2") +optimizer = torch.optim.AdamW( + network.parameters(), + lr=0.001, + betas=(0.95, 0.99), + weight_decay=1e-6, + eps=1e-8, +) + +# TODO: Flesh this out using accelerate library (a la other examples) diff --git a/scripts/convert_models_diffuser_to_diffusers.py b/scripts/convert_models_diffuser_to_diffusers.py new file mode 100644 index 000000000000..b154295e9726 --- /dev/null +++ b/scripts/convert_models_diffuser_to_diffusers.py @@ -0,0 +1,77 @@ +import json +import os + +import torch + +from diffusers import UNet1DModel + + +os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True) +os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True) + +os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True) + + +def unet(hor): + if hor == 128: + down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") + block_out_channels = (32, 128, 256) + up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D") + + elif hor == 32: + down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") + block_out_channels = (32, 64, 128, 256) + up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D") + model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch") + state_dict = model.state_dict() + config = dict( + down_block_types=down_block_types, + block_out_channels=block_out_channels, + up_block_types=up_block_types, + layers_per_block=1, + ) + hf_value_function = UNet1DModel(**config) + print(f"length of state dict: {len(state_dict.keys())}") + print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") + mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys())) + for k, v in mapping.items(): + state_dict[v] = state_dict.pop(k) + hf_value_function.load_state_dict(state_dict) + + torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin") + with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f: + json.dump(config, f) + + +def value_function(): + config = dict( + in_channels=14, + down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + up_block_types=(), + out_block_type="ValueFunction", + block_out_channels=(32, 64, 128, 256), + layers_per_block=1, + always_downsample=True, + ) + + model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch") + state_dict = model + hf_value_function = UNet1DModel(**config) + print(f"length of state dict: {len(state_dict.keys())}") + print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") + + mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys())) + for k, v in mapping.items(): + state_dict[v] = state_dict.pop(k) + + hf_value_function.load_state_dict(state_dict) + + torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin") + with open("hub/hopper-medium-v2/value_function/config.json", "w") as f: + json.dump(config, f) + + +if __name__ == "__main__": + unet(32) + # unet(128) + value_function() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fa97effaaf0a..7088e560dd66 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,7 +18,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel + from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index c5d53b2feb4b..b771aaac8467 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -19,6 +19,7 @@ from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel + from .unet_rl import ValueFunction from .vae import AutoencoderKL, VQModel if is_flax_available(): diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 3ede756c9b3d..b720c78b8833 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -17,14 +17,12 @@ import torch import torch.nn as nn -from diffusers.models.resnet import ResidualTemporalBlock1D -from diffusers.models.unet_1d_blocks import get_down_block, get_up_block +from diffusers.models.unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .embeddings import TimestepEmbedding, Timesteps -from .resnet import rearrange_dims @dataclass @@ -62,10 +60,13 @@ def __init__( out_channels: int = 14, down_block_types: Tuple[str] = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), up_block_types: Tuple[str] = ("UpResnetBlock1D", "UpResnetBlock1D"), + mid_block_types: Tuple[str] = ("MidResTemporalBlock1D", "MidResTemporalBlock1D"), + out_block_type: str = "OutConv1DBlock", block_out_channels: Tuple[int] = (32, 128, 256), act_fn: str = "mish", norm_num_groups: int = 8, layers_per_block: int = 1, + always_downsample: bool = False, ): super().__init__() @@ -95,14 +96,30 @@ def __init__( in_channels=input_channel, out_channels=output_channel, temb_channels=block_out_channels[0], - add_downsample=not is_final_block, + add_downsample=not is_final_block or always_downsample, ) self.down_blocks.append(down_block) # mid - self.mid_block1 = ResidualTemporalBlock1D(mid_dim, mid_dim, embed_dim=block_out_channels[0]) - self.mid_block2 = ResidualTemporalBlock1D(mid_dim, mid_dim, embed_dim=block_out_channels[0]) - + self.mid_blocks = nn.ModuleList([]) + for i, mid_block_type in enumerate(mid_block_types): + if always_downsample: + mid_block = get_mid_block( + mid_block_type, + in_channels=mid_dim // (i + 1), + out_channels=mid_dim // ((i + 1) * 2), + embed_dim=block_out_channels[0], + add_downsample=True, + ) + else: + mid_block = get_mid_block( + mid_block_type, + in_channels=mid_dim, + out_channels=mid_dim, + embed_dim=block_out_channels[0], + add_downsample=False, + ) + self.mid_blocks.append(mid_block) # up reversed_block_out_channels = list(reversed(block_out_channels)) for i, up_block_type in enumerate(up_block_types): @@ -123,13 +140,14 @@ def __init__( # out num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) - self.final_conv1d_1 = nn.Conv1d(block_out_channels[0], block_out_channels[0], 5, padding=2) - self.final_conv1d_gn = nn.GroupNorm(num_groups_out, block_out_channels[0]) - if act_fn == "silu": - self.final_conv1d_act = nn.SiLU() - if act_fn == "mish": - self.final_conv1d_act = nn.Mish() - self.final_conv1d_2 = nn.Conv1d(block_out_channels[0], out_channels, 1) + self.out_block = get_out_block( + out_block_type=out_block_type, + num_groups_out=num_groups_out, + embed_dim=block_out_channels[0], + out_channels=out_channels, + act_fn=act_fn, + fc_dim=mid_dim // 4, + ) def forward( self, @@ -166,20 +184,15 @@ def forward( down_block_res_samples.append(res_samples[0]) # 3. mid - sample = self.mid_block1(sample, temb) - sample = self.mid_block2(sample, temb) + for mid_block in self.mid_blocks: + sample = mid_block(sample, temb) # 4. up for up_block in self.up_blocks: sample = up_block(hidden_states=sample, res_hidden_states=down_block_res_samples.pop(), temb=temb) # 5. post-process - sample = self.final_conv1d_1(sample) - sample = rearrange_dims(sample) - sample = self.final_conv1d_gn(sample) - sample = rearrange_dims(sample) - sample = self.final_conv1d_act(sample) - sample = self.final_conv1d_2(sample) + sample = self.out_block(sample, temb) if not return_dict: return (sample,) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 40e25fb43afb..a00372faf7d9 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from torch import nn -from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D +from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims class DownResnetBlock1D(nn.Module): @@ -173,6 +173,66 @@ class UpBlock1DNoSkip(nn.Module): pass +class MidResTemporalBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, embed_dim, add_downsample): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_downsample = add_downsample + self.resnet = ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim) + + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True) + else: + self.downsample = nn.Identity() + + def forward(self, sample, temb): + sample = self.resnet(sample, temb) + sample = self.downsample(sample) + return sample + + +class OutConv1DBlock(nn.Module): + def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): + super().__init__() + self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) + self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) + if act_fn == "silu": + self.final_conv1d_act = nn.SiLU() + if act_fn == "mish": + self.final_conv1d_act = nn.Mish() + self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) + + def forward(self, sample, t): + sample = self.final_conv1d_1(sample) + sample = rearrange_dims(sample) + sample = self.final_conv1d_gn(sample) + sample = rearrange_dims(sample) + sample = self.final_conv1d_act(sample) + sample = self.final_conv1d_2(sample) + return sample + + +class OutValueFunctionBlock(nn.Module): + def __init__(self, fc_dim, embed_dim): + super().__init__() + self.final_block = nn.ModuleList( + [ + nn.Linear(fc_dim + embed_dim, fc_dim // 2), + nn.Mish(), + nn.Linear(fc_dim // 2, 1), + ] + ) + + def forward(self, sample, t): + sample = sample.view(sample.shape[0], -1) + sample = torch.cat((sample, t), dim=-1) + for layer in self.final_block: + sample = layer(sample) + + return sample + + def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample): if down_block_type == "DownResnetBlock1D": return DownResnetBlock1D( @@ -195,5 +255,19 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_chan temb_channels=temb_channels, add_upsample=add_upsample, ) - + elif up_block_type == "Identity": + return nn.Identity() raise ValueError(f"{up_block_type} does not exist.") + + +def get_mid_block(mid_block_type, in_channels, out_channels, embed_dim, add_downsample): + if mid_block_type == "MidResTemporalBlock1D": + return MidResTemporalBlock1D(in_channels, out_channels, embed_dim, add_downsample) + raise ValueError(f"{mid_block_type} does not exist.") + + +def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim): + if out_block_type == "OutConv1DBlock": + return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) + elif out_block_type == "ValueFunction": + return OutValueFunctionBlock(fc_dim, embed_dim) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py new file mode 100644 index 000000000000..66822f99b198 --- /dev/null +++ b/src/diffusers/models/unet_rl.py @@ -0,0 +1,135 @@ +# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py +from dataclasses import dataclass +from typing import Tuple, Union + +import torch +import torch.nn as nn + +from diffusers.models.resnet import Downsample1D, ResidualTemporalBlock1D +from diffusers.models.unet_1d_blocks import get_down_block + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..utils import BaseOutput +from .embeddings import TimestepEmbedding, Timesteps + + +@dataclass +class ValueFunctionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch, horizon, 1)`): + Hidden states output. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class ValueFunction(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels=14, + down_block_types: Tuple[str] = ( + "DownResnetBlock1D", + "DownResnetBlock1D", + "DownResnetBlock1D", + "DownResnetBlock1D", + ), + block_out_channels: Tuple[int] = (32, 64, 128, 256), + act_fn: str = "mish", + norm_num_groups: int = 8, + layers_per_block: int = 1, + ): + super().__init__() + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(num_channels=block_out_channels[0], flip_sin_to_cos=False, downscale_freq_shift=1) + self.time_mlp = TimestepEmbedding( + channel=block_out_channels[0], time_embed_dim=time_embed_dim, act_fn="mish", out_dim=block_out_channels[0] + ) + + self.blocks = nn.ModuleList([]) + mid_dim = block_out_channels[-1] + + output_channel = in_channels + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + + down_block_type = down_block_types[i] + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_downsample=True, + ) + self.blocks.append(down_block) + + ## + self.mid_block1 = ResidualTemporalBlock1D(mid_dim, mid_dim // 2, embed_dim=block_out_channels[0]) + self.mid_down1 = Downsample1D(mid_dim // 2, use_conv=True) + ## + self.mid_block2 = ResidualTemporalBlock1D(mid_dim // 2, mid_dim // 4, embed_dim=block_out_channels[0]) + self.mid_down2 = Downsample1D(mid_dim // 4, use_conv=True) + ## + fc_dim = mid_dim // 4 + self.final_block = nn.ModuleList( + [ + nn.Linear(fc_dim + block_out_channels[0], fc_dim // 2), + nn.Mish(), + nn.Linear(fc_dim // 2, 1), + ] + ) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + return_dict: bool = True, + ) -> Union[ValueFunctionOutput, Tuple]: + """r + Args: + sample (`torch.FloatTensor`): (batch, horizon, obs_dimension + action_dimension) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): batch (batch) timesteps + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_rl.ValueFunctionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_rl.ValueFunctionOutput`] or `tuple`: [`~models.unet_rl.ValueFunctionOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + sample = sample.permute(0, 2, 1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + t = self.time_proj(timesteps) + t = self.time_mlp(t) + down_block_res_samples = [] + + # 2. down + for downsample_block in self.blocks: + sample, res_samples = downsample_block(hidden_states=sample, temb=t) + down_block_res_samples.append(res_samples[0]) + + # 3. mid + sample = self.mid_block1(sample, t) + sample = self.mid_down1(sample) + sample = self.mid_block2(sample, t) + sample = self.mid_down2(sample) + + sample = sample.view(sample.shape[0], -1) + sample = torch.cat((sample, t), dim=-1) + for layer in self.final_block: + sample = layer(sample) + + if not return_dict: + return (sample,) + + return ValueFunctionOutput(sample=sample) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 04c92904a660..06596bd6091f 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -200,6 +200,7 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": variance = torch.log(torch.clamp(variance, min=1e-20)) + variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": variance = self.betas[t] elif variance_type == "fixed_large_log": @@ -283,7 +284,10 @@ def step( noise = torch.randn( model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator ).to(model_output.device) - variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise + if self.variance_type == "fixed_small_log": + variance = self._get_variance(t, predicted_variance=predicted_variance) * noise + else: + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise pred_prev_sample = pred_prev_sample + variance diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index e1dbdfaa4611..55f373af8a9b 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -20,7 +20,7 @@ import torch -from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel +from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction from diffusers.utils import floats_tensor, slow, torch_device from .test_modeling_common import ModelTesterMixin @@ -524,3 +524,86 @@ def test_output_pretrained(self): def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass + + +class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): + model_class = ValueFunction + + @property + def dummy_input(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 14, 16) + + @property + def output_shape(self): + return (4, 14, 1) + + def test_ema_training(self): + pass + + def test_training(self): + pass + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64, 128, 256), + "in_channels": 14, + "out_channels": 14, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + unet, loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-unet-hor32", output_loading_info=True + ) + value_function, vf_loading_info = ValueFunction.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True + ) + self.assertIsNotNone(unet) + self.assertEqual(len(loading_info["missing_keys"]), 0) + self.assertIsNotNone(value_function) + self.assertEqual(len(vf_loading_info["missing_keys"]), 0) + + unet.to(torch_device) + value_function.to(torch_device) + image = value_function(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + value_function, vf_loading_info = ValueFunction.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True + ) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = value_function.in_channels + seq_len = 14 + noise = torch.randn((1, seq_len, num_features)).permute( + 0, 2, 1 + ) # match original, we can update values and remove + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = value_function(noise, time_step).sample + + # fmt: off + expected_output_slice = torch.tensor([207.0272] * seq_len) + # fmt: on + self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) + + def test_forward_with_norm_groups(self): + # Not implemented yet for this UNet + pass