diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/assert_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/assert_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..edf5b85047049b04be9c6d60eb518aa8137f3a80 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/assert_transformer.py @@ -0,0 +1,53 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import gast + +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor + + +class AssertTransformer(gast.NodeTransformer): + """ + A class transforms python assert to fluid.layers.Assert. + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Input non-AstNodeWrapper node for the initialization of AssertTransformer." + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + + self.static_analysis_visitor = StaticAnalysisVisitor(self.root) + + def transform(self): + self.visit(self.root) + + def visit_Assert(self, node): + if not self.static_analysis_visitor.is_tensor_node(node.test): + return node + cast_node = gast.Call( + func=gast.parse("fluid.layers.cast").body[0].value, + args=[node.test, gast.Constant( + value="bool", kind=None)], + keywords=[]) + assert_node = gast.Call( + func=gast.parse("fluid.layers.Assert").body[0].value, + args=[cast_node], + keywords=[]) + return gast.Expr(value=assert_node) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 50a6fe7a9cf026ee72372ffe2c4664146c9e5051..744fad485f037ae473655adf6c2df25ffb7c5325 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -22,14 +22,15 @@ import gast import inspect import textwrap +from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer +from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer -from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer -from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer +from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func @@ -80,6 +81,9 @@ class DygraphToStaticAst(gast.NodeTransformer): # Transform all if/else statement of Dygraph into Static Graph. IfElseTransformer(node_wrapper).transform() + # Transform python assert statement + AssertTransformer(node_wrapper).transform() + # Transform all python print statement PrintTransformer(node_wrapper).transform() diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py index 920af14e46939a4fc95703588a3e3507d6a98860..4b3b9fcf298855ae09856636e3e7af40ae8ae6da 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py @@ -256,6 +256,14 @@ class StaticAnalysisVisitor(object): def get_var_env(self): return self.var_env + def is_tensor_node(self, node): + tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES} + node_wrapper = self.node_to_wrapper_map.get(node, None) + if node_wrapper is None: + return False + if node_wrapper.node_var_type & tensor_types: + return True + def _get_constant_node_type(self, node): assert isinstance(node, gast.Constant), \ "Type of input node should be gast.Constant, but received %s" % type(node) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_assert.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_assert.py new file mode 100644 index 0000000000000000000000000000000000000000..68e6f328726f5b2664d31ac46394fa451631388c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_assert.py @@ -0,0 +1,71 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy +import unittest + +import paddle.fluid as fluid +from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator +from paddle.fluid.dygraph.jit import declarative + + +@declarative +def dyfunc_assert_variable(x): + x_v = fluid.dygraph.to_variable(x) + assert x_v + + +@declarative +def dyfunc_assert_non_variable(x=True): + assert x + + +class TestAssertVariable(unittest.TestCase): + def _run(self, func, x, with_exception, to_static): + ProgramTranslator().enable(to_static) + if with_exception: + with self.assertRaises(BaseException): + with fluid.dygraph.guard(): + func(x) + else: + with fluid.dygraph.guard(): + func(x) + + def _run_dy_static(self, func, x, with_exception): + self._run(func, x, with_exception, True) + self._run(func, x, with_exception, False) + + def test_non_variable(self): + self._run_dy_static( + dyfunc_assert_non_variable, x=False, with_exception=True) + self._run_dy_static( + dyfunc_assert_non_variable, x=True, with_exception=False) + + def test_bool_variable(self): + self._run_dy_static( + dyfunc_assert_variable, x=numpy.array([False]), with_exception=True) + self._run_dy_static( + dyfunc_assert_variable, x=numpy.array([True]), with_exception=False) + + def test_int_variable(self): + self._run_dy_static( + dyfunc_assert_variable, x=numpy.array([0]), with_exception=True) + self._run_dy_static( + dyfunc_assert_variable, x=numpy.array([1]), with_exception=False) + + +if __name__ == '__main__': + unittest.main()