Skip to content

Commit 62ba15e

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
Rewrite assert statement with torch._assert under config (pytorch#88246)
This diff rewrites assert statement in python with torch._assert under config. The resulting graph looks something like: ``` SOURCE CODE: def f(x): assert x[0] == 3 return x.cos() CAPTURED GRAPH: graph(): %arg0 : [#users=2] = placeholder[target=arg0] %getitem : [#users=1] = call_function[target=operator.getitem](args = (%arg0, 0), kwargs = {}) %eq : [#users=1] = call_function[target=operator.eq](args = (%getitem, 3), kwargs = {}) %_assert : [#users=0] = call_function[target=torch._assert](args = (%eq, "assertion_error"), kwargs = {}) %cos : [#users=1] = call_method[target=cos](args = (%arg0,), kwargs = {}) return cos ``` Note that this introduces side-effect as it could error out while executing graph, but the assertion can eliminated via DCE if we choose to ignore it. Pull Request resolved: pytorch#88246 Approved by: https://github.com/jansel
1 parent b815f1f commit 62ba15e

File tree

3 files changed

+189
-0
lines changed

3 files changed

+189
-0
lines changed

test/dynamo/test_repros.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,6 +1938,98 @@ def fn(x):
19381938
self.assertEqual(cnt.frame_count, 1)
19391939
self.assertEqual(cnt.op_count, 1)
19401940

1941+
@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
1942+
def test_rewrite_assert_with_msg(self):
1943+
def f(x):
1944+
b = x.sin()
1945+
assert x[0] == 3, "First dim need to be 3"
1946+
return x.cos() + b
1947+
1948+
args = (torch.Tensor([3, 4, 5]),)
1949+
cnt = torch._dynamo.testing.CompileCounter()
1950+
1951+
opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
1952+
self.assertTrue(same(f(*args), opt_f(*args)))
1953+
self.assertEqual(cnt.op_count, 6)
1954+
self.assertEqual(cnt.frame_count, 1)
1955+
1956+
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
1957+
self.assertTrue(same(exported(*args), f(*args)))
1958+
1959+
with self.assertRaisesRegex(AssertionError, ""):
1960+
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
1961+
1962+
@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
1963+
def test_not_rewrite_assert_for_other_errors(self):
1964+
def f(x):
1965+
b = x.sin()
1966+
if not x.sum() <= 3:
1967+
raise ValueError("input sum needs to be 3")
1968+
return x.cos() + b
1969+
1970+
args = (torch.Tensor([3, 4, 5]),)
1971+
opt_fn = torch._dynamo.optimize("eager")(f)
1972+
with self.assertRaisesRegex(ValueError, "input sum needs to be 3"):
1973+
opt_fn(*args)
1974+
1975+
# TODO (tmanlaibaatar) handle data-dependent fstring in assert statement.
1976+
@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
1977+
def test_rewrite_assert_with_fstring_msg(self):
1978+
def f(x):
1979+
b = x.sin()
1980+
assert x[0] == 3, f"First dim need to be {x[0]}"
1981+
return x.cos() + b
1982+
1983+
args = (torch.Tensor([3, 4, 5]),)
1984+
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"):
1985+
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
1986+
1987+
@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
1988+
def test_rewrite_assert_without_msg(self):
1989+
def f(x):
1990+
b = x.sin()
1991+
assert x[0] == 3
1992+
return x.cos() + b
1993+
1994+
args = (torch.Tensor([3, 4, 5]),)
1995+
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
1996+
self.assertTrue(same(exported(*args), f(*args)))
1997+
1998+
with self.assertRaisesRegex(AssertionError, ""):
1999+
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
2000+
2001+
@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
2002+
def test_rewrite_assert_noop(self):
2003+
def f(x):
2004+
b = x.sin()
2005+
assert True
2006+
assert x.dtype == torch.float32
2007+
return x.cos() + b
2008+
2009+
args = (torch.Tensor([3, 4, 5]),)
2010+
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
2011+
self.assertTrue(same(exported(*args), f(*args)))
2012+
2013+
cnt = torch._dynamo.testing.CompileCounter()
2014+
opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
2015+
self.assertTrue(same(f(*args), opt_f(*args)))
2016+
# torch._assert shouldn't be in the graph
2017+
self.assertEqual(cnt.op_count, 3)
2018+
self.assertEqual(cnt.frame_count, 1)
2019+
2020+
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
2021+
self.assertTrue(same(exported(*args), f(*args)))
2022+
2023+
@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", False)
2024+
def test_not_rewrite_assert(self):
2025+
def f(x):
2026+
b = x.sin()
2027+
assert x[0] == 3
2028+
return x.cos() + b
2029+
2030+
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"):
2031+
torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
2032+
19412033

19422034
if __name__ == "__main__":
19432035
from torch._dynamo.test_case import run_tests

torch/_dynamo/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@
8787
# if an exception is encountered
8888
replay_record_enabled = False
8989

90+
# Rewrite assert statement in python with torch._assert
91+
rewrite_assert_with_torch_assert = True
92+
9093
# Show a warning on every graph break
9194
print_graph_breaks = False
9295

torch/_dynamo/symbolic_convert.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
fake_tensors_available,
5454
graph_break_dup_warning_checker,
5555
istype,
56+
proxy_args_kwargs,
5657
)
5758
from .variables.base import MutableLocal, typestr, VariableTracker
5859
from .variables.builder import VariableBuilder, wrap_fx_proxy
@@ -121,10 +122,103 @@ def impl(self: "InstructionTranslatorBase", inst: Instruction):
121122
return impl
122123

123124

125+
def _detect_and_normalize_assert_statement(
126+
self: "InstructionTranslatorBase", truth_fn: typing.Callable, push: bool
127+
):
128+
# Detect if this jump instruction is assert and normalize the assert
129+
# by pushing dummy error message when nothing is given.
130+
#
131+
# Python 3.9 assertion is in following format:
132+
# 18 POP_JUMP_IF_TRUE 28
133+
# 20 LOAD_ASSERTION_ERROR
134+
# 22 LOAD_CONST 3 ('Assert message') -> optional instruction
135+
# 24 CALL_FUNCTION 1 -> optional instruction
136+
# 26 RAISE_VARARGS
137+
#
138+
# Python 3.8 assertion is in following format:
139+
# 18 POP_JUMP_IF_TRUE 28
140+
# 20 LOAD_GLOBAL 0 (Assertion type)
141+
# 22 LOAD_CONST 3 ('Assert message') -> optional instruction
142+
# 24 CALL_FUNCTION 1 -> optional instruction
143+
# 26 RAISE_VARARGS 1
144+
145+
if (truth_fn is not operator.truth) or push:
146+
return False
147+
148+
current_instruction_pointer = self.instruction_pointer
149+
inst = self.instructions[current_instruction_pointer]
150+
# Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0
151+
if sys.version_info < (3, 9):
152+
if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError":
153+
return False
154+
else:
155+
if inst.opname != "LOAD_ASSERTION_ERROR":
156+
return False
157+
158+
current_instruction_pointer += 1
159+
160+
if current_instruction_pointer >= len(self.instructions):
161+
return False
162+
163+
inst = self.instructions[current_instruction_pointer]
164+
has_error_msg = False
165+
# DETECT RAISE_VARARGS or LOAD CONST
166+
if inst.opname == "LOAD_CONST":
167+
if not isinstance(inst.argval, str):
168+
return False
169+
self.LOAD_CONST(inst)
170+
has_error_msg = True
171+
172+
# if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION
173+
current_instruction_pointer += 1
174+
if current_instruction_pointer >= len(self.instructions):
175+
return False
176+
inst = self.instructions[current_instruction_pointer]
177+
if inst.opname != "CALL_FUNCTION":
178+
return False
179+
180+
# CALL_FUNCTION should be followed by RAISE_VARARGS
181+
current_instruction_pointer += 1
182+
if current_instruction_pointer >= len(self.instructions):
183+
return False
184+
inst = self.instructions[current_instruction_pointer]
185+
186+
if inst.opname != "RAISE_VARARGS":
187+
return False
188+
189+
if not has_error_msg:
190+
# Push dummy value instead of error message
191+
self.push(ConstantVariable("assertion error"))
192+
193+
return True
194+
195+
124196
def generic_jump(truth_fn: typing.Callable, push: bool):
125197
def inner(self: "InstructionTranslatorBase", inst: Instruction):
126198
value: VariableTracker = self.pop()
127199
self.output.guards.update(value.guards)
200+
if (
201+
config.rewrite_assert_with_torch_assert
202+
and _detect_and_normalize_assert_statement(self, truth_fn, push)
203+
):
204+
error_msg: VariableTracker = self.pop()
205+
self.output.guards.update(error_msg.guards)
206+
# Skip over things like `assert True`
207+
if value.is_python_constant() and bool(value.as_python_constant()):
208+
self.jump(inst)
209+
return
210+
211+
# Manually insert torch._assert instead of python assert and jump over
212+
# assert related instructions as we don't need them anymore.
213+
self.output.create_proxy(
214+
"call_function",
215+
torch._assert,
216+
*proxy_args_kwargs((value, error_msg), {}),
217+
current_tx=self,
218+
)
219+
self.jump(inst)
220+
return
221+
128222
if value.is_python_constant():
129223
if truth_fn(value.as_python_constant()):
130224
push and self.push(value)

0 commit comments

Comments
 (0)