未验证 提交 5ac96468 编写于 作者: A Aurelius84 提交者: GitHub

[OpAttr]Fix Ignore AttriteTensor in IndicateDataType bug in grad_op (#49472)

* [OpAttr]Fix Ignore AttriteTensor in IndicateDataType bug in grad_op

* add GetExpectedKernelType
上级 c4604025
...@@ -2716,22 +2716,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -2716,22 +2716,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
static_cast<proto::VarType::Type>(-1); static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type; proto::VarType::Type data_type = dafault_data_type;
auto in_name_list = ctx.InNameList(); for (auto* name : ctx.InNameList()) {
if (Info().HasOpProtoAndChecker()) {
for (auto& attr : Info().Proto().attrs()) {
auto it =
std::find_if(in_name_list.begin(),
in_name_list.end(),
[&attr](const std::string* name) {
return attr.support_tensor() && *name == attr.name();
});
if (it != in_name_list.end()) {
in_name_list.erase(it);
}
}
}
for (auto* name : in_name_list) {
if (ctx.InputSize(*name) == 1UL) { if (ctx.InputSize(*name) == 1UL) {
ParseInputDataType(ctx.InputVar(*name), *name, &data_type); ParseInputDataType(ctx.InputVar(*name), *name, &data_type);
} else { } else {
......
...@@ -30,6 +30,13 @@ class PadOp : public framework::OperatorWithKernel { ...@@ -30,6 +30,13 @@ class PadOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Pad"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Pad");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Pad"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Pad");
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
}; };
class PadOpMaker : public framework::OpProtoAndCheckerMaker { class PadOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -98,6 +105,14 @@ class PadOpGrad : public framework::OperatorWithKernel { ...@@ -98,6 +105,14 @@ class PadOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(x_grad_name, dout_dims); ctx->SetOutputDim(x_grad_name, dout_dims);
} }
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
}; };
template <typename T> template <typename T>
......
...@@ -187,9 +187,14 @@ class TestPaddingValueTensor3(unittest.TestCase): ...@@ -187,9 +187,14 @@ class TestPaddingValueTensor3(unittest.TestCase):
x = paddle.assign(np_x).astype('float32') x = paddle.assign(np_x).astype('float32')
pad_value = paddle.assign([0.0]).astype('float64') pad_value = paddle.assign([0.0]).astype('float64')
y = paddle.nn.functional.pad(x, [0, 1, 2, 3], value=pad_value) y = paddle.nn.functional.pad(x, [0, 1, 2, 3], value=pad_value)
loss = y.sum()
optimize_ops, params_grads = paddle.optimizer.SGD(0.01).minimize(
loss
)
exe = paddle.static.Executor(paddle.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
[pd_out] = exe.run(main_prog, fetch_list=[y]) res = exe.run(main_prog, fetch_list=[y] + [g for p, g in params_grads])
pd_out = res[0]
np_out = np.pad(np_x, [(0, 1), (2, 3)], constant_values=0.0) np_out = np.pad(np_x, [(0, 1), (2, 3)], constant_values=0.0)
np.testing.assert_allclose(pd_out, np_out) np.testing.assert_allclose(pd_out, np_out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册