未验证 提交 c4c30e6f 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Remove fluid.layers in transformer code and return_out_vars in convert_ifelse (#43372)

* [Dy2Stat]Remove fluid.layers in transformer code

* remove useless return_name_vars

* fix unittest

* fix unittest

* fix unittest
上级 2d96801a
...@@ -188,7 +188,7 @@ def _run_py_logical_not(x): ...@@ -188,7 +188,7 @@ def _run_py_logical_not(x):
return not x return not x
def convert_ifelse(pred, true_fn, false_fn, true_args, false_args, return_vars): def convert_ifelse(pred, true_fn, false_fn, true_args, false_args):
""" """
A function representation of a Python ``if/else`` statement. A function representation of a Python ``if/else`` statement.
...@@ -198,15 +198,13 @@ def convert_ifelse(pred, true_fn, false_fn, true_args, false_args, return_vars): ...@@ -198,15 +198,13 @@ def convert_ifelse(pred, true_fn, false_fn, true_args, false_args, return_vars):
false_fn(callable): A callable to be performed if ``pred`` is false. false_fn(callable): A callable to be performed if ``pred`` is false.
true_args(tuple): Parameters of ``true_fn``. true_args(tuple): Parameters of ``true_fn``.
false_args(tuple): Parameters of ``false_fn``. false_args(tuple): Parameters of ``false_fn``.
return_vars(tuple): Return variables of ``true_fn`` and ``false_fn``.
Returns: Returns:
``true_fn(true_args)`` if the predicate ``pred`` is true else ``false_fn(false_args)`` . ``true_fn(true_args)`` if the predicate ``pred`` is true else ``false_fn(false_args)`` .
""" """
if isinstance(pred, Variable): if isinstance(pred, Variable):
out = _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args, out = _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args)
return_vars)
else: else:
out = _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args) out = _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args)
...@@ -246,8 +244,7 @@ def _remove_no_value_return_var(out): ...@@ -246,8 +244,7 @@ def _remove_no_value_return_var(out):
return out return out
def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args, def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args):
return_vars):
pred = cast_bool_if_necessary(pred) pred = cast_bool_if_necessary(pred)
return control_flow.cond(pred, lambda: true_fn(*true_args), return control_flow.cond(pred, lambda: true_fn(*true_args),
lambda: false_fn(*false_args)) lambda: false_fn(*false_args))
......
...@@ -507,7 +507,7 @@ def create_convert_ifelse_node(return_name_ids, ...@@ -507,7 +507,7 @@ def create_convert_ifelse_node(return_name_ids,
is_if_expr=False): is_if_expr=False):
""" """
Create `paddle.jit.dy2static.convert_ifelse( Create `paddle.jit.dy2static.convert_ifelse(
pred, true_fn, false_fn, true_args, false_args, return_vars)` pred, true_fn, false_fn, true_args, false_args)`
to replace original `python if/else` statement. to replace original `python if/else` statement.
""" """
...@@ -535,17 +535,14 @@ def create_convert_ifelse_node(return_name_ids, ...@@ -535,17 +535,14 @@ def create_convert_ifelse_node(return_name_ids,
true_func_source = true_func.name true_func_source = true_func.name
false_func_source = false_func.name false_func_source = false_func.name
return_vars = create_name_nodes(return_name_ids)
convert_ifelse_layer = gast.parse( convert_ifelse_layer = gast.parse(
'_jst.convert_ifelse(' '_jst.convert_ifelse('
'{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})' '{pred}, {true_fn}, {false_fn}, {true_args}, {false_args})'.format(
.format(pred=ast_to_source_code(pred), pred=ast_to_source_code(pred),
true_fn=true_func_source, true_fn=true_func_source,
false_fn=false_func_source, false_fn=false_func_source,
true_args=ast_to_source_code(true_args), true_args=ast_to_source_code(true_args),
false_args=ast_to_source_code(false_args), false_args=ast_to_source_code(false_args))).body[0].value
return_vars=ast_to_source_code(return_vars))).body[0].value
if return_name_ids: if return_name_ids:
_, cond_node = create_assign_node(return_name_ids, convert_ifelse_layer) _, cond_node = create_assign_node(return_name_ids, convert_ifelse_layer)
......
...@@ -349,14 +349,14 @@ def create_api_shape_node(tensor_shape_node): ...@@ -349,14 +349,14 @@ def create_api_shape_node(tensor_shape_node):
if isinstance(tensor_shape_node, gast.Name): if isinstance(tensor_shape_node, gast.Name):
api_shape_node = gast.Call( api_shape_node = gast.Call(
func=gast.parse('fluid.layers.shape').body[0].value, func=gast.parse('paddle.shape').body[0].value,
args=[tensor_shape_node], args=[tensor_shape_node],
keywords=[]) keywords=[])
return api_shape_node return api_shape_node
if isinstance(tensor_shape_node, gast.Attribute): if isinstance(tensor_shape_node, gast.Attribute):
api_shape_node = gast.Call( api_shape_node = gast.Call(
func=gast.parse('fluid.layers.shape').body[0].value, func=gast.parse('paddle.shape').body[0].value,
args=[tensor_shape_node.value], args=[tensor_shape_node.value],
keywords=[]) keywords=[])
return api_shape_node return api_shape_node
...@@ -368,8 +368,8 @@ def create_api_shape_node(tensor_shape_node): ...@@ -368,8 +368,8 @@ def create_api_shape_node(tensor_shape_node):
def get_constant_variable_node(name, value, shape=[1], dtype='int64'): def get_constant_variable_node(name, value, shape=[1], dtype='int64'):
return gast.parse('%s = fluid.layers.fill_constant(%s, "%s", %s)' % return gast.parse('%s = paddle.full(%s, "%s", %s)' %
(name, str(shape), dtype, str(value))) (name, str(shape), str(value), dtype))
def get_attribute_full_name(node): def get_attribute_full_name(node):
......
...@@ -15,12 +15,11 @@ ...@@ -15,12 +15,11 @@
from __future__ import print_function from __future__ import print_function
import six import six
import paddle
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle.fluid.layers import fill_constant
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
__all__ = [ __all__ = [
...@@ -87,17 +86,19 @@ def create_static_variable_gast_node(name): ...@@ -87,17 +86,19 @@ def create_static_variable_gast_node(name):
def create_fill_constant_node(name, value): def create_fill_constant_node(name, value):
func_code = "{} = paddle.fluid.layers.fill_constant(shape=[1], ".format( func_code = "{} = paddle.full(shape=[1], ".format(name)
name)
if isinstance(value, bool): if isinstance(value, bool):
func_code += "dtype='bool', value={}, name='{}')".format(value, name) func_code += "dtype='bool', fill_value={}, name='{}')".format(
value, name)
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
if isinstance(value, float): if isinstance(value, float):
func_code += "dtype='float64', value={}, name='{}')".format(value, name) func_code += "dtype='float64', fill_value={}, name='{}')".format(
value, name)
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
if isinstance(value, int): if isinstance(value, int):
func_code += "dtype='int64', value={}, name='{}')".format(value, name) func_code += "dtype='int64', fill_value={}, name='{}')".format(
value, name)
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
...@@ -106,12 +107,12 @@ def to_static_variable(x): ...@@ -106,12 +107,12 @@ def to_static_variable(x):
Translate a Python Tensor to PaddlePaddle static graph Tensor Translate a Python Tensor to PaddlePaddle static graph Tensor
''' '''
if isinstance(x, bool): if isinstance(x, bool):
return fill_constant(shape=[1], dtype='bool', value=x) return paddle.full(shape=[1], dtype='bool', fill_value=x)
if isinstance(x, float): if isinstance(x, float):
return fill_constant(shape=[1], dtype='float64', value=x) return paddle.full(shape=[1], dtype='float64', fill_value=x)
if isinstance(x, six.integer_types): if isinstance(x, six.integer_types):
return fill_constant(shape=[1], dtype='int64', value=x) return paddle.full(shape=[1], dtype='int64', fill_value=x)
return x return x
...@@ -121,6 +122,6 @@ def create_bool_as_type(x, value=True): ...@@ -121,6 +122,6 @@ def create_bool_as_type(x, value=True):
Create a bool variable, which type is the same as x. Create a bool variable, which type is the same as x.
''' '''
if isinstance(x, Variable): if isinstance(x, Variable):
return fill_constant(shape=[1], value=value, dtype="bool") return paddle.full(shape=[1], fill_value=value, dtype="bool")
else: else:
return value return value
...@@ -66,8 +66,10 @@ def get_source_code(func): ...@@ -66,8 +66,10 @@ def get_source_code(func):
class StaticCode1(): class StaticCode1():
def dyfunc_with_if_else(x_v, label=None): def dyfunc_with_if_else(x_v, label=None):
__return_value_init_0 = paddle.fluid.layers.fill_constant( __return_value_init_0 = paddle.full(shape=[1],
shape=[1], dtype='float64', value=0.0, name='__return_value_init_0') dtype='float64',
fill_value=0.0,
name='__return_value_init_0')
__return_value_0 = __return_value_init_0 __return_value_0 = __return_value_init_0
def true_fn_0(x_v): def true_fn_0(x_v):
...@@ -80,7 +82,7 @@ class StaticCode1(): ...@@ -80,7 +82,7 @@ class StaticCode1():
x_v = _jst.convert_ifelse( x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ), fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ),
(x_v, ), (x_v, )) (x_v, ))
__return_0 = _jst.create_bool_as_type(label is not None, False) __return_0 = _jst.create_bool_as_type(label is not None, False)
def true_fn_1(__return_0, __return_value_0, label, x_v): def true_fn_1(__return_0, __return_value_0, label, x_v):
...@@ -95,7 +97,7 @@ class StaticCode1(): ...@@ -95,7 +97,7 @@ class StaticCode1():
__return_0, __return_value_0 = _jst.convert_ifelse( __return_0, __return_value_0 = _jst.convert_ifelse(
label is not None, true_fn_1, false_fn_1, label is not None, true_fn_1, false_fn_1,
(__return_0, __return_value_0, label, x_v), (__return_0, __return_value_0, label, x_v),
(__return_0, __return_value_0), (__return_0, __return_value_0)) (__return_0, __return_value_0))
def true_fn_2(__return_0, __return_value_0, x_v): def true_fn_2(__return_0, __return_value_0, x_v):
__return_1 = _jst.create_bool_as_type( __return_1 = _jst.create_bool_as_type(
...@@ -108,16 +110,17 @@ class StaticCode1(): ...@@ -108,16 +110,17 @@ class StaticCode1():
__return_value_0 = _jst.convert_ifelse( __return_value_0 = _jst.convert_ifelse(
_jst.convert_logical_not(__return_0), true_fn_2, false_fn_2, _jst.convert_logical_not(__return_0), true_fn_2, false_fn_2,
(__return_0, __return_value_0, x_v), (__return_value_0, ), (__return_0, __return_value_0, x_v), (__return_value_0, ))
(__return_value_0, ))
return __return_value_0 return __return_value_0
class StaticCode2(): class StaticCode2():
# TODO: Transform return statement # TODO: Transform return statement
def dyfunc_with_if_else(x_v, label=None): def dyfunc_with_if_else(x_v, label=None):
__return_value_init_1 = paddle.fluid.layers.fill_constant( __return_value_init_1 = paddle.full(shape=[1],
shape=[1], dtype='float64', value=0.0, name='__return_value_init_1') dtype='float64',
fill_value=0.0,
name='__return_value_init_1')
__return_value_1 = __return_value_init_1 __return_value_1 = __return_value_init_1
def true_fn_3(x_v): def true_fn_3(x_v):
...@@ -130,7 +133,7 @@ class StaticCode2(): ...@@ -130,7 +133,7 @@ class StaticCode2():
x_v = _jst.convert_ifelse( x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ), fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ),
(x_v, ), (x_v, )) (x_v, ))
__return_2 = _jst.create_bool_as_type(label is not None, False) __return_2 = _jst.create_bool_as_type(label is not None, False)
def true_fn_4(__return_2, __return_value_1, label, x_v): def true_fn_4(__return_2, __return_value_1, label, x_v):
...@@ -145,7 +148,7 @@ class StaticCode2(): ...@@ -145,7 +148,7 @@ class StaticCode2():
__return_2, __return_value_1 = _jst.convert_ifelse( __return_2, __return_value_1 = _jst.convert_ifelse(
label is not None, true_fn_4, false_fn_4, label is not None, true_fn_4, false_fn_4,
(__return_2, __return_value_1, label, x_v), (__return_2, __return_value_1, label, x_v),
(__return_2, __return_value_1), (__return_2, __return_value_1)) (__return_2, __return_value_1))
def true_fn_5(__return_2, __return_value_1, x_v): def true_fn_5(__return_2, __return_value_1, x_v):
__return_3 = _jst.create_bool_as_type( __return_3 = _jst.create_bool_as_type(
...@@ -158,8 +161,7 @@ class StaticCode2(): ...@@ -158,8 +161,7 @@ class StaticCode2():
__return_value_1 = _jst.convert_ifelse( __return_value_1 = _jst.convert_ifelse(
_jst.convert_logical_not(__return_2), true_fn_5, false_fn_5, _jst.convert_logical_not(__return_2), true_fn_5, false_fn_5,
(__return_2, __return_value_1, x_v), (__return_value_1, ), (__return_2, __return_value_1, x_v), (__return_value_1, ))
(__return_value_1, ))
return __return_value_1 return __return_value_1
......
...@@ -52,19 +52,19 @@ class TestVariableTransFunc(unittest.TestCase): ...@@ -52,19 +52,19 @@ class TestVariableTransFunc(unittest.TestCase):
def test_create_fill_constant_node(self): def test_create_fill_constant_node(self):
node = create_fill_constant_node("a", 1.0) node = create_fill_constant_node("a", 1.0)
source = "a = paddle.fluid.layers.fill_constant(shape=[1], dtype='float64', value=1.0, name='a')" source = "a = paddle.full(shape=[1], dtype='float64', fill_value=1.0, name='a')"
self.assertEqual( self.assertEqual(
ast_to_source_code(node).replace('\n', '').replace(' ', ''), ast_to_source_code(node).replace('\n', '').replace(' ', ''),
source.replace(' ', '')) source.replace(' ', ''))
node = create_fill_constant_node("b", True) node = create_fill_constant_node("b", True)
source = "b = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=True, name='b')" source = "b = paddle.full(shape=[1], dtype='bool', fill_value=True, name='b')"
self.assertEqual( self.assertEqual(
ast_to_source_code(node).replace('\n', '').replace(' ', ''), ast_to_source_code(node).replace('\n', '').replace(' ', ''),
source.replace(' ', '')) source.replace(' ', ''))
node = create_fill_constant_node("c", 4293) node = create_fill_constant_node("c", 4293)
source = "c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293, name='c')" source = "c = paddle.full(shape=[1], dtype='int64', fill_value=4293, name='c')"
self.assertEqual( self.assertEqual(
ast_to_source_code(node).replace('\n', '').replace(' ', ''), ast_to_source_code(node).replace('\n', '').replace(' ', ''),
source.replace(' ', '')) source.replace(' ', ''))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册