diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op_xpu.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op_xpu.cc index 108bba3b4522a1fee4fba769e844d08d6040bcfd..c86ebbc20c3359819968cbc2623a9d3b09cb8c3f 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op_xpu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op_xpu.cc @@ -90,6 +90,7 @@ class ReduceMeanGradXPUKernel : public framework::OpKernel { bool reduce_all = ctx.Attr("reduce_all"); auto reduce_dims = ctx.Attr>("dim"); + bool keep_dim = ctx.Attr("keep_dim"); std::vector xdims; for (int i = 0; i < input->dims().size(); i++) { @@ -112,7 +113,13 @@ class ReduceMeanGradXPUKernel : public framework::OpKernel { d = d + xdims.size(); } reduce_numel *= xdims[d]; - ydims.insert(ydims.begin() + d, 1); + } + + if (keep_dim != true) { + sort(reduce_dims.begin(), reduce_dims.end()); + for (auto& d : reduce_dims) { + ydims.insert(ydims.begin() + d, 1); + } } float val = 1.0f / static_cast(reduce_numel); diff --git a/python/paddle/fluid/tests/unittests/xpu/test_reduce_mean_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_reduce_mean_op_xpu.py index 90fe474e09cd1d2b077bf1a15e6e8cddd550dbab..73fa3fba3c7600a1a9b57291d404cdc582b3f06a 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_reduce_mean_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_reduce_mean_op_xpu.py @@ -103,6 +103,9 @@ class XPUTestReduce(XPUOpTestWrapper): # def test_check_grad(self): # self.check_output_with_place(self.place, ['X'], 'Out') + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + class Test2DReduce0(Test1DReduce): def setUp(self): @@ -161,6 +164,18 @@ class XPUTestReduce(XPUOpTestWrapper): 'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim'])) } + class Test6DReduce(Test1DReduce): + + def setUp(self): + super().setUp() + self.attrs = {'dim': [1, -1], 'use_xpu': True} + self.inputs = { + 'X': np.random.random((5, 6, 7, 8, 9, 10)).astype(self.dtype) + } + self.outputs = { + 'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim'])) + } + class TestKeepDimReduce(Test1DReduce): def setUp(self):