diff --git a/python/paddle/incubate/autograd/utils.py b/python/paddle/incubate/autograd/utils.py index e79c27f30e1d000aa5d1f69ed610fcfb005995d1..0de52c68bb61b2b42f055cc69ae9cb4a1c11b798 100644 --- a/python/paddle/incubate/autograd/utils.py +++ b/python/paddle/incubate/autograd/utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections import typing import paddle @@ -132,11 +133,13 @@ INT_DTYPE_2_STRING = { } -def get_var_block(block, names): +def get_var_block(block, names, is_tensor_list=None): assert isinstance(names, list) if len(names) == 0: return None elif len(names) == 1: + if is_tensor_list: + return [block.var(names[0])] return block.var(names[0]) else: return [block.var(name) for name in names] @@ -179,7 +182,7 @@ def _get_args_values(op, phi_name): "get attrs' values for api args' values" args = op_info[phi_name] args_list = args["args"].split(",") - inputs = [] + inputs = collections.OrderedDict() attrs = [] for item in args_list: @@ -212,9 +215,9 @@ def _get_args_values(op, phi_name): "inputs" in op_content.keys() and arg_name in op_content["inputs"].keys() ): - inputs.append(op_content["inputs"][arg_name]) + inputs[op_content["inputs"][arg_name]] = arg_type else: - inputs.append(arg_name) + inputs[arg_name] = arg_type else: attr_value = _get_attr_value(op, arg_type, arg_name) attrs.append(attr_value) @@ -237,9 +240,16 @@ def prepare_python_api_arguments(op): phi_name = op.type inputs, attrs = _get_args_values(op, phi_name) res = [] - for item in inputs: + for item, tensor_type in inputs.items(): if item in op.input_names: - res.append(get_var_block(op.block, op.input(item))) + if tensor_type == "Tensor[]": + res.append( + get_var_block( + op.block, op.input(item), is_tensor_list=True + ) + ) + else: + res.append(get_var_block(op.block, op.input(item))) else: # Note: in some cases, inputs may be optional, thus assign None. Such case must be recorded. res.append(None) diff --git a/test/legacy_test/test_stack_op.py b/test/legacy_test/test_stack_op.py index fea318351208339fb90526d380631ba615d9ac62..5c5e653dbaeb6abfc08ff599a60c12ba518b279c 100644 --- a/test/legacy_test/test_stack_op.py +++ b/test/legacy_test/test_stack_op.py @@ -375,5 +375,20 @@ class TestStackAPI_ZeroDim(unittest.TestCase): paddle.enable_static() +class TestStackListOfSingleTensor(unittest.TestCase): + def setUp(self): + paddle.disable_static() + paddle.seed(2022) + self.x = [paddle.randn((4, 2, 6), dtype="float32")] + + def test_list_single_tensor(self): + expect = paddle.stack(self.x) + paddle.fluid.core._set_prim_all_enabled(True) + st_model = paddle.jit.to_static(paddle.stack) + actual = st_model(self.x) + np.testing.assert_allclose(expect, actual) + paddle.enable_static() + + if __name__ == '__main__': unittest.main()