From 2989c012f27b6a5e14e6ec40ea352ce525de6905 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Sun, 5 Jul 2020 17:33:15 +0800 Subject: [PATCH] [DygraphToStatic]Add cast transform for dygraph_to_static. (#25325) * add cast transform and its UT for dygraph_to_static. --- .../dygraph_to_static/ast_transformer.py | 6 +- .../dygraph_to_static/cast_transformer.py | 47 +++++ .../dygraph_to_static/convert_operators.py | 21 +++ .../unittests/dygraph_to_static/test_cast.py | 173 ++++++++++++++++++ 4 files changed, 246 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/cast_transformer.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_cast.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 9d0eec7f779..f859d40050c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -21,9 +21,10 @@ from __future__ import print_function import gast from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer -from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer +from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer +from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer @@ -93,6 +94,9 @@ class DygraphToStaticAst(gast.NodeTransformer): # Transform call recursively CallTransformer(node_wrapper).transform() + # Transform python type casting statement + CastTransformer(node_wrapper).transform() + def visit_FunctionDef(self, node): if self.decorate_func_name is None: self.decorate_func_name = node.name diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/cast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/cast_transformer.py new file mode 100644 index 00000000000..71cb999eab0 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/cast_transformer.py @@ -0,0 +1,47 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import gast + +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code + + +class CastTransformer(gast.NodeTransformer): + """ + This class transforms type casting into Static Graph Ast. + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Input non-AstNodeWrapper node for the initialization of CastTransformer." + self._root = wrapper_root.node + self._castable_type = {'bool', 'int', 'float'} + + def transform(self): + self.visit(self._root) + + def visit_Call(self, node): + self.generic_visit(node) + func_str = ast_to_source_code(node.func).strip() + if func_str in self._castable_type and len(node.args) > 0: + args_str = ast_to_source_code(node.args[0]).strip() + new_func_str = "fluid.dygraph.dygraph_to_static.convert_operators.convert_var_dtype({}, '{}')".format( + args_str, func_str) + new_node = gast.parse(new_func_str).body[0].value + return new_node + + return node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index c05173a28e2..78031a5b388 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -238,3 +238,24 @@ def cast_bool_if_necessary(var): if convert_dtype(var.dtype) not in ['bool']: var = cast(var, dtype="bool") return var + + +def convert_var_dtype(var, dtype): + if isinstance(var, Variable): + src_dtype = convert_dtype(var.dtype) + assert src_dtype in [ + 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8' + ], "The dtype of var {} is {}, which is not supported in the cast op.".format( + var.name, src_dtype) + assert dtype in [ + 'bool', 'int', 'float' + ], "The casted target dtype is {}, which is not supported in type casting.".format( + dtype) + cast_map = { + 'bool': 'bool', + 'int': 'int32', + 'float': 'float32', + } + return cast(var, dtype=cast_map[dtype]) + else: + return eval('{}(var)'.format(dtype)) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cast.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cast.py new file mode 100644 index 00000000000..b4cc38b3a60 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cast.py @@ -0,0 +1,173 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.dygraph import declarative + +SEED = 2020 +np.random.seed(SEED) + + +@declarative +def test_bool_cast(x): + x = fluid.dygraph.to_variable(x) + x = bool(x) + return x + + +@declarative +def test_int_cast(x): + x = fluid.dygraph.to_variable(x) + x = int(x) + return x + + +@declarative +def test_float_cast(x): + x = fluid.dygraph.to_variable(x) + x = float(x) + return x + + +@declarative +def test_not_var_cast(x): + x = int(x) + return x + + +@declarative +def test_mix_cast(x): + x = fluid.dygraph.to_variable(x) + x = int(x) + x = float(x) + x = bool(x) + x = float(x) + return x + + +class TestCastBase(unittest.TestCase): + def setUp(self): + self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( + ) else fluid.CPUPlace() + self.prepare() + self.set_func() + + def prepare(self): + self.input_shape = (16, 32) + self.input_dtype = 'float32' + self.input = np.random.binomial( + 4, 0.3, size=np.product(self.input_shape)).reshape( + self.input_shape).astype(self.input_dtype) + self.cast_dtype = 'bool' + + def set_func(self): + self.func = test_bool_cast + + def do_test(self): + with fluid.dygraph.guard(): + res = self.func(self.input) + return res + + def test_cast_result(self): + res = self.do_test().numpy() + self.assertTrue( + res.dtype == self.cast_dtype, + msg='The target dtype is {}, but the casted dtype is {}.'.format( + self.cast_dtype, res.dtype)) + ref_val = self.input.astype(self.cast_dtype) + self.assertTrue( + np.allclose(res, ref_val), + msg='The casted value is {}.\nThe correct value is {}.'.format( + res, ref_val)) + + +class TestIntCast(TestCastBase): + def prepare(self): + self.input_shape = (1, ) + self.input_dtype = 'float32' + self.input = np.random.normal( + loc=6, scale=10, size=np.product(self.input_shape)).reshape( + self.input_shape).astype(self.input_dtype) + self.cast_dtype = 'int32' + + def set_func(self): + self.func = test_int_cast + + +class TestFloatCast(TestCastBase): + def prepare(self): + self.input_shape = (8, 16) + self.input_dtype = 'bool' + self.input = np.random.binomial( + 2, 0.5, size=np.product(self.input_shape)).reshape( + self.input_shape).astype(self.input_dtype) + self.cast_dtype = 'float32' + + def set_func(self): + self.func = test_float_cast + + +class TestMixCast(TestCastBase): + def prepare(self): + self.input_shape = (8, 32) + self.input_dtype = 'float32' + self.input = np.random.normal( + loc=6, scale=10, size=np.product(self.input_shape)).reshape( + self.input_shape).astype(self.input_dtype) + self.cast_int = 'int' + self.cast_float = 'float32' + self.cast_bool = 'bool' + self.cast_dtype = 'float32' + + def set_func(self): + self.func = test_mix_cast + + def test_cast_result(self): + res = self.do_test().numpy() + self.assertTrue( + res.dtype == self.cast_dtype, + msg='The target dtype is {}, but the casted dtype is {}.'.format( + self.cast_dtype, res.dtype)) + ref_val = self.input.astype(self.cast_int).astype( + self.cast_float).astype(self.cast_bool).astype(self.cast_dtype) + self.assertTrue( + np.allclose(res, ref_val), + msg='The casted value is {}.\nThe correct value is {}.'.format( + res, ref_val)) + + +class TestNotVarCast(TestCastBase): + def prepare(self): + self.input = 3.14 + self.cast_dtype = 'int' + + def set_func(self): + self.func = test_not_var_cast + + def test_cast_result(self): + res = self.do_test() + self.assertTrue(type(res) == int, msg='The casted dtype is not int.') + ref_val = int(self.input) + self.assertTrue( + res == ref_val, + msg='The casted value is {}.\nThe correct value is {}.'.format( + res, ref_val)) + + +if __name__ == '__main__': + unittest.main() -- GitLab