Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 57 additions & 17 deletions examples/helloworld/tf_example1/test.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,75 @@
"""Example: post-training quantization with neural-compressor.

Simplified version that keeps tf.flags usage for consistency
with the original script. Uses print for logging to stay minimal.
"""

from pathlib import Path
import tensorflow as tf

from neural_compressor.data import TensorflowImageRecord
from neural_compressor.data import BilinearImagenetTransform
from neural_compressor.data import ComposeTransform
from neural_compressor.data import DefaultDataLoader
from neural_compressor.quantization import fit
from neural_compressor.config import PostTrainingQuantConfig
from neural_compressor import Metric
from neural_compressor.config import PostTrainingQuantConfig
from neural_compressor.data import (
BilinearImagenetTransform,
ComposeTransform,
DefaultDataLoader,
TensorflowImageRecord,
)
from neural_compressor.quantization import fit

flags = tf.compat.v1.flags
FLAGS = flags.FLAGS

flags.DEFINE_string('dataset_location', None, 'location of calibration dataset and evaluate dataset')
# keep the same flag as original code
flags.DEFINE_string(
"dataset_location",
None,
"location of calibration dataset and evaluate dataset",
)


calib_dataset = TensorflowImageRecord(root=FLAGS.dataset_location, transform= \
ComposeTransform(transform_list= [BilinearImagenetTransform(height=224, width=224)]))
calib_dataloader = DefaultDataLoader(dataset=calib_dataset, batch_size=10)
def build_dataloader(root: str, batch_size: int) -> DefaultDataLoader:
"""Create a DefaultDataLoader for given root and batch size."""
transform = ComposeTransform(
transform_list=[BilinearImagenetTransform(height=224, width=224)]
)
dataset = TensorflowImageRecord(root=root, transform=transform)
return DefaultDataLoader(dataset=dataset, batch_size=batch_size)

eval_dataset = TensorflowImageRecord(root=FLAGS.dataset_location, transform=ComposeTransform(transform_list= \
[BilinearImagenetTransform(height=224, width=224)]))
eval_dataloader = DefaultDataLoader(dataset=eval_dataset, batch_size=1)

def main():
def main() -> None:
"""Run post-training quantization with predefined configuration."""
dataset_location = FLAGS.dataset_location or "./dataset"
model_path = "./mobilenet_v1_1.0_224_frozen.pb"
calib_batch_size = 10
eval_batch_size = 1
calib_size = 20

# basic checks
ds_path = Path(dataset_location)
model_file = Path(model_path)

if not ds_path.exists():
raise FileNotFoundError(f"Dataset path not found: {ds_path}")
if not model_file.exists():
raise FileNotFoundError(f"Model file not found: {model_file}")

# build dataloaders
calib_dataloader = build_dataloader(root=str(ds_path), batch_size=calib_batch_size)
eval_dataloader = build_dataloader(root=str(ds_path), batch_size=eval_batch_size)

# metric and config
top1 = Metric(name="topk", k=1)
config = PostTrainingQuantConfig(calibration_sampling_size=[20])
config = PostTrainingQuantConfig(calibration_sampling_size=[calib_size])

q_model = fit(
model="./mobilenet_v1_1.0_224_frozen.pb",
model=str(model_file),
conf=config,
calib_dataloader=calib_dataloader,
eval_dataloader=eval_dataloader,
eval_metric=top1)
eval_metric=top1,
)


if __name__ == "__main__":
main()