未验证 提交 484377cd 编写于 作者: A Aurelius84 提交者: GitHub

[Cherry-Pick][BugFix]Fix reduce_mean/min/sum/prod, cumsum grad_op infershape bug (#46409)

* [BugFix]Fix reduce_mean/min/sum/prod, cumsum grad_op infershape bug

* fix typo

* fix typo
上级 7eb046c7
......@@ -72,13 +72,9 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetType("cumsum");
grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttr("axis", PADDLE_GET_CONST(int, this->GetAttr("axis")));
grad_op->SetAttr("flatten",
PADDLE_GET_CONST(bool, this->GetAttr("flatten")));
grad_op->SetAttrMap(this->Attrs());
grad_op->SetAttr("reverse",
!PADDLE_GET_CONST(bool, this->GetAttr("reverse")));
grad_op->SetAttr("exclusive",
PADDLE_GET_CONST(bool, this->GetAttr("exclusive")));
}
};
......
......@@ -634,9 +634,14 @@ class ReduceGradOp : public framework::OperatorWithKernel {
"ReduceOp");
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
// TODO(dev): We should delete Infershape and migrate it into
// UnchangeInferMeta.In case of 'dim' is Variable, it will
// not exist in Attrs but in Inputs.
if (ctx->HasAttr("dim")) {
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
for (size_t i = 0; i < dims.size(); ++i) {
PADDLE_ENFORCE_LT(dims[i],
PADDLE_ENFORCE_LT(
dims[i],
x_rank,
platform::errors::InvalidArgument(
"The reduce dim index %d should be in the "
......@@ -647,7 +652,8 @@ class ReduceGradOp : public framework::OperatorWithKernel {
dims[i]));
if (dims[i] < 0) dims[i] = x_rank + dims[i];
}
sort(dims.begin(), dims.end());
}
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
......
......@@ -356,6 +356,9 @@ class TestTensorAxis(unittest.TestCase):
relu_out = paddle.nn.functional.relu(linear_out)
axis = paddle.full([1], 2, dtype='int64')
out = paddle.cumsum(relu_out, axis=axis)
loss = paddle.mean(out)
sgd = paddle.optimizer.SGD(learning_rate=0.)
sgd.minimize(paddle.mean(out))
exe = paddle.static.Executor(self.place)
exe.run(starup_prog)
......
......@@ -543,6 +543,9 @@ class TestReduceOPTensorAxisBase(unittest.TestCase):
linear = paddle.nn.Linear(x.shape[-1], 5)
linear_out = linear(x)
out = self.pd_api(linear_out, axis, keepdim=self.keepdim)
sgd = paddle.optimizer.SGD(learning_rate=0.)
sgd.minimize(paddle.mean(out))
exe = paddle.static.Executor(self.place)
exe.run(starup_prog)
static_out = exe.run(feed={'x': self.x.numpy().astype('float32')},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册