From c80fcf901ed432f91ff89acbe5b104205e65be74 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Tue, 25 Aug 2020 10:29:46 +0800 Subject: [PATCH] reduce_mean error if keepdim=True and reduce_all=True (#26614) --- .../reduce_ops/reduce_mean_op.part.cu | 4 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 4 +- .../fluid/tests/unittests/test_mean_op.py | 119 +++++++++++++++++- .../fluid/tests/unittests/test_reduce_op.py | 96 -------------- 4 files changed, 119 insertions(+), 104 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu index 12eceb33ec2..289f574719f 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu @@ -21,6 +21,4 @@ using CUDAReduceMeanGradKernel = ops::MeanGradFunctor, true>; REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel, - CUDAReduceMeanGradKernel, - CUDAReduceMeanGradKernel, - CUDAReduceMeanGradKernel); + CUDAReduceMeanGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 13814a16775..67a19cb83c3 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -236,8 +236,8 @@ class ReduceGradKernel : public framework::OpKernel { if (reduce_all) { auto x = EigenVector::Flatten(*input0); - auto x_reduce = EigenVector::From(*input1); - auto x_reduce_grad = EigenVector::From(*input2); + auto x_reduce = EigenVector::Flatten(*input1); + auto x_reduce_grad = EigenVector::Flatten(*input2); auto x_grad = EigenVector::Flatten(*output); auto& place = *context.template device_context().eigen_device(); diff --git a/python/paddle/fluid/tests/unittests/test_mean_op.py b/python/paddle/fluid/tests/unittests/test_mean_op.py index 1747dbffc18..29e79b096cf 100644 --- a/python/paddle/fluid/tests/unittests/test_mean_op.py +++ b/python/paddle/fluid/tests/unittests/test_mean_op.py @@ -22,6 +22,8 @@ import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import Program, program_guard +np.random.seed(10) + class TestMeanOp(OpTest): def setUp(self): @@ -74,10 +76,105 @@ class TestFP16MeanOp(TestMeanOp): place, ['X'], 'Out', max_relative_error=0.8) +def ref_reduce_mean(x, axis=None, keepdim=False, reduce_all=False): + if isinstance(axis, list): + axis = tuple(axis) + if reduce_all: + axis = None + return np.mean(x, axis=axis, keepdims=keepdim) + + +class TestReduceMeanOp(OpTest): + def setUp(self): + self.op_type = 'reduce_mean' + self.dtype = 'float64' + self.shape = [2, 3, 4, 5] + self.axis = [0] + self.keepdim = False + self.reduce_all = False + self.set_attrs() + + np.random.seed(10) + x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + out_np = ref_reduce_mean(x_np, self.axis, self.keepdim, self.reduce_all) + self.inputs = {'X': x_np} + self.outputs = {'Out': out_np} + self.attrs = { + 'dim': self.axis, + 'keep_dim': self.keepdim, + 'reduce_all': self.reduce_all + } + + def set_attrs(self): + pass + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], ['Out']) + + +class TestReduceMeanOpDefaultAttrs(TestReduceMeanOp): + def setUp(self): + self.op_type = 'reduce_mean' + self.dtype = 'float64' + self.shape = [2, 3, 4, 5] + + x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + out_np = np.mean(x_np, axis=0) + self.inputs = {'X': x_np} + self.outputs = {'Out': out_np} + + +class TestReduceMeanOpFloat32(TestReduceMeanOp): + def set_attrs(self): + self.dtype = 'float32' + + +class TestReduceMeanOpShape1D(TestReduceMeanOp): + def set_attrs(self): + self.shape = [100] + + +class TestReduceMeanOpShape6D(TestReduceMeanOp): + def set_attrs(self): + self.shape = [2, 3, 4, 5, 6, 7] + + +class TestReduceMeanOpAxisAll(TestReduceMeanOp): + def set_attrs(self): + self.axis = [0, 1, 2, 3] + + +class TestReduceMeanOpAxisTuple(TestReduceMeanOp): + def set_attrs(self): + self.axis = (0, 1, 2) + + +class TestReduceMeanOpAxisNegative(TestReduceMeanOp): + def set_attrs(self): + self.axis = [-2, -1] + + +class TestReduceMeanOpKeepdimTrue1(TestReduceMeanOp): + def set_attrs(self): + self.keepdim = True + + +class TestReduceMeanOpKeepdimTrue2(TestReduceMeanOp): + def set_attrs(self): + self.axis = [0, 1, 2, 3] + self.keepdim = True + + +class TestReduceMeanOpReduceAllTrue(TestReduceMeanOp): + def set_attrs(self): + self.reduce_all = True + + class TestMeanAPI(unittest.TestCase): - """ - test paddle.tensor.stat.mean - """ + # test paddle.tensor.stat.mean def setUp(self): self.x_shape = [2, 3, 4, 5] @@ -128,6 +225,22 @@ class TestMeanAPI(unittest.TestCase): test_case(self.x, [0, 1, 2, 3]) paddle.enable_static() + def test_fluid_api(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data("x", shape=[10, 10], dtype="float32") + out = fluid.layers.reduce_mean(input=x, dim=1) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + x_np = np.random.rand(10, 10).astype(np.float32) + res = exe.run(feed={"x": x_np}, fetch_list=[out]) + self.assertEqual(np.allclose(res[0], np.mean(x_np, axis=1)), True) + + with fluid.dygraph.guard(): + x_np = np.random.rand(10, 10).astype(np.float32) + x = fluid.dygraph.to_variable(x_np) + out = fluid.layers.reduce_mean(input=x, dim=1) + self.assertEqual(np.allclose(out.numpy(), np.mean(x_np, axis=1)), True) + def test_errors(self): paddle.disable_static() x = np.random.uniform(-1, 1, [10, 12]).astype('float32') diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 6a6ce8d329c..cf35f9dbcda 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -67,22 +67,6 @@ class TestSumOp6D(OpTest): self.check_grad(['X'], 'Out') -class TestMeanOp(OpTest): - def setUp(self): - self.op_type = "reduce_mean" - self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float64")} - self.attrs = {'dim': [1]} - self.outputs = { - 'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim'])) - } - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - @skip_check_grad_ci( reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") @@ -318,21 +302,6 @@ class TestReduceAll(Test1DReduce): self.outputs = {'Out': self.inputs['X'].sum()} -## reduction in multi dims -class TestReduceMeanOpMultiAxises(OpTest): - def setUp(self): - self.op_type = "reduce_mean" - self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float64")} - self.attrs = {'dim': [1, 2]} - self.outputs = {'Out': self.inputs['X'].mean(axis=(1, 2))} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - @skip_check_grad_ci( reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") @@ -420,40 +389,6 @@ class TestReduceSumWithNumelOne(OpTest): self.check_grad(['X'], 'Out') -class TestReduceMeanWithDimOne(OpTest): - def setUp(self): - self.op_type = "reduce_mean" - self.inputs = {'X': np.random.random((100, 1, 1)).astype("float64")} - self.attrs = {'dim': [1], 'keep_dim': False} - self.outputs = { - 'Out': self.inputs['X'].mean( - axis=tuple(self.attrs['dim']), keepdims=False) - } - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - -class TestReduceMeanWithNumelOne(OpTest): - def setUp(self): - self.op_type = "reduce_mean" - self.inputs = {'X': np.random.random((100, 1)).astype("float64")} - self.attrs = {'dim': [1], 'keep_dim': True} - self.outputs = { - 'Out': self.inputs['X'].mean( - axis=tuple(self.attrs['dim']), keepdims=True) - } - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - class TestReduceAll(OpTest): def setUp(self): self.op_type = "reduce_sum" @@ -536,18 +471,6 @@ class TestReduceSumOpError(unittest.TestCase): self.assertRaises(TypeError, fluid.layers.reduce_sum, x2) -class TestReduceMeanOpError(unittest.TestCase): - def test_errors(self): - with program_guard(Program(), Program()): - # The input type of reduce_mean_op must be Variable. - x1 = fluid.create_lod_tensor( - np.array([[-1]]), [[1]], fluid.CPUPlace()) - self.assertRaises(TypeError, fluid.layers.reduce_mean, x1) - # The input dtype of reduce_mean_op must be float32 or float64 or int32 or int64. - x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8") - self.assertRaises(TypeError, fluid.layers.reduce_mean, x2) - - class API_TestSumOpError(unittest.TestCase): def test_errors(self): def test_dtype1(): @@ -649,24 +572,5 @@ class API_TestSumOp(unittest.TestCase): self.assertTrue((out3 == np.sum(np_x, axis=(0, 1, 2))).all()) -class API_TestReduceMeanOp(unittest.TestCase): - def test_static(self): - with fluid.program_guard(fluid.Program(), fluid.Program()): - x = fluid.data("x", shape=[10, 10], dtype="float32") - out = fluid.layers.reduce_mean(input=x, dim=1) - place = fluid.CPUPlace() - exe = fluid.Executor(place) - x_np = np.random.rand(10, 10).astype(np.float32) - res = exe.run(feed={"x": x_np}, fetch_list=[out]) - self.assertEqual(np.allclose(res[0], np.mean(x_np, axis=1)), True) - - def test_dygraph(self): - with fluid.dygraph.guard(): - x_np = np.random.rand(10, 10).astype(np.float32) - x = fluid.dygraph.to_variable(x_np) - out = fluid.layers.reduce_mean(input=x, dim=1) - self.assertEqual(np.allclose(out.numpy(), np.mean(x_np, axis=1)), True) - - if __name__ == '__main__': unittest.main() -- GitLab