From 0f8dc611c87309defdf0cfe1a8e6efc17edb6376 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Thu, 16 Jul 2020 10:21:17 +0800 Subject: [PATCH] [Dy2Static] Convert assert stmt with new function `convert_assert`. (#25551) --- .../dygraph_to_static/assert_transformer.py | 26 +++++++------------ .../dygraph_to_static/convert_operators.py | 14 +++++++++- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/assert_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/assert_transformer.py index 61ff82f5be..73dba66d3f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/assert_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/assert_transformer.py @@ -17,12 +17,12 @@ from __future__ import print_function import gast from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper -from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code class AssertTransformer(gast.NodeTransformer): """ - A class transforms python assert to fluid.layers.Assert. + A class transforms python assert to convert_assert. """ def __init__(self, wrapper_root): @@ -32,21 +32,15 @@ class AssertTransformer(gast.NodeTransformer): self.wrapper_root = wrapper_root self.root = wrapper_root.node - self.static_analysis_visitor = StaticAnalysisVisitor(self.root) - def transform(self): self.visit(self.root) def visit_Assert(self, node): - if not self.static_analysis_visitor.is_tensor_node(node.test): - return node - cast_node = gast.Call( - func=gast.parse("fluid.layers.cast").body[0].value, - args=[node.test, gast.Constant( - value="bool", kind=None)], - keywords=[]) - assert_node = gast.Call( - func=gast.parse("fluid.layers.Assert").body[0].value, - args=[cast_node], - keywords=[]) - return gast.Expr(value=assert_node) + convert_assert_node = gast.parse( + 'fluid.dygraph.dygraph_to_static.convert_operators.convert_assert({test}, {msg})'. + format( + test=ast_to_source_code(node.test), + msg=ast_to_source_code(node.msg) + if node.msg else "")).body[0].value + + return gast.Expr(value=convert_assert_node) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 78031a5b38..1291be60c6 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -15,7 +15,7 @@ from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable from paddle.fluid.framework import Variable, core -from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn +from paddle.fluid.layers import Assert, cast, control_flow, logical_and, logical_not, logical_or, nn def convert_while_loop(cond, body, loop_vars): @@ -259,3 +259,15 @@ def convert_var_dtype(var, dtype): return cast(var, dtype=cast_map[dtype]) else: return eval('{}(var)'.format(dtype)) + + +def convert_assert(cond, message=""): + """ + A function representation of a Python ``assert`` statement. + """ + if isinstance(cond, Variable): + cond = cast(cond, "bool") + # NOTE: message is not used because Paddle Assert has no corresponding parameter to use. + return Assert(cond) + else: + assert cond, message -- GitLab