未验证 提交 8655904b 编写于 作者: W whs 提交者: GitHub

Enhance reduce op (#10708)

* Enhance reduce op for multi dims.

* Uncomment some unitest.

* Uncomment unitest.

* Remove unused code.

* Fix infershape and python wrapper.

* Add more examples.

* Fix l2_normalize.

* Fix normalization_wrapper.

* Polish code.
1. Rename unitest function.
2. Rename const variable.
上级 051a4b39
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/reduce_op.h"
#include <algorithm>
#include <string>
#include <vector>
......@@ -34,11 +35,14 @@ class ReduceOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
int dim = ctx->Attrs().Get<int>("dim");
if (dim < 0) dim = x_rank + dim;
PADDLE_ENFORCE_LT(
dim, x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i];
PADDLE_ENFORCE_LT(
dims[i], x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
}
sort(dims.begin(), dims.end());
bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
if (reduce_all) {
......@@ -49,14 +53,22 @@ class ReduceOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", {1});
} else {
auto dims_vector = vectorize(x_dims);
if (keep_dim || x_rank == 1) {
dims_vector[dim] = 1;
if (keep_dim) {
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = 1;
}
} else {
dims_vector.erase(dims_vector.begin() + dim);
const int kDelFlag = -2;
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = kDelFlag;
}
dims_vector.erase(
remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
}
auto out_dims = framework::make_ddim(dims_vector);
ctx->SetOutputDim("Out", out_dims);
if (dim != 0) {
if (dims[0] != 0) {
// Only pass LoD when not reducing on the first dim.
ctx->ShareLoD("X", /*->*/ "Out");
}
......@@ -75,11 +87,14 @@ class ReduceGradOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
int dim = ctx->Attrs().Get<int>("dim");
if (dim < 0) dim = x_rank + dim;
PADDLE_ENFORCE_LT(
dim, x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i];
PADDLE_ENFORCE_LT(
dims[i], x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
}
sort(dims.begin(), dims.end());
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
......@@ -95,13 +110,13 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor) The input tensor. Tensors with rank at most 6 are "
"supported.");
AddOutput("Out", "(Tensor) The result tensor.");
AddAttr<int>(
AddAttr<std::vector<int>>(
"dim",
"(int, default 0) The dimension to reduce. "
"(list<int>, default {0}) The dimensions to reduce. "
"Must be in the range [-rank(input), rank(input)). "
"If `dim < 0`, the dim to reduce is `rank + dim`. "
"If `dim[i] < 0`, the dims[i] to reduce is `rank + dims[i]`. "
"Note that reducing on the first dim will make the LoD info lost.")
.SetDefault(0);
.SetDefault({0});
AddAttr<bool>("keep_dim",
"(bool, default false) "
"If true, retain the reduced dimension with length 1.")
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -109,6 +110,11 @@ struct ProdGradFunctor {
}
};
#define HANDLE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
ReduceCompute<NDIM, RDIM>(context); \
}
template <typename DeviceContext, typename T, typename Functor>
class ReduceKernel : public framework::OpKernel<T> {
public:
......@@ -127,32 +133,29 @@ class ReduceKernel : public framework::OpKernel<T> {
Functor functor;
functor(place, &x, &out, reduce_dim);
} else {
int rank = context.Input<Tensor>("X")->dims().size();
switch (rank) {
case 1:
ReduceCompute<1>(context);
break;
case 2:
ReduceCompute<2>(context);
break;
case 3:
ReduceCompute<3>(context);
break;
case 4:
ReduceCompute<4>(context);
break;
case 5:
ReduceCompute<5>(context);
break;
case 6:
ReduceCompute<6>(context);
break;
}
int ndim = context.Input<Tensor>("X")->dims().size();
int rdim = context.Attr<std::vector<int>>("dim").size();
HANDLE_DIM(6, 5);
HANDLE_DIM(6, 4);
HANDLE_DIM(6, 3);
HANDLE_DIM(6, 2);
HANDLE_DIM(6, 1);
HANDLE_DIM(5, 4);
HANDLE_DIM(5, 3);
HANDLE_DIM(5, 2);
HANDLE_DIM(5, 1);
HANDLE_DIM(4, 3);
HANDLE_DIM(4, 2);
HANDLE_DIM(4, 1);
HANDLE_DIM(3, 2);
HANDLE_DIM(3, 1);
HANDLE_DIM(2, 1);
HANDLE_DIM(1, 1);
}
}
private:
template <size_t D>
template <size_t D, size_t R_D>
void ReduceCompute(const framework::ExecutionContext& context) const {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
......@@ -160,18 +163,26 @@ class ReduceKernel : public framework::OpKernel<T> {
auto x = EigenTensor<T, D>::From(*input);
auto x_rank = static_cast<int>(x.dimensions().size());
int dim = static_cast<int>(context.Attr<int>("dim"));
if (dim < 0) dim = x_rank + dim;
auto reduce_dim = Eigen::array<int, 1>({{dim}});
auto dims = context.Attr<std::vector<int>>("dim");
auto reduce_dim = Eigen::array<int, R_D>();
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i];
reduce_dim[i] = dims[i];
}
// construct the squeezed output tensor
bool keep_dim = context.Attr<bool>("keep_dim");
DDim dims = output->dims();
auto dims_vector = vectorize(dims);
DDim out_dims = output->dims();
if (keep_dim && x_rank > 1) {
dims_vector.erase(dims_vector.begin() + dim);
dims = framework::make_ddim(dims_vector);
const int kDelFlag = -2;
auto dims_vector = vectorize(out_dims);
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = kDelFlag;
}
dims_vector.erase(
remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
out_dims = framework::make_ddim(dims_vector);
}
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Functor functor;
......@@ -180,7 +191,7 @@ class ReduceKernel : public framework::OpKernel<T> {
auto out = EigenScalar<T>::From(*output);
functor(place, &x, &out, reduce_dim);
} else {
auto out = EigenTensor<T, (D - 1)>::From(*output, dims);
auto out = EigenTensor<T, (D - R_D)>::From(*output, out_dims);
functor(place, &x, &out, reduce_dim);
}
}
......@@ -245,21 +256,29 @@ class ReduceGradKernel : public framework::OpKernel<T> {
auto x = EigenTensor<T, D>::From(*input0);
auto x_grad = EigenTensor<T, D>::From(*output);
auto x_rank = static_cast<int>(x.dimensions().size());
int dim = static_cast<int>(context.Attr<int>("dim"));
if (dim < 0) dim = x_rank + dim;
DDim dims = input0->dims();
dims[dim] = 1;
auto x_reduce = EigenTensor<T, D>::From(*input1, dims);
auto x_reduce_grad = EigenTensor<T, D>::From(*input2, dims);
auto dims = context.Attr<std::vector<int>>("dim");
auto x_dims = input0->dims();
auto reduced_dims_v = vectorize(x_dims);
Eigen::array<int, D> broadcast_dim;
for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1;
broadcast_dim[dim] = input0->dims()[dim];
int broad_cats_times = 1;
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i];
reduced_dims_v[dims[i]] = 1;
broadcast_dim[dims[i]] = x_dims[dims[i]];
broad_cats_times *= x_dims[dims[i]];
}
auto reduced_dims = framework::make_ddim(reduced_dims_v);
auto x_reduce = EigenTensor<T, D>::From(*input1, reduced_dims);
auto x_reduce_grad = EigenTensor<T, D>::From(*input2, reduced_dims);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Functor functor;
functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim,
broadcast_dim[dim]);
broad_cats_times);
}
};
......
......@@ -2082,11 +2082,11 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
Args:
input (Variable): The input variable which is a Tensor or LoDTensor.
dim (int|None): The dimension along which the sum is performed. If
dim (list|int|None): The dimensions along which the sum is performed. If
:attr:`None`, sum all elements of :attr:`input` and return a
Tensor variable with a single element, otherwise must be in the
range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`,
the dimension to reduce is :math:`rank + dim`.
range :math:`[-rank(input), rank(input))`. If :math:`dim[i] < 0`,
the dimension to reduce is :math:`rank + dim[i]`.
keep_dim (bool|False): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true.
......@@ -2107,15 +2107,25 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_sum(x, dim=0) # [0.3, 0.5, 1.1, 1.6]
fluid.layers.reduce_sum(x, dim=-1) # [1.9, 1.6]
fluid.layers.reduce_sum(x, dim=1, keep_dim=True) # [[1.9], [1.6]]
# x is a Tensor variable with shape [2, 2, 2] and elements as below:
# [[[1, 2], [3, 4]],
# [[5, 6], [7, 8]]]
# Each example is followed by the correspending output tensor.
fluid.layers.reduce_sum(x, dim=[1, 2]) # [10, 26]
fluid.layers.reduce_sum(x, dim=[0, 1]) # [16, 20]
"""
helper = LayerHelper('reduce_sum', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list):
dim = [dim]
helper.append_op(
type='reduce_sum',
inputs={'X': input},
outputs={'Out': out},
attrs={
'dim': dim if dim != None else 0,
'dim': dim if dim != None else [0],
'keep_dim': keep_dim,
'reduce_all': True if dim == None else False
})
......@@ -2128,11 +2138,11 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None):
Args:
input (Variable): The input variable which is a Tensor or LoDTensor.
dim (int|None): The dimension along which the mean is computed. If
dim (list|int|None): The dimensions along which the mean is computed. If
:attr:`None`, compute the mean over all elements of :attr:`input`
and return a Tensor variable with a single element, otherwise
must be in the range :math:`[-rank(input), rank(input))`. If
:math:`dim < 0`, the dimension to reduce is :math:`rank + dim`.
:math:`dim[i] < 0`, the dimension to reduce is :math:`rank + dim[i]`.
keep_dim (bool): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true.
......@@ -2153,15 +2163,24 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_mean(x, dim=0) # [0.15, 0.25, 0.55, 0.8]
fluid.layers.reduce_mean(x, dim=-1) # [0.475, 0.4]
fluid.layers.reduce_mean(x, dim=1, keep_dim=True) # [[0.475], [0.4]]
# x is a Tensor variable with shape [2, 2, 2] and elements as below:
# [[[1.0, 2.0], [3.0, 4.0]],
# [[5.0, 6.0], [7.0, 8.0]]]
# Each example is followed by the correspending output tensor.
fluid.layers.reduce_mean(x, dim=[1, 2]) # [2.5, 6.5]
fluid.layers.reduce_mean(x, dim=[0, 1]) # [4.0, 5.0]
"""
helper = LayerHelper('reduce_mean', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list):
dim = [dim]
helper.append_op(
type='reduce_mean',
inputs={'X': input},
outputs={'Out': out},
attrs={
'dim': dim if dim != None else 0,
'dim': dim if dim != None else [0],
'keep_dim': keep_dim,
'reduce_all': True if dim == None else False
})
......@@ -2174,11 +2193,11 @@ def reduce_max(input, dim=None, keep_dim=False, name=None):
Args:
input (Variable): The input variable which is a Tensor or LoDTensor.
dim (int|None): The dimension along which the maximum is computed.
dim (list|int|None): The dimension along which the maximum is computed.
If :attr:`None`, compute the maximum over all elements of
:attr:`input` and return a Tensor variable with a single element,
otherwise must be in the range :math:`[-rank(input), rank(input))`.
If :math:`dim < 0`, the dimension to reduce is :math:`rank + dim`.
If :math:`dim[i] < 0`, the dimension to reduce is :math:`rank + dim[i]`.
keep_dim (bool): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true.
......@@ -2199,15 +2218,24 @@ def reduce_max(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_max(x, dim=0) # [0.2, 0.3, 0.6, 0.9]
fluid.layers.reduce_max(x, dim=-1) # [0.9, 0.7]
fluid.layers.reduce_max(x, dim=1, keep_dim=True) # [[0.9], [0.7]]
# x is a Tensor variable with shape [2, 2, 2] and elements as below:
# [[[1.0, 2.0], [3.0, 4.0]],
# [[5.0, 6.0], [7.0, 8.0]]]
# Each example is followed by the correspending output tensor.
fluid.layers.reduce_max(x, dim=[1, 2]) # [4.0, 8.0]
fluid.layers.reduce_max(x, dim=[0, 1]) # [7.0, 8.0]
"""
helper = LayerHelper('reduce_max', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list):
dim = [dim]
helper.append_op(
type='reduce_max',
inputs={'X': input},
outputs={'Out': out},
attrs={
'dim': dim if dim != None else 0,
'dim': dim if dim != None else [0],
'keep_dim': keep_dim,
'reduce_all': True if dim == None else False
})
......@@ -2220,11 +2248,11 @@ def reduce_min(input, dim=None, keep_dim=False, name=None):
Args:
input (Variable): The input variable which is a Tensor or LoDTensor.
dim (int|None): The dimension along which the minimum is computed.
dim (list|int|None): The dimensions along which the minimum is computed.
If :attr:`None`, compute the minimum over all elements of
:attr:`input` and return a Tensor variable with a single element,
otherwise must be in the range :math:`[-rank(input), rank(input))`.
If :math:`dim < 0`, the dimension to reduce is :math:`rank + dim`.
If :math:`dim[i] < 0`, the dimension to reduce is :math:`rank + dim[i]`.
keep_dim (bool): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true.
......@@ -2245,15 +2273,24 @@ def reduce_min(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_min(x, dim=0) # [0.1, 0.2, 0.5, 0.7]
fluid.layers.reduce_min(x, dim=-1) # [0.2, 0.1]
fluid.layers.reduce_min(x, dim=1, keep_dim=True) # [[0.2], [0.1]]
# x is a Tensor variable with shape [2, 2, 2] and elements as below:
# [[[1.0, 2.0], [3.0, 4.0]],
# [[5.0, 6.0], [7.0, 8.0]]]
# Each example is followed by the correspending output tensor.
fluid.layers.reduce_min(x, dim=[1, 2]) # [1.0, 5.0]
fluid.layers.reduce_min(x, dim=[0, 1]) # [1.0, 2.0]
"""
helper = LayerHelper('reduce_min', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list):
dim = [dim]
helper.append_op(
type='reduce_min',
inputs={'X': input},
outputs={'Out': out},
attrs={
'dim': dim if dim != None else 0,
'dim': dim if dim != None else [0],
'keep_dim': keep_dim,
'reduce_all': True if dim == None else False
})
......@@ -2266,11 +2303,11 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None):
Args:
input (Variable): The input variable which is a Tensor or LoDTensor.
dim (int|None): The dimension along which the product is performed. If
dim (list|int|None): The dimensions along which the product is performed. If
:attr:`None`, multipy all elements of :attr:`input` and return a
Tensor variable with a single element, otherwise must be in the
range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`,
the dimension to reduce is :math:`rank + dim`.
range :math:`[-rank(input), rank(input))`. If :math:`dim[i] < 0`,
the dimension to reduce is :math:`rank + dim[i]`.
keep_dim (bool|False): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true.
......@@ -2292,15 +2329,24 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_prod(x, dim=-1) # [0.027, 0.0084]
fluid.layers.reduce_prod(x, dim=1,
keep_dim=True) # [[0.027], [0.0084]]
# x is a Tensor variable with shape [2, 2, 2] and elements as below:
# [[[1.0, 2.0], [3.0, 4.0]],
# [[5.0, 6.0], [7.0, 8.0]]]
# Each example is followed by the correspending output tensor.
fluid.layers.reduce_prod(x, dim=[1, 2]) # [24.0, 1680.0]
fluid.layers.reduce_prod(x, dim=[0, 1]) # [105.0, 384.0]
"""
helper = LayerHelper('reduce_prod', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list):
dim = [dim]
helper.append_op(
type='reduce_prod',
inputs={'X': input},
outputs={'Out': out},
attrs={
'dim': dim if dim != None else 0,
'dim': dim if dim != None else [0],
'keep_dim': keep_dim,
'reduce_all': True if dim == None else False
})
......@@ -2403,7 +2449,6 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
if len(x.shape) == 1:
axis = 0
helper = LayerHelper("l2_normalize", **locals())
square = helper.create_tmp_variable(dtype=x.dtype)
......@@ -2415,7 +2460,7 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
inputs={"X": square},
outputs={"Out": reduced_sum},
attrs={
"dim": 1 if axis is None else axis,
"dim": [1] if axis is None else [axis],
"keep_dim": True,
"reduce_all": False
})
......
......@@ -34,8 +34,10 @@ class TestMeanOp(OpTest):
def setUp(self):
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float64")}
self.attrs = {'dim': 1}
self.outputs = {'Out': self.inputs['X'].mean(axis=self.attrs['dim'])}
self.attrs = {'dim': [1]}
self.outputs = {
'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim']))
}
def test_check_output(self):
self.check_output()
......@@ -50,8 +52,10 @@ class TestMaxOp(OpTest):
def setUp(self):
self.op_type = "reduce_max"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.attrs = {'dim': -1}
self.outputs = {'Out': self.inputs['X'].max(axis=self.attrs['dim'])}
self.attrs = {'dim': [-1]}
self.outputs = {
'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim']))
}
def test_check_output(self):
self.check_output()
......@@ -63,8 +67,10 @@ class TestMinOp(OpTest):
def setUp(self):
self.op_type = "reduce_min"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.attrs = {'dim': 2}
self.outputs = {'Out': self.inputs['X'].min(axis=self.attrs['dim'])}
self.attrs = {'dim': [2]}
self.outputs = {
'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim']))
}
def test_check_output(self):
self.check_output()
......@@ -87,9 +93,10 @@ class TestKeepDimReduce(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.attrs = {'dim': -2, 'keep_dim': True}
self.attrs = {'dim': [-2], 'keep_dim': True}
self.outputs = {
'Out': self.inputs['X'].sum(axis=self.attrs['dim'], keepdims=True)
'Out':
self.inputs['X'].sum(axis=tuple(self.attrs['dim']), keepdims=True)
}
def test_check_output(self):
......@@ -126,5 +133,67 @@ class TestReduceAll(OpTest):
self.check_grad(['X'], 'Out')
## reduction in multi dims
class TestReduceMeanOpMultiAxises(OpTest):
def setUp(self):
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float64")}
self.attrs = {'dim': [1, 2]}
self.outputs = {'Out': self.inputs['X'].mean(axis=(1, 2))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestReduceMaxOpMultiAxises(OpTest):
"""Remove Max with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_max"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.attrs = {'dim': [-2, -1]}
self.outputs = {
'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim']))
}
def test_check_output(self):
self.check_output()
class TestReduceMinOpMultiAxises(OpTest):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_min"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.attrs = {'dim': [1, 2]}
self.outputs = {
'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim']))
}
def test_check_output(self):
self.check_output()
class TestKeepDimReduceSumMultiAxises(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.attrs = {'dim': [-2, -1], 'keep_dim': True}
self.outputs = {
'Out':
self.inputs['X'].sum(axis=tuple(self.attrs['dim']), keepdims=True)
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册