@@ -1938,6 +1938,98 @@ def fn(x):
1938
1938
self .assertEqual (cnt .frame_count , 1 )
1939
1939
self .assertEqual (cnt .op_count , 1 )
1940
1940
1941
+ @patch .object (torch ._dynamo .config , "rewrite_assert_with_torch_assert" , True )
1942
+ def test_rewrite_assert_with_msg (self ):
1943
+ def f (x ):
1944
+ b = x .sin ()
1945
+ assert x [0 ] == 3 , "First dim need to be 3"
1946
+ return x .cos () + b
1947
+
1948
+ args = (torch .Tensor ([3 , 4 , 5 ]),)
1949
+ cnt = torch ._dynamo .testing .CompileCounter ()
1950
+
1951
+ opt_f = torch ._dynamo .optimize (cnt , nopython = True )(f )
1952
+ self .assertTrue (same (f (* args ), opt_f (* args )))
1953
+ self .assertEqual (cnt .op_count , 6 )
1954
+ self .assertEqual (cnt .frame_count , 1 )
1955
+
1956
+ exported , _ = torch ._dynamo .export (f , torch .Tensor ([3 , 4 , 5 ]))
1957
+ self .assertTrue (same (exported (* args ), f (* args )))
1958
+
1959
+ with self .assertRaisesRegex (AssertionError , "" ):
1960
+ exported , _ = torch ._dynamo .export (f , torch .Tensor ([4 , 4 , 5 ]))
1961
+
1962
+ @patch .object (torch ._dynamo .config , "rewrite_assert_with_torch_assert" , True )
1963
+ def test_not_rewrite_assert_for_other_errors (self ):
1964
+ def f (x ):
1965
+ b = x .sin ()
1966
+ if not x .sum () <= 3 :
1967
+ raise ValueError ("input sum needs to be 3" )
1968
+ return x .cos () + b
1969
+
1970
+ args = (torch .Tensor ([3 , 4 , 5 ]),)
1971
+ opt_fn = torch ._dynamo .optimize ("eager" )(f )
1972
+ with self .assertRaisesRegex (ValueError , "input sum needs to be 3" ):
1973
+ opt_fn (* args )
1974
+
1975
+ # TODO (tmanlaibaatar) handle data-dependent fstring in assert statement.
1976
+ @patch .object (torch ._dynamo .config , "rewrite_assert_with_torch_assert" , True )
1977
+ def test_rewrite_assert_with_fstring_msg (self ):
1978
+ def f (x ):
1979
+ b = x .sin ()
1980
+ assert x [0 ] == 3 , f"First dim need to be { x [0 ]} "
1981
+ return x .cos () + b
1982
+
1983
+ args = (torch .Tensor ([3 , 4 , 5 ]),)
1984
+ with self .assertRaisesRegex (torch ._dynamo .exc .Unsupported , "generic_jump" ):
1985
+ exported , _ = torch ._dynamo .export (f , torch .Tensor ([3 , 4 , 5 ]))
1986
+
1987
+ @patch .object (torch ._dynamo .config , "rewrite_assert_with_torch_assert" , True )
1988
+ def test_rewrite_assert_without_msg (self ):
1989
+ def f (x ):
1990
+ b = x .sin ()
1991
+ assert x [0 ] == 3
1992
+ return x .cos () + b
1993
+
1994
+ args = (torch .Tensor ([3 , 4 , 5 ]),)
1995
+ exported , _ = torch ._dynamo .export (f , torch .Tensor ([3 , 4 , 5 ]))
1996
+ self .assertTrue (same (exported (* args ), f (* args )))
1997
+
1998
+ with self .assertRaisesRegex (AssertionError , "" ):
1999
+ exported , _ = torch ._dynamo .export (f , torch .Tensor ([4 , 4 , 5 ]))
2000
+
2001
+ @patch .object (torch ._dynamo .config , "rewrite_assert_with_torch_assert" , True )
2002
+ def test_rewrite_assert_noop (self ):
2003
+ def f (x ):
2004
+ b = x .sin ()
2005
+ assert True
2006
+ assert x .dtype == torch .float32
2007
+ return x .cos () + b
2008
+
2009
+ args = (torch .Tensor ([3 , 4 , 5 ]),)
2010
+ exported , _ = torch ._dynamo .export (f , torch .Tensor ([3 , 4 , 5 ]))
2011
+ self .assertTrue (same (exported (* args ), f (* args )))
2012
+
2013
+ cnt = torch ._dynamo .testing .CompileCounter ()
2014
+ opt_f = torch ._dynamo .optimize (cnt , nopython = True )(f )
2015
+ self .assertTrue (same (f (* args ), opt_f (* args )))
2016
+ # torch._assert shouldn't be in the graph
2017
+ self .assertEqual (cnt .op_count , 3 )
2018
+ self .assertEqual (cnt .frame_count , 1 )
2019
+
2020
+ exported , _ = torch ._dynamo .export (f , torch .Tensor ([4 , 4 , 5 ]))
2021
+ self .assertTrue (same (exported (* args ), f (* args )))
2022
+
2023
+ @patch .object (torch ._dynamo .config , "rewrite_assert_with_torch_assert" , False )
2024
+ def test_not_rewrite_assert (self ):
2025
+ def f (x ):
2026
+ b = x .sin ()
2027
+ assert x [0 ] == 3
2028
+ return x .cos () + b
2029
+
2030
+ with self .assertRaisesRegex (torch ._dynamo .exc .Unsupported , "generic_jump" ):
2031
+ torch ._dynamo .export (f , torch .Tensor ([3 , 4 , 5 ]))
2032
+
1941
2033
1942
2034
if __name__ == "__main__" :
1943
2035
from torch ._dynamo .test_case import run_tests
0 commit comments