From 8a859554867a527c7ff795cb794eb87aa656c3e4 Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Fri, 30 Dec 2022 10:43:35 +0800 Subject: [PATCH] Fix default GetExpectedKernelType for ops supported tensor attrs (#49414) * Fix default GetExpectedKernelType for ops supported tensor attrs --- paddle/fluid/framework/operator.cc | 18 +++++++++++++++++- .../fluid/tests/unittests/test_pad_op.py | 16 ++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index ae216b1e499..91f8f869c91 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2715,7 +2715,23 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( proto::VarType::Type dafault_data_type = static_cast(-1); proto::VarType::Type data_type = dafault_data_type; - for (auto* name : ctx.InNameList()) { + + auto in_name_list = ctx.InNameList(); + if (Info().HasOpProtoAndChecker()) { + for (auto& attr : Info().Proto().attrs()) { + auto it = + std::find_if(in_name_list.begin(), + in_name_list.end(), + [&attr](const std::string* name) { + return attr.support_tensor() && *name == attr.name(); + }); + if (it != in_name_list.end()) { + in_name_list.erase(it); + } + } + } + + for (auto* name : in_name_list) { if (ctx.InputSize(*name) == 1UL) { ParseInputDataType(ctx.InputVar(*name), *name, &data_type); } else { diff --git a/python/paddle/fluid/tests/unittests/test_pad_op.py b/python/paddle/fluid/tests/unittests/test_pad_op.py index 04617274356..735f62c646b 100644 --- a/python/paddle/fluid/tests/unittests/test_pad_op.py +++ b/python/paddle/fluid/tests/unittests/test_pad_op.py @@ -178,6 +178,22 @@ class TestPaddingValueTensor2(TestPaddingValueTensor): return out +class TestPaddingValueTensor3(unittest.TestCase): + def test_static(self): + np_x = np.random.random((16, 16)).astype('float32') + main_prog = Program() + starup_prog = Program() + with program_guard(main_prog, starup_prog): + x = paddle.assign(np_x).astype('float32') + pad_value = paddle.assign([0.0]).astype('float64') + y = paddle.nn.functional.pad(x, [0, 1, 2, 3], value=pad_value) + + exe = paddle.static.Executor(paddle.CPUPlace()) + [pd_out] = exe.run(main_prog, fetch_list=[y]) + np_out = np.pad(np_x, [(0, 1), (2, 3)], constant_values=0.0) + np.testing.assert_allclose(pd_out, np_out) + + if __name__ == '__main__': paddle.enable_static() unittest.main() -- GitLab