From 8655904be21df4731c2c8f2fd380027d670cefda Mon Sep 17 00:00:00 2001 From: whs Date: Wed, 23 May 2018 11:20:30 +0800 Subject: [PATCH] Enhance reduce op (#10708) * Enhance reduce op for multi dims. * Uncomment some unitest. * Uncomment unitest. * Remove unused code. * Fix infershape and python wrapper. * Add more examples. * Fix l2_normalize. * Fix normalization_wrapper. * Polish code. 1. Rename unitest function. 2. Rename const variable. --- paddle/fluid/operators/reduce_op.cc | 51 ++++++---- paddle/fluid/operators/reduce_op.h | 99 +++++++++++-------- python/paddle/fluid/layers/nn.py | 83 ++++++++++++---- .../fluid/tests/unittests/test_reduce_op.py | 85 ++++++++++++++-- 4 files changed, 233 insertions(+), 85 deletions(-) diff --git a/paddle/fluid/operators/reduce_op.cc b/paddle/fluid/operators/reduce_op.cc index eb8c21179db..e293fd5e410 100644 --- a/paddle/fluid/operators/reduce_op.cc +++ b/paddle/fluid/operators/reduce_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/reduce_op.h" +#include #include #include @@ -34,11 +35,14 @@ class ReduceOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto x_rank = x_dims.size(); PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); - int dim = ctx->Attrs().Get("dim"); - if (dim < 0) dim = x_rank + dim; - PADDLE_ENFORCE_LT( - dim, x_rank, - "The dim should be in the range [-rank(input), rank(input))."); + auto dims = ctx->Attrs().Get>("dim"); + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) dims[i] = x_rank + dims[i]; + PADDLE_ENFORCE_LT( + dims[i], x_rank, + "The dim should be in the range [-rank(input), rank(input))."); + } + sort(dims.begin(), dims.end()); bool reduce_all = ctx->Attrs().Get("reduce_all"); bool keep_dim = ctx->Attrs().Get("keep_dim"); if (reduce_all) { @@ -49,14 +53,22 @@ class ReduceOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", {1}); } else { auto dims_vector = vectorize(x_dims); - if (keep_dim || x_rank == 1) { - dims_vector[dim] = 1; + if (keep_dim) { + for (size_t i = 0; i < dims.size(); ++i) { + dims_vector[dims[i]] = 1; + } } else { - dims_vector.erase(dims_vector.begin() + dim); + const int kDelFlag = -2; + for (size_t i = 0; i < dims.size(); ++i) { + dims_vector[dims[i]] = kDelFlag; + } + dims_vector.erase( + remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); } auto out_dims = framework::make_ddim(dims_vector); ctx->SetOutputDim("Out", out_dims); - if (dim != 0) { + if (dims[0] != 0) { // Only pass LoD when not reducing on the first dim. ctx->ShareLoD("X", /*->*/ "Out"); } @@ -75,11 +87,14 @@ class ReduceGradOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto x_rank = x_dims.size(); PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); - int dim = ctx->Attrs().Get("dim"); - if (dim < 0) dim = x_rank + dim; - PADDLE_ENFORCE_LT( - dim, x_rank, - "The dim should be in the range [-rank(input), rank(input))."); + auto dims = ctx->Attrs().Get>("dim"); + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) dims[i] = x_rank + dims[i]; + PADDLE_ENFORCE_LT( + dims[i], x_rank, + "The dim should be in the range [-rank(input), rank(input))."); + } + sort(dims.begin(), dims.end()); auto x_grad_name = framework::GradVarName("X"); if (ctx->HasOutput(x_grad_name)) { ctx->SetOutputDim(x_grad_name, x_dims); @@ -95,13 +110,13 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) The input tensor. Tensors with rank at most 6 are " "supported."); AddOutput("Out", "(Tensor) The result tensor."); - AddAttr( + AddAttr>( "dim", - "(int, default 0) The dimension to reduce. " + "(list, default {0}) The dimensions to reduce. " "Must be in the range [-rank(input), rank(input)). " - "If `dim < 0`, the dim to reduce is `rank + dim`. " + "If `dim[i] < 0`, the dims[i] to reduce is `rank + dims[i]`. " "Note that reducing on the first dim will make the LoD info lost.") - .SetDefault(0); + .SetDefault({0}); AddAttr("keep_dim", "(bool, default false) " "If true, retain the reduced dimension with length 1.") diff --git a/paddle/fluid/operators/reduce_op.h b/paddle/fluid/operators/reduce_op.h index e42b4bfe42d..cd19cc1460a 100644 --- a/paddle/fluid/operators/reduce_op.h +++ b/paddle/fluid/operators/reduce_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "glog/logging.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -109,6 +110,11 @@ struct ProdGradFunctor { } }; +#define HANDLE_DIM(NDIM, RDIM) \ + if (ndim == NDIM && rdim == RDIM) { \ + ReduceCompute(context); \ + } + template class ReduceKernel : public framework::OpKernel { public: @@ -127,32 +133,29 @@ class ReduceKernel : public framework::OpKernel { Functor functor; functor(place, &x, &out, reduce_dim); } else { - int rank = context.Input("X")->dims().size(); - switch (rank) { - case 1: - ReduceCompute<1>(context); - break; - case 2: - ReduceCompute<2>(context); - break; - case 3: - ReduceCompute<3>(context); - break; - case 4: - ReduceCompute<4>(context); - break; - case 5: - ReduceCompute<5>(context); - break; - case 6: - ReduceCompute<6>(context); - break; - } + int ndim = context.Input("X")->dims().size(); + int rdim = context.Attr>("dim").size(); + HANDLE_DIM(6, 5); + HANDLE_DIM(6, 4); + HANDLE_DIM(6, 3); + HANDLE_DIM(6, 2); + HANDLE_DIM(6, 1); + HANDLE_DIM(5, 4); + HANDLE_DIM(5, 3); + HANDLE_DIM(5, 2); + HANDLE_DIM(5, 1); + HANDLE_DIM(4, 3); + HANDLE_DIM(4, 2); + HANDLE_DIM(4, 1); + HANDLE_DIM(3, 2); + HANDLE_DIM(3, 1); + HANDLE_DIM(2, 1); + HANDLE_DIM(1, 1); } } private: - template + template void ReduceCompute(const framework::ExecutionContext& context) const { auto* input = context.Input("X"); auto* output = context.Output("Out"); @@ -160,18 +163,26 @@ class ReduceKernel : public framework::OpKernel { auto x = EigenTensor::From(*input); auto x_rank = static_cast(x.dimensions().size()); - int dim = static_cast(context.Attr("dim")); - if (dim < 0) dim = x_rank + dim; - auto reduce_dim = Eigen::array({{dim}}); + auto dims = context.Attr>("dim"); + auto reduce_dim = Eigen::array(); + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) dims[i] = x_rank + dims[i]; + reduce_dim[i] = dims[i]; + } // construct the squeezed output tensor bool keep_dim = context.Attr("keep_dim"); - DDim dims = output->dims(); - auto dims_vector = vectorize(dims); + DDim out_dims = output->dims(); if (keep_dim && x_rank > 1) { - dims_vector.erase(dims_vector.begin() + dim); - dims = framework::make_ddim(dims_vector); + const int kDelFlag = -2; + auto dims_vector = vectorize(out_dims); + for (size_t i = 0; i < dims.size(); ++i) { + dims_vector[dims[i]] = kDelFlag; + } + dims_vector.erase( + remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); + out_dims = framework::make_ddim(dims_vector); } - auto& place = *context.template device_context().eigen_device(); Functor functor; @@ -180,7 +191,7 @@ class ReduceKernel : public framework::OpKernel { auto out = EigenScalar::From(*output); functor(place, &x, &out, reduce_dim); } else { - auto out = EigenTensor::From(*output, dims); + auto out = EigenTensor::From(*output, out_dims); functor(place, &x, &out, reduce_dim); } } @@ -245,21 +256,29 @@ class ReduceGradKernel : public framework::OpKernel { auto x = EigenTensor::From(*input0); auto x_grad = EigenTensor::From(*output); auto x_rank = static_cast(x.dimensions().size()); - int dim = static_cast(context.Attr("dim")); - if (dim < 0) dim = x_rank + dim; - DDim dims = input0->dims(); - dims[dim] = 1; - auto x_reduce = EigenTensor::From(*input1, dims); - auto x_reduce_grad = EigenTensor::From(*input2, dims); - + auto dims = context.Attr>("dim"); + auto x_dims = input0->dims(); + auto reduced_dims_v = vectorize(x_dims); Eigen::array broadcast_dim; for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1; - broadcast_dim[dim] = input0->dims()[dim]; + + int broad_cats_times = 1; + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) dims[i] = x_rank + dims[i]; + reduced_dims_v[dims[i]] = 1; + broadcast_dim[dims[i]] = x_dims[dims[i]]; + broad_cats_times *= x_dims[dims[i]]; + } + auto reduced_dims = framework::make_ddim(reduced_dims_v); + auto x_reduce = EigenTensor::From(*input1, reduced_dims); + auto x_reduce_grad = EigenTensor::From(*input2, reduced_dims); + auto& place = *context.template device_context().eigen_device(); + Functor functor; functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim, - broadcast_dim[dim]); + broad_cats_times); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 75f7ec2f853..1f2e483a096 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2082,11 +2082,11 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None): Args: input (Variable): The input variable which is a Tensor or LoDTensor. - dim (int|None): The dimension along which the sum is performed. If + dim (list|int|None): The dimensions along which the sum is performed. If :attr:`None`, sum all elements of :attr:`input` and return a Tensor variable with a single element, otherwise must be in the - range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`, - the dimension to reduce is :math:`rank + dim`. + range :math:`[-rank(input), rank(input))`. If :math:`dim[i] < 0`, + the dimension to reduce is :math:`rank + dim[i]`. keep_dim (bool|False): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true. @@ -2107,15 +2107,25 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_sum(x, dim=0) # [0.3, 0.5, 1.1, 1.6] fluid.layers.reduce_sum(x, dim=-1) # [1.9, 1.6] fluid.layers.reduce_sum(x, dim=1, keep_dim=True) # [[1.9], [1.6]] + + # x is a Tensor variable with shape [2, 2, 2] and elements as below: + # [[[1, 2], [3, 4]], + # [[5, 6], [7, 8]]] + # Each example is followed by the correspending output tensor. + fluid.layers.reduce_sum(x, dim=[1, 2]) # [10, 26] + fluid.layers.reduce_sum(x, dim=[0, 1]) # [16, 20] + """ helper = LayerHelper('reduce_sum', **locals()) out = helper.create_tmp_variable(dtype=helper.input_dtype()) + if dim is not None and not isinstance(dim, list): + dim = [dim] helper.append_op( type='reduce_sum', inputs={'X': input}, outputs={'Out': out}, attrs={ - 'dim': dim if dim != None else 0, + 'dim': dim if dim != None else [0], 'keep_dim': keep_dim, 'reduce_all': True if dim == None else False }) @@ -2128,11 +2138,11 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None): Args: input (Variable): The input variable which is a Tensor or LoDTensor. - dim (int|None): The dimension along which the mean is computed. If + dim (list|int|None): The dimensions along which the mean is computed. If :attr:`None`, compute the mean over all elements of :attr:`input` and return a Tensor variable with a single element, otherwise must be in the range :math:`[-rank(input), rank(input))`. If - :math:`dim < 0`, the dimension to reduce is :math:`rank + dim`. + :math:`dim[i] < 0`, the dimension to reduce is :math:`rank + dim[i]`. keep_dim (bool): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true. @@ -2153,15 +2163,24 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_mean(x, dim=0) # [0.15, 0.25, 0.55, 0.8] fluid.layers.reduce_mean(x, dim=-1) # [0.475, 0.4] fluid.layers.reduce_mean(x, dim=1, keep_dim=True) # [[0.475], [0.4]] + + # x is a Tensor variable with shape [2, 2, 2] and elements as below: + # [[[1.0, 2.0], [3.0, 4.0]], + # [[5.0, 6.0], [7.0, 8.0]]] + # Each example is followed by the correspending output tensor. + fluid.layers.reduce_mean(x, dim=[1, 2]) # [2.5, 6.5] + fluid.layers.reduce_mean(x, dim=[0, 1]) # [4.0, 5.0] """ helper = LayerHelper('reduce_mean', **locals()) out = helper.create_tmp_variable(dtype=helper.input_dtype()) + if dim is not None and not isinstance(dim, list): + dim = [dim] helper.append_op( type='reduce_mean', inputs={'X': input}, outputs={'Out': out}, attrs={ - 'dim': dim if dim != None else 0, + 'dim': dim if dim != None else [0], 'keep_dim': keep_dim, 'reduce_all': True if dim == None else False }) @@ -2174,11 +2193,11 @@ def reduce_max(input, dim=None, keep_dim=False, name=None): Args: input (Variable): The input variable which is a Tensor or LoDTensor. - dim (int|None): The dimension along which the maximum is computed. + dim (list|int|None): The dimension along which the maximum is computed. If :attr:`None`, compute the maximum over all elements of :attr:`input` and return a Tensor variable with a single element, otherwise must be in the range :math:`[-rank(input), rank(input))`. - If :math:`dim < 0`, the dimension to reduce is :math:`rank + dim`. + If :math:`dim[i] < 0`, the dimension to reduce is :math:`rank + dim[i]`. keep_dim (bool): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true. @@ -2199,15 +2218,24 @@ def reduce_max(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_max(x, dim=0) # [0.2, 0.3, 0.6, 0.9] fluid.layers.reduce_max(x, dim=-1) # [0.9, 0.7] fluid.layers.reduce_max(x, dim=1, keep_dim=True) # [[0.9], [0.7]] + + # x is a Tensor variable with shape [2, 2, 2] and elements as below: + # [[[1.0, 2.0], [3.0, 4.0]], + # [[5.0, 6.0], [7.0, 8.0]]] + # Each example is followed by the correspending output tensor. + fluid.layers.reduce_max(x, dim=[1, 2]) # [4.0, 8.0] + fluid.layers.reduce_max(x, dim=[0, 1]) # [7.0, 8.0] """ helper = LayerHelper('reduce_max', **locals()) out = helper.create_tmp_variable(dtype=helper.input_dtype()) + if dim is not None and not isinstance(dim, list): + dim = [dim] helper.append_op( type='reduce_max', inputs={'X': input}, outputs={'Out': out}, attrs={ - 'dim': dim if dim != None else 0, + 'dim': dim if dim != None else [0], 'keep_dim': keep_dim, 'reduce_all': True if dim == None else False }) @@ -2220,11 +2248,11 @@ def reduce_min(input, dim=None, keep_dim=False, name=None): Args: input (Variable): The input variable which is a Tensor or LoDTensor. - dim (int|None): The dimension along which the minimum is computed. + dim (list|int|None): The dimensions along which the minimum is computed. If :attr:`None`, compute the minimum over all elements of :attr:`input` and return a Tensor variable with a single element, otherwise must be in the range :math:`[-rank(input), rank(input))`. - If :math:`dim < 0`, the dimension to reduce is :math:`rank + dim`. + If :math:`dim[i] < 0`, the dimension to reduce is :math:`rank + dim[i]`. keep_dim (bool): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true. @@ -2245,15 +2273,24 @@ def reduce_min(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_min(x, dim=0) # [0.1, 0.2, 0.5, 0.7] fluid.layers.reduce_min(x, dim=-1) # [0.2, 0.1] fluid.layers.reduce_min(x, dim=1, keep_dim=True) # [[0.2], [0.1]] + + # x is a Tensor variable with shape [2, 2, 2] and elements as below: + # [[[1.0, 2.0], [3.0, 4.0]], + # [[5.0, 6.0], [7.0, 8.0]]] + # Each example is followed by the correspending output tensor. + fluid.layers.reduce_min(x, dim=[1, 2]) # [1.0, 5.0] + fluid.layers.reduce_min(x, dim=[0, 1]) # [1.0, 2.0] """ helper = LayerHelper('reduce_min', **locals()) out = helper.create_tmp_variable(dtype=helper.input_dtype()) + if dim is not None and not isinstance(dim, list): + dim = [dim] helper.append_op( type='reduce_min', inputs={'X': input}, outputs={'Out': out}, attrs={ - 'dim': dim if dim != None else 0, + 'dim': dim if dim != None else [0], 'keep_dim': keep_dim, 'reduce_all': True if dim == None else False }) @@ -2266,11 +2303,11 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): Args: input (Variable): The input variable which is a Tensor or LoDTensor. - dim (int|None): The dimension along which the product is performed. If + dim (list|int|None): The dimensions along which the product is performed. If :attr:`None`, multipy all elements of :attr:`input` and return a Tensor variable with a single element, otherwise must be in the - range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`, - the dimension to reduce is :math:`rank + dim`. + range :math:`[-rank(input), rank(input))`. If :math:`dim[i] < 0`, + the dimension to reduce is :math:`rank + dim[i]`. keep_dim (bool|False): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true. @@ -2292,15 +2329,24 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_prod(x, dim=-1) # [0.027, 0.0084] fluid.layers.reduce_prod(x, dim=1, keep_dim=True) # [[0.027], [0.0084]] + + # x is a Tensor variable with shape [2, 2, 2] and elements as below: + # [[[1.0, 2.0], [3.0, 4.0]], + # [[5.0, 6.0], [7.0, 8.0]]] + # Each example is followed by the correspending output tensor. + fluid.layers.reduce_prod(x, dim=[1, 2]) # [24.0, 1680.0] + fluid.layers.reduce_prod(x, dim=[0, 1]) # [105.0, 384.0] """ helper = LayerHelper('reduce_prod', **locals()) out = helper.create_tmp_variable(dtype=helper.input_dtype()) + if dim is not None and not isinstance(dim, list): + dim = [dim] helper.append_op( type='reduce_prod', inputs={'X': input}, outputs={'Out': out}, attrs={ - 'dim': dim if dim != None else 0, + 'dim': dim if dim != None else [0], 'keep_dim': keep_dim, 'reduce_all': True if dim == None else False }) @@ -2403,7 +2449,6 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None): if len(x.shape) == 1: axis = 0 - helper = LayerHelper("l2_normalize", **locals()) square = helper.create_tmp_variable(dtype=x.dtype) @@ -2415,7 +2460,7 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None): inputs={"X": square}, outputs={"Out": reduced_sum}, attrs={ - "dim": 1 if axis is None else axis, + "dim": [1] if axis is None else [axis], "keep_dim": True, "reduce_all": False }) diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 9b0cc3534dc..865c2b7df08 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -34,8 +34,10 @@ class TestMeanOp(OpTest): def setUp(self): self.op_type = "reduce_mean" self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float64")} - self.attrs = {'dim': 1} - self.outputs = {'Out': self.inputs['X'].mean(axis=self.attrs['dim'])} + self.attrs = {'dim': [1]} + self.outputs = { + 'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim'])) + } def test_check_output(self): self.check_output() @@ -50,8 +52,10 @@ class TestMaxOp(OpTest): def setUp(self): self.op_type = "reduce_max" self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} - self.attrs = {'dim': -1} - self.outputs = {'Out': self.inputs['X'].max(axis=self.attrs['dim'])} + self.attrs = {'dim': [-1]} + self.outputs = { + 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) + } def test_check_output(self): self.check_output() @@ -63,8 +67,10 @@ class TestMinOp(OpTest): def setUp(self): self.op_type = "reduce_min" self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} - self.attrs = {'dim': 2} - self.outputs = {'Out': self.inputs['X'].min(axis=self.attrs['dim'])} + self.attrs = {'dim': [2]} + self.outputs = { + 'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim'])) + } def test_check_output(self): self.check_output() @@ -87,9 +93,10 @@ class TestKeepDimReduce(OpTest): def setUp(self): self.op_type = "reduce_sum" self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} - self.attrs = {'dim': -2, 'keep_dim': True} + self.attrs = {'dim': [-2], 'keep_dim': True} self.outputs = { - 'Out': self.inputs['X'].sum(axis=self.attrs['dim'], keepdims=True) + 'Out': + self.inputs['X'].sum(axis=tuple(self.attrs['dim']), keepdims=True) } def test_check_output(self): @@ -126,5 +133,67 @@ class TestReduceAll(OpTest): self.check_grad(['X'], 'Out') +## reduction in multi dims +class TestReduceMeanOpMultiAxises(OpTest): + def setUp(self): + self.op_type = "reduce_mean" + self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float64")} + self.attrs = {'dim': [1, 2]} + self.outputs = {'Out': self.inputs['X'].mean(axis=(1, 2))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestReduceMaxOpMultiAxises(OpTest): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} + self.attrs = {'dim': [-2, -1]} + self.outputs = { + 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) + } + + def test_check_output(self): + self.check_output() + + +class TestReduceMinOpMultiAxises(OpTest): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} + self.attrs = {'dim': [1, 2]} + self.outputs = { + 'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim'])) + } + + def test_check_output(self): + self.check_output() + + +class TestKeepDimReduceSumMultiAxises(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} + self.attrs = {'dim': [-2, -1], 'keep_dim': True} + self.outputs = { + 'Out': + self.inputs['X'].sum(axis=tuple(self.attrs['dim']), keepdims=True) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + if __name__ == '__main__': unittest.main() -- GitLab