From 0eb1b0bc0110e8aa16d0d51198701e1885edc009 Mon Sep 17 00:00:00 2001 From: wawltor Date: Tue, 9 Jun 2020 17:48:52 +0800 Subject: [PATCH] Add support the 5d, 6d tensor support for the reduce ops Add the support the 5d,6d tensor support for the reduce ops; Add the same time, the compile time, it was 22 minutes, it was 21 minutes after fixed. --- paddle/fluid/operators/reduce_ops/reduce_op.h | 9 ++++++ .../fluid/tests/unittests/test_reduce_op.py | 30 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 383eea9d073..4673dc258d0 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -68,6 +68,15 @@ struct ReduceKernelFunctor { } else { int ndim = input->dims().size(); int rdim = dims.size(); + HANDLE_DIM(6, 5); + HANDLE_DIM(6, 4); + HANDLE_DIM(6, 3); + HANDLE_DIM(6, 2); + HANDLE_DIM(6, 1); + HANDLE_DIM(5, 4); + HANDLE_DIM(5, 3); + HANDLE_DIM(5, 2); + HANDLE_DIM(5, 1); HANDLE_DIM(4, 3); HANDLE_DIM(4, 2); HANDLE_DIM(4, 1); diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 99c519b6293..16874d80112 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -37,6 +37,36 @@ class TestSumOp(OpTest): self.check_grad(['X'], 'Out') +class TestSumOp5D(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = { + 'X': np.random.random((1, 2, 5, 6, 10)).astype("float64") + } + self.outputs = {'Out': self.inputs['X'].sum(axis=0)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestSumOp6D(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = { + 'X': np.random.random((1, 1, 2, 5, 6, 10)).astype("float64") + } + self.outputs = {'Out': self.inputs['X'].sum(axis=0)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + class TestMeanOp(OpTest): def setUp(self): self.op_type = "reduce_mean" -- GitLab