Skip to content

Commit 2cc0e9c

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6a45287 commit 2cc0e9c

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

tests/jax/test_custom_call_compute.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,7 @@ def test_grouped_qdq(
682682
x = jax.random.uniform(subkeys[1], input_shape, in_dtype)
683683

684684
grouped_quantizer = QuantizerFactory.create(
685-
q_params=QuantizerParams(
686-
q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
687-
),
685+
q_params=QuantizerParams(q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout),
688686
n_groups=n_groups,
689687
)
690688

transformer_engine/jax/quantize/quantizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -804,15 +804,21 @@ def create(
804804
Returns:
805805
A single quantizer or tuple of quantizers
806806
"""
807-
assert all(arg_name not in kwargs for arg_name in ['scaling_mode', 'q_layout', 'q_dtype']), "Please use q_params instead of passing scaling_mode, q_layout, and q_dtype as separate args."
807+
assert all(
808+
arg_name not in kwargs for arg_name in ["scaling_mode", "q_layout", "q_dtype"]
809+
), (
810+
"Please use q_params instead of passing scaling_mode, q_layout, and q_dtype as separate"
811+
" args."
812+
)
808813
assert isinstance(q_params.scaling_mode, ScalingMode), "Invalid scaling_mode type"
809814
if q_params.scaling_mode == ScalingMode.NO_SCALING:
810815
quantizers = [None] * n_quantizers
811816
else:
812817
if n_groups:
813818
if n_quantizers != 1:
814819
warnings.warn(
815-
"Using more than one GroupedQuantizer for a grouped input is not recommended"
820+
"Using more than one GroupedQuantizer for a grouped input is not"
821+
" recommended"
816822
)
817823
quantizer_type = GroupedQuantizer
818824
kwargs["n_groups"] = n_groups

0 commit comments

Comments
 (0)