未验证 提交 2d0f849e 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2Stat] Add assert for ProgramTranslator (#24492)

Add assert grammar for ProgramTranslator
上级 53e3c534
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
class AssertTransformer(gast.NodeTransformer):
"""
A class transforms python assert to fluid.layers.Assert.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of AssertTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
def transform(self):
self.visit(self.root)
def visit_Assert(self, node):
if not self.static_analysis_visitor.is_tensor_node(node.test):
return node
cast_node = gast.Call(
func=gast.parse("fluid.layers.cast").body[0].value,
args=[node.test, gast.Constant(
value="bool", kind=None)],
keywords=[])
assert_node = gast.Call(
func=gast.parse("fluid.layers.Assert").body[0].value,
args=[cast_node],
keywords=[])
return gast.Expr(value=assert_node)
......@@ -22,14 +22,15 @@ import gast
import inspect
import textwrap
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
......@@ -80,6 +81,9 @@ class DygraphToStaticAst(gast.NodeTransformer):
# Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer(node_wrapper).transform()
# Transform python assert statement
AssertTransformer(node_wrapper).transform()
# Transform all python print statement
PrintTransformer(node_wrapper).transform()
......
......@@ -256,6 +256,14 @@ class StaticAnalysisVisitor(object):
def get_var_env(self):
return self.var_env
def is_tensor_node(self, node):
tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
node_wrapper = self.node_to_wrapper_map.get(node, None)
if node_wrapper is None:
return False
if node_wrapper.node_var_type & tensor_types:
return True
def _get_constant_node_type(self, node):
assert isinstance(node, gast.Constant), \
"Type of input node should be gast.Constant, but received %s" % type(node)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy
import unittest
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import declarative
@declarative
def dyfunc_assert_variable(x):
x_v = fluid.dygraph.to_variable(x)
assert x_v
@declarative
def dyfunc_assert_non_variable(x=True):
assert x
class TestAssertVariable(unittest.TestCase):
def _run(self, func, x, with_exception, to_static):
ProgramTranslator().enable(to_static)
if with_exception:
with self.assertRaises(BaseException):
with fluid.dygraph.guard():
func(x)
else:
with fluid.dygraph.guard():
func(x)
def _run_dy_static(self, func, x, with_exception):
self._run(func, x, with_exception, True)
self._run(func, x, with_exception, False)
def test_non_variable(self):
self._run_dy_static(
dyfunc_assert_non_variable, x=False, with_exception=True)
self._run_dy_static(
dyfunc_assert_non_variable, x=True, with_exception=False)
def test_bool_variable(self):
self._run_dy_static(
dyfunc_assert_variable, x=numpy.array([False]), with_exception=True)
self._run_dy_static(
dyfunc_assert_variable, x=numpy.array([True]), with_exception=False)
def test_int_variable(self):
self._run_dy_static(
dyfunc_assert_variable, x=numpy.array([0]), with_exception=True)
self._run_dy_static(
dyfunc_assert_variable, x=numpy.array([1]), with_exception=False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册