Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
estimate_tuning_block_mem,
find_matching_blocks,
flatten_list,
get_avg_bits,
get_block_names,
get_device_memory,
get_fp_layer_names,
Expand Down Expand Up @@ -1710,6 +1711,11 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
# because it may cause the gguf format to not be exported normally.
self.model = _handle_moe_model(self.model, formats=formats)
self.has_qlayer_outside_block = self._set_layerwise_config(self.layer_config)
average_bits = get_avg_bits(self.model)
average_bits_w_lm_head = get_avg_bits(self.model, with_lm_head=True)
logger.info(f"The target average bits: {average_bits:.3f} bits")
if average_bits_w_lm_head != average_bits:
logger.debug(f"The target average bits (including lm_head): {average_bits_w_lm_head:.3f} bits")
if not hasattr(self, "formats"):
logger.warning("this API is deprecated, please use `quantize_and_save` instead")
else:
Expand Down
56 changes: 56 additions & 0 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2755,3 +2755,59 @@ def is_mllm_model(model_or_path: Union[str, torch.nn.Module]):
return True

return False


def get_avg_bits(module, with_lm_head=False):
"""
Calculates the average number of bits per weight element for supported layers in a given module.

Iterates through all named modules in the module, accumulating the total number of weight elements
and the corresponding bit usage, including additional scale bits for specific data types.

Args:
module: A neural network module containing layers to be analyzed.

Returns:
float: The average number of bits per weight element across all supported layers.

Note:
- Only layers of types specified in SUPPORTED_LAYER_TYPES are considered.
- For certain data types ("fp4_v2", "nv_fp4", "mx_fp4", "mx_fp8"), scale bits are added.
- For "fp4_v2" and "nv_fp4", an additional 32 global scale bits are included.
"""

def _get_scale_num(bits, group_size, input_features, weight_numel):
if bits >= 16:
return 0
if group_size == 0:
return 1
elif group_size == -1:
return input_features
else:
return weight_numel // group_size

all_numel = 0
all_bits = 0

lm_head_name = get_lm_head_name(module)
if lm_head_name is None:
with_lm_head = False
for n, m in module.named_modules():
if n == lm_head_name and not with_lm_head:
continue
if isinstance(m, SUPPORTED_LAYER_TYPES):
# get weight bits
m_numel = m.weight.numel()
all_numel += m_numel
w_bits = m.bits * m_numel
all_bits += w_bits
# get scale bits
scale_num = _get_scale_num(m.bits, m.group_size, m.weight.shape[-1], m_numel)
bits_per_scale = 16 if m.data_type == "int" else 8
scale_bits = bits_per_scale * scale_num
if m.data_type in ("fp4_v2", "nv_fp"):
scale_bits += 32 # global scale bits
all_bits += scale_bits

avg_bits = all_bits / all_numel if all_numel > 0 else 0
return round(avg_bits, 6)