未验证 提交 2989c012 编写于 作者: Z Zhen Wang 提交者: GitHub

[DygraphToStatic]Add cast transform for dygraph_to_static. (#25325)

* add cast transform and its UT for dygraph_to_static. 
上级 bdad383c
......@@ -21,9 +21,10 @@ from __future__ import print_function
import gast
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer
from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer
......@@ -93,6 +94,9 @@ class DygraphToStaticAst(gast.NodeTransformer):
# Transform call recursively
CallTransformer(node_wrapper).transform()
# Transform python type casting statement
CastTransformer(node_wrapper).transform()
def visit_FunctionDef(self, node):
if self.decorate_func_name is None:
self.decorate_func_name = node.name
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
class CastTransformer(gast.NodeTransformer):
"""
This class transforms type casting into Static Graph Ast.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of CastTransformer."
self._root = wrapper_root.node
self._castable_type = {'bool', 'int', 'float'}
def transform(self):
self.visit(self._root)
def visit_Call(self, node):
self.generic_visit(node)
func_str = ast_to_source_code(node.func).strip()
if func_str in self._castable_type and len(node.args) > 0:
args_str = ast_to_source_code(node.args[0]).strip()
new_func_str = "fluid.dygraph.dygraph_to_static.convert_operators.convert_var_dtype({}, '{}')".format(
args_str, func_str)
new_node = gast.parse(new_func_str).body[0].value
return new_node
return node
......@@ -238,3 +238,24 @@ def cast_bool_if_necessary(var):
if convert_dtype(var.dtype) not in ['bool']:
var = cast(var, dtype="bool")
return var
def convert_var_dtype(var, dtype):
if isinstance(var, Variable):
src_dtype = convert_dtype(var.dtype)
assert src_dtype in [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'
], "The dtype of var {} is {}, which is not supported in the cast op.".format(
var.name, src_dtype)
assert dtype in [
'bool', 'int', 'float'
], "The casted target dtype is {}, which is not supported in type casting.".format(
dtype)
cast_map = {
'bool': 'bool',
'int': 'int32',
'float': 'float32',
}
return cast(var, dtype=cast_map[dtype])
else:
return eval('{}(var)'.format(dtype))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph import declarative
SEED = 2020
np.random.seed(SEED)
@declarative
def test_bool_cast(x):
x = fluid.dygraph.to_variable(x)
x = bool(x)
return x
@declarative
def test_int_cast(x):
x = fluid.dygraph.to_variable(x)
x = int(x)
return x
@declarative
def test_float_cast(x):
x = fluid.dygraph.to_variable(x)
x = float(x)
return x
@declarative
def test_not_var_cast(x):
x = int(x)
return x
@declarative
def test_mix_cast(x):
x = fluid.dygraph.to_variable(x)
x = int(x)
x = float(x)
x = bool(x)
x = float(x)
return x
class TestCastBase(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.prepare()
self.set_func()
def prepare(self):
self.input_shape = (16, 32)
self.input_dtype = 'float32'
self.input = np.random.binomial(
4, 0.3, size=np.product(self.input_shape)).reshape(
self.input_shape).astype(self.input_dtype)
self.cast_dtype = 'bool'
def set_func(self):
self.func = test_bool_cast
def do_test(self):
with fluid.dygraph.guard():
res = self.func(self.input)
return res
def test_cast_result(self):
res = self.do_test().numpy()
self.assertTrue(
res.dtype == self.cast_dtype,
msg='The target dtype is {}, but the casted dtype is {}.'.format(
self.cast_dtype, res.dtype))
ref_val = self.input.astype(self.cast_dtype)
self.assertTrue(
np.allclose(res, ref_val),
msg='The casted value is {}.\nThe correct value is {}.'.format(
res, ref_val))
class TestIntCast(TestCastBase):
def prepare(self):
self.input_shape = (1, )
self.input_dtype = 'float32'
self.input = np.random.normal(
loc=6, scale=10, size=np.product(self.input_shape)).reshape(
self.input_shape).astype(self.input_dtype)
self.cast_dtype = 'int32'
def set_func(self):
self.func = test_int_cast
class TestFloatCast(TestCastBase):
def prepare(self):
self.input_shape = (8, 16)
self.input_dtype = 'bool'
self.input = np.random.binomial(
2, 0.5, size=np.product(self.input_shape)).reshape(
self.input_shape).astype(self.input_dtype)
self.cast_dtype = 'float32'
def set_func(self):
self.func = test_float_cast
class TestMixCast(TestCastBase):
def prepare(self):
self.input_shape = (8, 32)
self.input_dtype = 'float32'
self.input = np.random.normal(
loc=6, scale=10, size=np.product(self.input_shape)).reshape(
self.input_shape).astype(self.input_dtype)
self.cast_int = 'int'
self.cast_float = 'float32'
self.cast_bool = 'bool'
self.cast_dtype = 'float32'
def set_func(self):
self.func = test_mix_cast
def test_cast_result(self):
res = self.do_test().numpy()
self.assertTrue(
res.dtype == self.cast_dtype,
msg='The target dtype is {}, but the casted dtype is {}.'.format(
self.cast_dtype, res.dtype))
ref_val = self.input.astype(self.cast_int).astype(
self.cast_float).astype(self.cast_bool).astype(self.cast_dtype)
self.assertTrue(
np.allclose(res, ref_val),
msg='The casted value is {}.\nThe correct value is {}.'.format(
res, ref_val))
class TestNotVarCast(TestCastBase):
def prepare(self):
self.input = 3.14
self.cast_dtype = 'int'
def set_func(self):
self.func = test_not_var_cast
def test_cast_result(self):
res = self.do_test()
self.assertTrue(type(res) == int, msg='The casted dtype is not int.')
ref_val = int(self.input)
self.assertTrue(
res == ref_val,
msg='The casted value is {}.\nThe correct value is {}.'.format(
res, ref_val))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册