未验证 提交 484377cd 编写于 作者: A Aurelius84 提交者: GitHub

[Cherry-Pick][BugFix]Fix reduce_mean/min/sum/prod, cumsum grad_op infershape bug (#46409)

* [BugFix]Fix reduce_mean/min/sum/prod, cumsum grad_op infershape bug

* fix typo

* fix typo
上级 7eb046c7
...@@ -72,13 +72,9 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -72,13 +72,9 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetType("cumsum"); grad_op->SetType("cumsum");
grad_op->SetInput("X", this->OutputGrad("Out")); grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", this->InputGrad("X")); grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttr("axis", PADDLE_GET_CONST(int, this->GetAttr("axis"))); grad_op->SetAttrMap(this->Attrs());
grad_op->SetAttr("flatten",
PADDLE_GET_CONST(bool, this->GetAttr("flatten")));
grad_op->SetAttr("reverse", grad_op->SetAttr("reverse",
!PADDLE_GET_CONST(bool, this->GetAttr("reverse"))); !PADDLE_GET_CONST(bool, this->GetAttr("reverse")));
grad_op->SetAttr("exclusive",
PADDLE_GET_CONST(bool, this->GetAttr("exclusive")));
} }
}; };
......
...@@ -634,9 +634,14 @@ class ReduceGradOp : public framework::OperatorWithKernel { ...@@ -634,9 +634,14 @@ class ReduceGradOp : public framework::OperatorWithKernel {
"ReduceOp"); "ReduceOp");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size(); auto x_rank = x_dims.size();
// TODO(dev): We should delete Infershape and migrate it into
// UnchangeInferMeta.In case of 'dim' is Variable, it will
// not exist in Attrs but in Inputs.
if (ctx->HasAttr("dim")) {
auto dims = ctx->Attrs().Get<std::vector<int>>("dim"); auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
for (size_t i = 0; i < dims.size(); ++i) { for (size_t i = 0; i < dims.size(); ++i) {
PADDLE_ENFORCE_LT(dims[i], PADDLE_ENFORCE_LT(
dims[i],
x_rank, x_rank,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The reduce dim index %d should be in the " "The reduce dim index %d should be in the "
...@@ -647,7 +652,8 @@ class ReduceGradOp : public framework::OperatorWithKernel { ...@@ -647,7 +652,8 @@ class ReduceGradOp : public framework::OperatorWithKernel {
dims[i])); dims[i]));
if (dims[i] < 0) dims[i] = x_rank + dims[i]; if (dims[i] < 0) dims[i] = x_rank + dims[i];
} }
sort(dims.begin(), dims.end()); }
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->SetOutputDim(x_grad_name, x_dims);
......
...@@ -356,6 +356,9 @@ class TestTensorAxis(unittest.TestCase): ...@@ -356,6 +356,9 @@ class TestTensorAxis(unittest.TestCase):
relu_out = paddle.nn.functional.relu(linear_out) relu_out = paddle.nn.functional.relu(linear_out)
axis = paddle.full([1], 2, dtype='int64') axis = paddle.full([1], 2, dtype='int64')
out = paddle.cumsum(relu_out, axis=axis) out = paddle.cumsum(relu_out, axis=axis)
loss = paddle.mean(out)
sgd = paddle.optimizer.SGD(learning_rate=0.)
sgd.minimize(paddle.mean(out))
exe = paddle.static.Executor(self.place) exe = paddle.static.Executor(self.place)
exe.run(starup_prog) exe.run(starup_prog)
......
...@@ -543,6 +543,9 @@ class TestReduceOPTensorAxisBase(unittest.TestCase): ...@@ -543,6 +543,9 @@ class TestReduceOPTensorAxisBase(unittest.TestCase):
linear = paddle.nn.Linear(x.shape[-1], 5) linear = paddle.nn.Linear(x.shape[-1], 5)
linear_out = linear(x) linear_out = linear(x)
out = self.pd_api(linear_out, axis, keepdim=self.keepdim) out = self.pd_api(linear_out, axis, keepdim=self.keepdim)
sgd = paddle.optimizer.SGD(learning_rate=0.)
sgd.minimize(paddle.mean(out))
exe = paddle.static.Executor(self.place) exe = paddle.static.Executor(self.place)
exe.run(starup_prog) exe.run(starup_prog)
static_out = exe.run(feed={'x': self.x.numpy().astype('float32')}, static_out = exe.run(feed={'x': self.x.numpy().astype('float32')},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册