From 682cc17f53687f456e94e86e18ed79c6340b7a61 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Wed, 25 Nov 2020 17:33:42 +0800 Subject: [PATCH] [Dynamic-to-Static] Fix bug: support pop from a dict and polish code of convert_pop (#29023) * Support pop for dict in dy2stat * Move convert_pop to convert_operators.py and polish convert_pop --- .../dygraph_to_static/convert_operators.py | 89 ++++++++++++++- .../dygraph_to_static/list_transformer.py | 105 +++++------------- .../unittests/dygraph_to_static/test_dict.py | 64 +++++++++++ .../paddle/jit/dy2static/convert_operators.py | 3 +- 4 files changed, 181 insertions(+), 80 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index ea03d6143ad..dcb8b686eef 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -16,7 +16,10 @@ from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable from paddle.fluid.framework import core, Variable from paddle.fluid.layers import Assert, Print +from paddle.fluid.layers import array_length, array_read, array_write, create_array +from paddle.fluid.layers import assign, fill_constant, slice from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn +from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment def convert_while_loop(cond, body, loop_vars): @@ -24,12 +27,12 @@ def convert_while_loop(cond, body, loop_vars): A function representation of a Python ``while`` statement. Args: - cond(Callable): A callable object that returns a boolean variable to control whether to execute the loop body. It takes ``loop_vars`` as arguments. + cond(Callable): A callable object that returns a boolean variable to control whether to execute the loop body. It takes ``loop_vars`` as arguments. body(Callable): A callable object that returns a tuple or list of variables with the same arguments ``loops_vars`` as ``cond`` . loop_vars(list|tuple): A list or tuple of variables passed to ``cond`` and ``body`` . Returns: - A list or tuple of variables which returned by ``body`` . + A list or tuple of variables which returned by ``body``. """ # NOTE: It may be slower if cond is very expensive, but usually cond is just O(1). @@ -320,3 +323,85 @@ def convert_print(*args): var = Print(var) else: print(var) + + +def convert_pop(target, *args): + """ + A function representation of a Python pop statement for a list or dict. + + Args: + target(list|dict|Tensor): A variable to pop item from. + *args(tuple): index or default value to parse. + + Returns: + A item poped from target. + """ + + is_variable = isinstance(target, Variable) + if is_variable: + is_tensor_array = target.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY + + if is_variable and is_tensor_array: + return _run_paddle_pop(target, *args) + else: + return _run_python_pop(target, *args) + + +def _run_paddle_pop(array, *args): + if len(args) == 0: + idx = -1 + else: + idx = args[0] + + assert isinstance(idx, int) + + def cond(i, new_array): + return less_than(i, arr_len) + + def body(i, new_array): + item = array_read(array=array, i=i) + array_write(item, array_length(new_array), new_array) + i = increment(i) + return i, new_array + + arr_len = array_length(array) + if idx < 0: + idx = idx + arr_len + else: + idx = fill_constant(shape=[1], dtype="int64", value=idx) + + pop_item = array_read(array, idx) + + new_array = _slice_tensor_array(array, 0, idx) + i = idx + 1 + _, new_array = while_loop(cond, body, [i, new_array]) + assign(input=new_array, output=array) + + return pop_item + + +# TODO(liym27): A better way to slice tensor array. +# Maybe support start == end for slice op. +def _slice_tensor_array(array, start, end): + def true_fn(): + null_array = create_array("float32") + return null_array + + def false_fn(array, start, end): + new_array = slice(array, starts=[start], ends=[end], axes=[0]) + return new_array + + new_array = cond(start == end, true_fn, lambda: false_fn(array, start, end)) + return new_array + + +def _run_python_pop(target, *args): + # 1. pop for a dict + if len(args) == 2: + idx, default = args + return target.pop(idx, default) + + # 2. pop for a list or dict + else: + idx = args[0] if args else -1 + return target.pop(idx) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py index 9819f5fb72b..51d06a60fdf 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py @@ -17,74 +17,9 @@ from __future__ import print_function import astor import gast -from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code, is_control_flow_to_transform from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer -from paddle.fluid.framework import core, Variable -from paddle.fluid.layers import array_length, array_read, array_write, create_array -from paddle.fluid.layers import assign, fill_constant, slice -from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment - - -# TODO(liym27): A better way to slice tensor array. -# Maybe support start == end for slice op. -def slice_tensor_array(array, start, end): - def true_fn(): - null_array = create_array("float32") - return null_array - - def false_fn(array, start, end): - new_array = slice(array, starts=[start], ends=[end], axes=[0]) - return new_array - - new_array = cond(start == end, true_fn, lambda: false_fn(array, start, end)) - return new_array - - -def tensor_array_pop(array, idx): - assert isinstance(idx, int) - - def cond(i, new_array): - return less_than(i, arr_len) - - def body(i, new_array): - item = array_read(array=array, i=i) - array_write(item, array_length(new_array), new_array) - i = increment(i) - return i, new_array - - arr_len = array_length(array) - if idx < 0: - idx = idx + arr_len - else: - idx = fill_constant(shape=[1], dtype="int64", value=idx) - - pop_item = array_read(array, idx) - - new_array = slice_tensor_array(array, 0, idx) - i = idx + 1 - _, new_array = while_loop(cond, body, [i, new_array]) - assign(input=new_array, output=array) - - return pop_item - - -def convert_list_pop(target, idx=None): - """ - Convert list pop. - """ - - if idx is None: - idx = -1 - - is_variable = isinstance(target, Variable) - if is_variable: - is_tensor_array = target.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY - if is_variable and is_tensor_array: - result = tensor_array_pop(target, idx) - else: - result = target.pop(idx) - return result class ListTransformer(gast.NodeTransformer): @@ -117,7 +52,7 @@ class ListTransformer(gast.NodeTransformer): if isinstance(node.func, gast.Attribute): func_name = node.func.attr if func_name == "pop": - node = self._replace_list_pop(node) + node = self._replace_pop(node) return node def visit_Assign(self, node): @@ -283,20 +218,36 @@ class ListTransformer(gast.NodeTransformer): del self.list_name_to_updated[target_id] return False - def _replace_list_pop(self, node): + def _replace_pop(self, node): + """ + Replace a pop statement for a list or dict. + For example: + + list_a = [0,1,2,3,4] + x = list_a.pop() # --> convert_pop(list_a) + y = list_a.pop(1) # --> convert_pop(list_a, 1) + + dict_a = {"red":0, "blue":1, "yellow":2} + m = dict_a.pop("red") # --> convert_pop(dict_a, "red") + n = dict_a.pop("black", 3) # --> convert_pop(dict_a, "black", 3) + + """ assert isinstance(node, gast.Call) assert isinstance(node.func, gast.Attribute) target_node = node.func.value target_str = ast_to_source_code(target_node).strip() - if node.args: - idx_node = node.args[0] - idx_str = ast_to_source_code(idx_node).strip() + args_str = [ast_to_source_code(arg).strip() for arg in node.args] + + # NOTE(liym27): + # 1. pop stmt for a list if len(args_str) == 0 + # 2. pop stmt for a list or dict if len(args_str) == 1 + # 3. pop stmt for a dict if len(args_str) == 2 + if len(args_str) <= 2: + new_pop_str = "paddle.jit.dy2static.convert_pop({}, {})"\ + .format(target_str, ",".join(args_str)) + new_pop_node = gast.parse(new_pop_str).body[0].value + return new_pop_node else: - idx_str = "None" - - new_call_str = "fluid.dygraph.dygraph_to_static.list_transformer.convert_list_pop({}, {})".format( - target_str, idx_str) - new_call_node = gast.parse(new_call_str).body[0].value - return new_call_node + return node diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py index af1e44ffe21..4af955e774a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py @@ -18,6 +18,7 @@ import six import numpy as np import unittest +import paddle import paddle.fluid as fluid from paddle.jit import to_static from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator @@ -139,5 +140,68 @@ class TestNetWithDict(unittest.TestCase): self.assertTrue((self._run_dygraph() == self._run_static()).all()) +# Tests for dict pop +@paddle.jit.to_static +def test_dic_pop(x): + x = paddle.to_tensor(x) + dict_a = {"red": 0, "green": 1, "blue": 2} + + m = dict_a.pop("red") + n = dict_a.pop("black", 3) + + out = x + m + n + return out + + +@paddle.jit.to_static +def test_dic_pop_2(x): + x = paddle.to_tensor(x) + dict_a = {"red": x, "green": x + 1, "blue": x + 3} + + m = dict_a.pop("red") + n = dict_a.pop("black", 3) + + out = x + m + n + return out + + +class TestDictPop(unittest.TestCase): + def setUp(self): + self.input = np.random.random((3)).astype('int32') + self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( + ) else paddle.CPUPlace() + self._set_test_func() + + def _set_test_func(self): + self.dygraph_func = test_dic_pop + + def _run_static(self): + return self._run(to_static=True) + + def _run_dygraph(self): + return self._run(to_static=False) + + def _run(self, to_static): + prog_trans = ProgramTranslator() + prog_trans.enable(to_static) + + result = self.dygraph_func(self.input) + + return result.numpy() + + def test_transformed_result(self): + dygraph_res = self._run_dygraph() + static_res = self._run_static() + self.assertTrue( + np.allclose(dygraph_res, static_res), + msg='dygraph result is {}\nstatic result is {}'.format(dygraph_res, + static_res)) + + +class TestDictPop2(TestDictPop): + def _set_test_func(self): + self.dygraph_func = test_dic_pop_2 + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index 89df1d0aa77..443c7234454 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -20,6 +20,7 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_len #D from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical_and #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical_not #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical_or #DEFINE_ALIAS +from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_pop #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape #DEFINE_ALIAS @@ -28,6 +29,6 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_l __all__ = [ 'cast_bool_if_necessary', 'convert_assert', 'convert_ifelse', 'convert_len', 'convert_logical_and', 'convert_logical_not', 'convert_logical_or', - 'convert_print', 'convert_var_dtype', 'convert_var_shape', + 'convert_pop', 'convert_print', 'convert_var_dtype', 'convert_var_shape', 'convert_while_loop' ] -- GitLab