From 812e4b4756c95fdc4af0973626485b8d897ed0e7 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 23 Sep 2022 10:13:49 +0800 Subject: [PATCH] [BugFix]Fix reduce_mean/min/sum/prod, cumsum grad_op infershape bug (#46408) * [BugFix]Fix reduce_mean/min/sum/prod, cumsum grad_op infershape bug * fix typo * fix typo --- paddle/fluid/operators/cum_op.cc | 6 +--- paddle/fluid/operators/reduce_ops/reduce_op.h | 32 +++++++++++-------- .../fluid/tests/unittests/test_cumsum_op.py | 3 ++ .../fluid/tests/unittests/test_sum_op.py | 3 ++ 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/operators/cum_op.cc b/paddle/fluid/operators/cum_op.cc index 09d3f1dbe7..29bc83bd9a 100644 --- a/paddle/fluid/operators/cum_op.cc +++ b/paddle/fluid/operators/cum_op.cc @@ -72,13 +72,9 @@ class CumsumGradMaker : public framework::SingleGradOpMaker { grad_op->SetType("cumsum"); grad_op->SetInput("X", this->OutputGrad("Out")); grad_op->SetOutput("Out", this->InputGrad("X")); - grad_op->SetAttr("axis", PADDLE_GET_CONST(int, this->GetAttr("axis"))); - grad_op->SetAttr("flatten", - PADDLE_GET_CONST(bool, this->GetAttr("flatten"))); + grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttr("reverse", !PADDLE_GET_CONST(bool, this->GetAttr("reverse"))); - grad_op->SetAttr("exclusive", - PADDLE_GET_CONST(bool, this->GetAttr("exclusive"))); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index d305a65e0d..22a251706a 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -634,20 +634,26 @@ class ReduceGradOp : public framework::OperatorWithKernel { "ReduceOp"); auto x_dims = ctx->GetInputDim("X"); auto x_rank = x_dims.size(); - auto dims = ctx->Attrs().Get>("dim"); - for (size_t i = 0; i < dims.size(); ++i) { - PADDLE_ENFORCE_LT(dims[i], - x_rank, - platform::errors::InvalidArgument( - "The reduce dim index %d should be in the " - "range [-dimension(X), dimension(X)], " - "which dimesion = %d. But received dim index = %d.", - i, - x_rank, - dims[i])); - if (dims[i] < 0) dims[i] = x_rank + dims[i]; + // TODO(dev): We should delete Infershape and migrate it into + // UnchangeInferMeta.In case of 'dim' is Variable, it will + // not exist in Attrs but in Inputs. + if (ctx->HasAttr("dim")) { + auto dims = ctx->Attrs().Get>("dim"); + for (size_t i = 0; i < dims.size(); ++i) { + PADDLE_ENFORCE_LT( + dims[i], + x_rank, + platform::errors::InvalidArgument( + "The reduce dim index %d should be in the " + "range [-dimension(X), dimension(X)], " + "which dimesion = %d. But received dim index = %d.", + i, + x_rank, + dims[i])); + if (dims[i] < 0) dims[i] = x_rank + dims[i]; + } } - sort(dims.begin(), dims.end()); + auto x_grad_name = framework::GradVarName("X"); if (ctx->HasOutput(x_grad_name)) { ctx->SetOutputDim(x_grad_name, x_dims); diff --git a/python/paddle/fluid/tests/unittests/test_cumsum_op.py b/python/paddle/fluid/tests/unittests/test_cumsum_op.py index 42def391ac..cfef7ddcf4 100644 --- a/python/paddle/fluid/tests/unittests/test_cumsum_op.py +++ b/python/paddle/fluid/tests/unittests/test_cumsum_op.py @@ -356,6 +356,9 @@ class TestTensorAxis(unittest.TestCase): relu_out = paddle.nn.functional.relu(linear_out) axis = paddle.full([1], 2, dtype='int64') out = paddle.cumsum(relu_out, axis=axis) + loss = paddle.mean(out) + sgd = paddle.optimizer.SGD(learning_rate=0.) + sgd.minimize(paddle.mean(out)) exe = paddle.static.Executor(self.place) exe.run(starup_prog) diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index c4d7bb7c2b..dc6d867321 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -543,6 +543,9 @@ class TestReduceOPTensorAxisBase(unittest.TestCase): linear = paddle.nn.Linear(x.shape[-1], 5) linear_out = linear(x) out = self.pd_api(linear_out, axis, keepdim=self.keepdim) + + sgd = paddle.optimizer.SGD(learning_rate=0.) + sgd.minimize(paddle.mean(out)) exe = paddle.static.Executor(self.place) exe.run(starup_prog) static_out = exe.run(feed={'x': self.x.numpy().astype('float32')}, -- GitLab