未验证 提交 682cc17f 编写于 作者: L liym27 提交者: GitHub

[Dynamic-to-Static] Fix bug: support pop from a dict and polish code of convert_pop (#29023)

* Support pop for dict in dy2stat

* Move convert_pop to convert_operators.py and polish convert_pop
上级 8ca0a8a8
...@@ -16,7 +16,10 @@ from paddle.fluid.data_feeder import convert_dtype ...@@ -16,7 +16,10 @@ from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.framework import core, Variable from paddle.fluid.framework import core, Variable
from paddle.fluid.layers import Assert, Print from paddle.fluid.layers import Assert, Print
from paddle.fluid.layers import array_length, array_read, array_write, create_array
from paddle.fluid.layers import assign, fill_constant, slice
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
def convert_while_loop(cond, body, loop_vars): def convert_while_loop(cond, body, loop_vars):
...@@ -24,12 +27,12 @@ def convert_while_loop(cond, body, loop_vars): ...@@ -24,12 +27,12 @@ def convert_while_loop(cond, body, loop_vars):
A function representation of a Python ``while`` statement. A function representation of a Python ``while`` statement.
Args: Args:
cond(Callable): A callable object that returns a boolean variable to control whether to execute the loop body. It takes ``loop_vars`` as arguments. cond(Callable): A callable object that returns a boolean variable to control whether to execute the loop body. It takes ``loop_vars`` as arguments.
body(Callable): A callable object that returns a tuple or list of variables with the same arguments ``loops_vars`` as ``cond`` . body(Callable): A callable object that returns a tuple or list of variables with the same arguments ``loops_vars`` as ``cond`` .
loop_vars(list|tuple): A list or tuple of variables passed to ``cond`` and ``body`` . loop_vars(list|tuple): A list or tuple of variables passed to ``cond`` and ``body`` .
Returns: Returns:
A list or tuple of variables which returned by ``body`` . A list or tuple of variables which returned by ``body``.
""" """
# NOTE: It may be slower if cond is very expensive, but usually cond is just O(1). # NOTE: It may be slower if cond is very expensive, but usually cond is just O(1).
...@@ -320,3 +323,85 @@ def convert_print(*args): ...@@ -320,3 +323,85 @@ def convert_print(*args):
var = Print(var) var = Print(var)
else: else:
print(var) print(var)
def convert_pop(target, *args):
"""
A function representation of a Python pop statement for a list or dict.
Args:
target(list|dict|Tensor): A variable to pop item from.
*args(tuple): index or default value to parse.
Returns:
A item poped from target.
"""
is_variable = isinstance(target, Variable)
if is_variable:
is_tensor_array = target.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
if is_variable and is_tensor_array:
return _run_paddle_pop(target, *args)
else:
return _run_python_pop(target, *args)
def _run_paddle_pop(array, *args):
if len(args) == 0:
idx = -1
else:
idx = args[0]
assert isinstance(idx, int)
def cond(i, new_array):
return less_than(i, arr_len)
def body(i, new_array):
item = array_read(array=array, i=i)
array_write(item, array_length(new_array), new_array)
i = increment(i)
return i, new_array
arr_len = array_length(array)
if idx < 0:
idx = idx + arr_len
else:
idx = fill_constant(shape=[1], dtype="int64", value=idx)
pop_item = array_read(array, idx)
new_array = _slice_tensor_array(array, 0, idx)
i = idx + 1
_, new_array = while_loop(cond, body, [i, new_array])
assign(input=new_array, output=array)
return pop_item
# TODO(liym27): A better way to slice tensor array.
# Maybe support start == end for slice op.
def _slice_tensor_array(array, start, end):
def true_fn():
null_array = create_array("float32")
return null_array
def false_fn(array, start, end):
new_array = slice(array, starts=[start], ends=[end], axes=[0])
return new_array
new_array = cond(start == end, true_fn, lambda: false_fn(array, start, end))
return new_array
def _run_python_pop(target, *args):
# 1. pop for a dict
if len(args) == 2:
idx, default = args
return target.pop(idx, default)
# 2. pop for a list or dict
else:
idx = args[0] if args else -1
return target.pop(idx)
...@@ -17,74 +17,9 @@ from __future__ import print_function ...@@ -17,74 +17,9 @@ from __future__ import print_function
import astor import astor
import gast import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code, is_control_flow_to_transform from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code, is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.framework import core, Variable
from paddle.fluid.layers import array_length, array_read, array_write, create_array
from paddle.fluid.layers import assign, fill_constant, slice
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
# TODO(liym27): A better way to slice tensor array.
# Maybe support start == end for slice op.
def slice_tensor_array(array, start, end):
def true_fn():
null_array = create_array("float32")
return null_array
def false_fn(array, start, end):
new_array = slice(array, starts=[start], ends=[end], axes=[0])
return new_array
new_array = cond(start == end, true_fn, lambda: false_fn(array, start, end))
return new_array
def tensor_array_pop(array, idx):
assert isinstance(idx, int)
def cond(i, new_array):
return less_than(i, arr_len)
def body(i, new_array):
item = array_read(array=array, i=i)
array_write(item, array_length(new_array), new_array)
i = increment(i)
return i, new_array
arr_len = array_length(array)
if idx < 0:
idx = idx + arr_len
else:
idx = fill_constant(shape=[1], dtype="int64", value=idx)
pop_item = array_read(array, idx)
new_array = slice_tensor_array(array, 0, idx)
i = idx + 1
_, new_array = while_loop(cond, body, [i, new_array])
assign(input=new_array, output=array)
return pop_item
def convert_list_pop(target, idx=None):
"""
Convert list pop.
"""
if idx is None:
idx = -1
is_variable = isinstance(target, Variable)
if is_variable:
is_tensor_array = target.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
if is_variable and is_tensor_array:
result = tensor_array_pop(target, idx)
else:
result = target.pop(idx)
return result
class ListTransformer(gast.NodeTransformer): class ListTransformer(gast.NodeTransformer):
...@@ -117,7 +52,7 @@ class ListTransformer(gast.NodeTransformer): ...@@ -117,7 +52,7 @@ class ListTransformer(gast.NodeTransformer):
if isinstance(node.func, gast.Attribute): if isinstance(node.func, gast.Attribute):
func_name = node.func.attr func_name = node.func.attr
if func_name == "pop": if func_name == "pop":
node = self._replace_list_pop(node) node = self._replace_pop(node)
return node return node
def visit_Assign(self, node): def visit_Assign(self, node):
...@@ -283,20 +218,36 @@ class ListTransformer(gast.NodeTransformer): ...@@ -283,20 +218,36 @@ class ListTransformer(gast.NodeTransformer):
del self.list_name_to_updated[target_id] del self.list_name_to_updated[target_id]
return False return False
def _replace_list_pop(self, node): def _replace_pop(self, node):
"""
Replace a pop statement for a list or dict.
For example:
list_a = [0,1,2,3,4]
x = list_a.pop() # --> convert_pop(list_a)
y = list_a.pop(1) # --> convert_pop(list_a, 1)
dict_a = {"red":0, "blue":1, "yellow":2}
m = dict_a.pop("red") # --> convert_pop(dict_a, "red")
n = dict_a.pop("black", 3) # --> convert_pop(dict_a, "black", 3)
"""
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
assert isinstance(node.func, gast.Attribute) assert isinstance(node.func, gast.Attribute)
target_node = node.func.value target_node = node.func.value
target_str = ast_to_source_code(target_node).strip() target_str = ast_to_source_code(target_node).strip()
if node.args: args_str = [ast_to_source_code(arg).strip() for arg in node.args]
idx_node = node.args[0]
idx_str = ast_to_source_code(idx_node).strip() # NOTE(liym27):
# 1. pop stmt for a list if len(args_str) == 0
# 2. pop stmt for a list or dict if len(args_str) == 1
# 3. pop stmt for a dict if len(args_str) == 2
if len(args_str) <= 2:
new_pop_str = "paddle.jit.dy2static.convert_pop({}, {})"\
.format(target_str, ",".join(args_str))
new_pop_node = gast.parse(new_pop_str).body[0].value
return new_pop_node
else: else:
idx_str = "None" return node
new_call_str = "fluid.dygraph.dygraph_to_static.list_transformer.convert_list_pop({}, {})".format(
target_str, idx_str)
new_call_node = gast.parse(new_call_str).body[0].value
return new_call_node
...@@ -18,6 +18,7 @@ import six ...@@ -18,6 +18,7 @@ import six
import numpy as np import numpy as np
import unittest import unittest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.jit import to_static from paddle.jit import to_static
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
...@@ -139,5 +140,68 @@ class TestNetWithDict(unittest.TestCase): ...@@ -139,5 +140,68 @@ class TestNetWithDict(unittest.TestCase):
self.assertTrue((self._run_dygraph() == self._run_static()).all()) self.assertTrue((self._run_dygraph() == self._run_static()).all())
# Tests for dict pop
@paddle.jit.to_static
def test_dic_pop(x):
x = paddle.to_tensor(x)
dict_a = {"red": 0, "green": 1, "blue": 2}
m = dict_a.pop("red")
n = dict_a.pop("black", 3)
out = x + m + n
return out
@paddle.jit.to_static
def test_dic_pop_2(x):
x = paddle.to_tensor(x)
dict_a = {"red": x, "green": x + 1, "blue": x + 3}
m = dict_a.pop("red")
n = dict_a.pop("black", 3)
out = x + m + n
return out
class TestDictPop(unittest.TestCase):
def setUp(self):
self.input = np.random.random((3)).astype('int32')
self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
self._set_test_func()
def _set_test_func(self):
self.dygraph_func = test_dic_pop
def _run_static(self):
return self._run(to_static=True)
def _run_dygraph(self):
return self._run(to_static=False)
def _run(self, to_static):
prog_trans = ProgramTranslator()
prog_trans.enable(to_static)
result = self.dygraph_func(self.input)
return result.numpy()
def test_transformed_result(self):
dygraph_res = self._run_dygraph()
static_res = self._run_static()
self.assertTrue(
np.allclose(dygraph_res, static_res),
msg='dygraph result is {}\nstatic result is {}'.format(dygraph_res,
static_res))
class TestDictPop2(TestDictPop):
def _set_test_func(self):
self.dygraph_func = test_dic_pop_2
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -20,6 +20,7 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_len #D ...@@ -20,6 +20,7 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_len #D
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical_and #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical_and #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical_not #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical_not #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical_or #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical_or #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_pop #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape #DEFINE_ALIAS
...@@ -28,6 +29,6 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_l ...@@ -28,6 +29,6 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_l
__all__ = [ __all__ = [
'cast_bool_if_necessary', 'convert_assert', 'convert_ifelse', 'convert_len', 'cast_bool_if_necessary', 'convert_assert', 'convert_ifelse', 'convert_len',
'convert_logical_and', 'convert_logical_not', 'convert_logical_or', 'convert_logical_and', 'convert_logical_not', 'convert_logical_or',
'convert_print', 'convert_var_dtype', 'convert_var_shape', 'convert_pop', 'convert_print', 'convert_var_dtype', 'convert_var_shape',
'convert_while_loop' 'convert_while_loop'
] ]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册