From 1e5fec393856e9348393d0b2da39bdbd90234165 Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Fri, 11 Aug 2023 10:52:59 +0800 Subject: [PATCH] [Prim] Fix get var in prim when list of single tensor (#56114) * fix get var in prim * fix stack test case --- python/paddle/incubate/autograd/utils.py | 22 ++++++++++++++++------ test/legacy_test/test_stack_op.py | 15 +++++++++++++++ 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/python/paddle/incubate/autograd/utils.py b/python/paddle/incubate/autograd/utils.py index e79c27f30e1..0de52c68bb6 100644 --- a/python/paddle/incubate/autograd/utils.py +++ b/python/paddle/incubate/autograd/utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections import typing import paddle @@ -132,11 +133,13 @@ INT_DTYPE_2_STRING = { } -def get_var_block(block, names): +def get_var_block(block, names, is_tensor_list=None): assert isinstance(names, list) if len(names) == 0: return None elif len(names) == 1: + if is_tensor_list: + return [block.var(names[0])] return block.var(names[0]) else: return [block.var(name) for name in names] @@ -179,7 +182,7 @@ def _get_args_values(op, phi_name): "get attrs' values for api args' values" args = op_info[phi_name] args_list = args["args"].split(",") - inputs = [] + inputs = collections.OrderedDict() attrs = [] for item in args_list: @@ -212,9 +215,9 @@ def _get_args_values(op, phi_name): "inputs" in op_content.keys() and arg_name in op_content["inputs"].keys() ): - inputs.append(op_content["inputs"][arg_name]) + inputs[op_content["inputs"][arg_name]] = arg_type else: - inputs.append(arg_name) + inputs[arg_name] = arg_type else: attr_value = _get_attr_value(op, arg_type, arg_name) attrs.append(attr_value) @@ -237,9 +240,16 @@ def prepare_python_api_arguments(op): phi_name = op.type inputs, attrs = _get_args_values(op, phi_name) res = [] - for item in inputs: + for item, tensor_type in inputs.items(): if item in op.input_names: - res.append(get_var_block(op.block, op.input(item))) + if tensor_type == "Tensor[]": + res.append( + get_var_block( + op.block, op.input(item), is_tensor_list=True + ) + ) + else: + res.append(get_var_block(op.block, op.input(item))) else: # Note: in some cases, inputs may be optional, thus assign None. Such case must be recorded. res.append(None) diff --git a/test/legacy_test/test_stack_op.py b/test/legacy_test/test_stack_op.py index fea31835120..5c5e653dbae 100644 --- a/test/legacy_test/test_stack_op.py +++ b/test/legacy_test/test_stack_op.py @@ -375,5 +375,20 @@ class TestStackAPI_ZeroDim(unittest.TestCase): paddle.enable_static() +class TestStackListOfSingleTensor(unittest.TestCase): + def setUp(self): + paddle.disable_static() + paddle.seed(2022) + self.x = [paddle.randn((4, 2, 6), dtype="float32")] + + def test_list_single_tensor(self): + expect = paddle.stack(self.x) + paddle.fluid.core._set_prim_all_enabled(True) + st_model = paddle.jit.to_static(paddle.stack) + actual = st_model(self.x) + np.testing.assert_allclose(expect, actual) + paddle.enable_static() + + if __name__ == '__main__': unittest.main() -- GitLab