未验证 提交 0f8dc611 编写于 作者: L liym27 提交者: GitHub

[Dy2Static] Convert assert stmt with new function `convert_assert`. (#25551)

上级 cf3c51a6
......@@ -17,12 +17,12 @@ from __future__ import print_function
import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
class AssertTransformer(gast.NodeTransformer):
"""
A class transforms python assert to fluid.layers.Assert.
A class transforms python assert to convert_assert.
"""
def __init__(self, wrapper_root):
......@@ -32,21 +32,15 @@ class AssertTransformer(gast.NodeTransformer):
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
def transform(self):
self.visit(self.root)
def visit_Assert(self, node):
if not self.static_analysis_visitor.is_tensor_node(node.test):
return node
cast_node = gast.Call(
func=gast.parse("fluid.layers.cast").body[0].value,
args=[node.test, gast.Constant(
value="bool", kind=None)],
keywords=[])
assert_node = gast.Call(
func=gast.parse("fluid.layers.Assert").body[0].value,
args=[cast_node],
keywords=[])
return gast.Expr(value=assert_node)
convert_assert_node = gast.parse(
'fluid.dygraph.dygraph_to_static.convert_operators.convert_assert({test}, {msg})'.
format(
test=ast_to_source_code(node.test),
msg=ast_to_source_code(node.msg)
if node.msg else "")).body[0].value
return gast.Expr(value=convert_assert_node)
......@@ -15,7 +15,7 @@
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.framework import Variable, core
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
from paddle.fluid.layers import Assert, cast, control_flow, logical_and, logical_not, logical_or, nn
def convert_while_loop(cond, body, loop_vars):
......@@ -259,3 +259,15 @@ def convert_var_dtype(var, dtype):
return cast(var, dtype=cast_map[dtype])
else:
return eval('{}(var)'.format(dtype))
def convert_assert(cond, message=""):
"""
A function representation of a Python ``assert`` statement.
"""
if isinstance(cond, Variable):
cond = cast(cond, "bool")
# NOTE: message is not used because Paddle Assert has no corresponding parameter to use.
return Assert(cond)
else:
assert cond, message
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册