From 5ac96468c910333210ec29328096b397c545e19d Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 3 Jan 2023 14:33:17 +0800 Subject: [PATCH] [OpAttr]Fix Ignore AttriteTensor in IndicateDataType bug in grad_op (#49472) * [OpAttr]Fix Ignore AttriteTensor in IndicateDataType bug in grad_op * add GetExpectedKernelType --- paddle/fluid/framework/operator.cc | 17 +---------------- paddle/fluid/operators/pad_op.cc | 15 +++++++++++++++ .../paddle/fluid/tests/unittests/test_pad_op.py | 7 ++++++- 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index eb7ad8ed94..dcb822afb4 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2716,22 +2716,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( static_cast(-1); proto::VarType::Type data_type = dafault_data_type; - auto in_name_list = ctx.InNameList(); - if (Info().HasOpProtoAndChecker()) { - for (auto& attr : Info().Proto().attrs()) { - auto it = - std::find_if(in_name_list.begin(), - in_name_list.end(), - [&attr](const std::string* name) { - return attr.support_tensor() && *name == attr.name(); - }); - if (it != in_name_list.end()) { - in_name_list.erase(it); - } - } - } - - for (auto* name : in_name_list) { + for (auto* name : ctx.InNameList()) { if (ctx.InputSize(*name) == 1UL) { ParseInputDataType(ctx.InputVar(*name), *name, &data_type); } else { diff --git a/paddle/fluid/operators/pad_op.cc b/paddle/fluid/operators/pad_op.cc index 4e6a10a912..2951091508 100644 --- a/paddle/fluid/operators/pad_op.cc +++ b/paddle/fluid/operators/pad_op.cc @@ -30,6 +30,13 @@ class PadOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Pad"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Pad"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } }; class PadOpMaker : public framework::OpProtoAndCheckerMaker { @@ -98,6 +105,14 @@ class PadOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(x_grad_name, dout_dims); } } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } }; template diff --git a/python/paddle/fluid/tests/unittests/test_pad_op.py b/python/paddle/fluid/tests/unittests/test_pad_op.py index 735f62c646..ee42ce1625 100644 --- a/python/paddle/fluid/tests/unittests/test_pad_op.py +++ b/python/paddle/fluid/tests/unittests/test_pad_op.py @@ -187,9 +187,14 @@ class TestPaddingValueTensor3(unittest.TestCase): x = paddle.assign(np_x).astype('float32') pad_value = paddle.assign([0.0]).astype('float64') y = paddle.nn.functional.pad(x, [0, 1, 2, 3], value=pad_value) + loss = y.sum() + optimize_ops, params_grads = paddle.optimizer.SGD(0.01).minimize( + loss + ) exe = paddle.static.Executor(paddle.CPUPlace()) - [pd_out] = exe.run(main_prog, fetch_list=[y]) + res = exe.run(main_prog, fetch_list=[y] + [g for p, g in params_grads]) + pd_out = res[0] np_out = np.pad(np_x, [(0, 1), (2, 3)], constant_values=0.0) np.testing.assert_allclose(pd_out, np_out) -- GitLab