未验证 提交 0f8dc611 编写于 作者: L liym27 提交者: GitHub

[Dy2Static] Convert assert stmt with new function `convert_assert`. (#25551)

上级 cf3c51a6
...@@ -17,12 +17,12 @@ from __future__ import print_function ...@@ -17,12 +17,12 @@ from __future__ import print_function
import gast import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
class AssertTransformer(gast.NodeTransformer): class AssertTransformer(gast.NodeTransformer):
""" """
A class transforms python assert to fluid.layers.Assert. A class transforms python assert to convert_assert.
""" """
def __init__(self, wrapper_root): def __init__(self, wrapper_root):
...@@ -32,21 +32,15 @@ class AssertTransformer(gast.NodeTransformer): ...@@ -32,21 +32,15 @@ class AssertTransformer(gast.NodeTransformer):
self.wrapper_root = wrapper_root self.wrapper_root = wrapper_root
self.root = wrapper_root.node self.root = wrapper_root.node
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
def transform(self): def transform(self):
self.visit(self.root) self.visit(self.root)
def visit_Assert(self, node): def visit_Assert(self, node):
if not self.static_analysis_visitor.is_tensor_node(node.test): convert_assert_node = gast.parse(
return node 'fluid.dygraph.dygraph_to_static.convert_operators.convert_assert({test}, {msg})'.
cast_node = gast.Call( format(
func=gast.parse("fluid.layers.cast").body[0].value, test=ast_to_source_code(node.test),
args=[node.test, gast.Constant( msg=ast_to_source_code(node.msg)
value="bool", kind=None)], if node.msg else "")).body[0].value
keywords=[])
assert_node = gast.Call( return gast.Expr(value=convert_assert_node)
func=gast.parse("fluid.layers.Assert").body[0].value,
args=[cast_node],
keywords=[])
return gast.Expr(value=assert_node)
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.framework import Variable, core from paddle.fluid.framework import Variable, core
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn from paddle.fluid.layers import Assert, cast, control_flow, logical_and, logical_not, logical_or, nn
def convert_while_loop(cond, body, loop_vars): def convert_while_loop(cond, body, loop_vars):
...@@ -259,3 +259,15 @@ def convert_var_dtype(var, dtype): ...@@ -259,3 +259,15 @@ def convert_var_dtype(var, dtype):
return cast(var, dtype=cast_map[dtype]) return cast(var, dtype=cast_map[dtype])
else: else:
return eval('{}(var)'.format(dtype)) return eval('{}(var)'.format(dtype))
def convert_assert(cond, message=""):
"""
A function representation of a Python ``assert`` statement.
"""
if isinstance(cond, Variable):
cond = cast(cond, "bool")
# NOTE: message is not used because Paddle Assert has no corresponding parameter to use.
return Assert(cond)
else:
assert cond, message
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册