-
Notifications
You must be signed in to change notification settings - Fork 389
Description
Hi i am facing issue in Quantization on SDXL 1 when i export UNet component to .ONNX model and the i am not able to create .ONNX Model
when i export using the code given below it export the unet.ONNX model of 4mb which is too small
and also facing an issue while loading it for inference
actually i want to quantized the model and used for edge devices for image to image inference to save memory and time could you guide me completely to Quantized the SDXL-1
!pip install -q diffusers transformers accelerate onnx onnxruntime
from diffusers import UNet2DConditionModel
import torch
Load the UNet model from SDXL
unet = UNet2DConditionModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="unet",
torch_dtype=torch.float32 # use float32 for export
)
unet.eval()
import torch.nn as nn
class UNetWrapper(nn.Module):
def init(self, unet):
super().init()
self.unet = unet
def forward(self, sample, timestep, encoder_hidden_states, text_embeds, time_ids):
added_cond_kwargs = {
"text_embeds": text_embeds,
"time_ids": time_ids
}
return self.unet(
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
added_cond_kwargs=added_cond_kwargs
).sample
sample = torch.randn(1, 4, 64, 64) # latent input
timestep = torch.tensor([10], dtype=torch.int64) # denoising step
encoder_hidden_states = torch.randn(1, 77, 2048) # text encoder output
text_embeds = torch.randn(1, 1280) # pooled text embeddings
time_ids = torch.randn(1, 6) # time conditioning vector
from pathlib import Path
Define output path
onnx_path = "/content/drive/MyDrive/Umair/SkinCare/Quantization/Dynamic Quantization unet/unet_sdxl.onnx"
Path(onnx_path).parent.mkdir(parents=True, exist_ok=True)
Export the model
torch.onnx.export(
UNetWrapper(unet),
(sample, timestep, encoder_hidden_states, text_embeds, time_ids),
onnx_path,
input_names=["sample", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"],
output_names=["out_sample"],
dynamic_axes={
"sample": {0: "batch"},
"encoder_hidden_states": {0: "batch"},
"text_embeds": {0: "batch"},
"time_ids": {0: "batch"},
"out_sample": {0: "batch"},
},
opset=17
)
print("✅ UNet exported successfully to:", onnx_path)
from onnxruntime.quantization import quantize_dynamic, QuantType
quantized_path = onnx_path.replace(".onnx", "_int8.onnx")
quantize_dynamic(
model_input=onnx_path,
model_output=quantized_path,
weight_type=QuantType.QInt8
)
print("✅ Quantized model saved to:", quantized_path)
code to check the size of .onnx and _int8.onnx
print(f".onnx file size: {os.path.getsize(onnx_model_path) / (10241024):.2f} MB")
print(f"_int8.onnx file size: {os.path.getsize(quant_model_path) / (10241024):.2f} MB")
.onnx file size: 3.98 MB
_int8.onnx file size: 2458.93 MB
and also facing an issue while loading it for infrence