提交 842b485f 编写于 作者: G guosheng

Enhance ReduceOp to support reducing over all elements

上级 0a8addf8
......@@ -37,18 +37,23 @@ class ReduceOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LT(
dim, x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
auto dims_vector = vectorize(x_dims);
if (keep_dim || x_rank == 1) {
dims_vector[dim] = 1;
bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
if (reduce_all) {
ctx->SetOutputDim("Out", {1});
} else {
dims_vector.erase(dims_vector.begin() + dim);
}
auto out_dims = framework::make_ddim(dims_vector);
ctx->SetOutputDim("Out", out_dims);
if (dim != 0) {
// Only pass LoD when not reducing on the first dim.
ctx->ShareLoD("X", /*->*/ "Out");
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
auto dims_vector = vectorize(x_dims);
if (keep_dim || x_rank == 1) {
dims_vector[dim] = 1;
} else {
dims_vector.erase(dims_vector.begin() + dim);
}
auto out_dims = framework::make_ddim(dims_vector);
ctx->SetOutputDim("Out", out_dims);
if (dim != 0) {
// Only pass LoD when not reducing on the first dim.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
}
};
......@@ -95,11 +100,16 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) "
"If true, retain the reduced dimension with length 1.")
.SetDefault(false);
AddAttr<bool>("reduce_all",
"(bool, default false) "
"If true, output a scalar reduced along all dimensions.")
.SetDefault(false);
comment_ = R"DOC(
{ReduceOp} Operator.
This operator computes the {reduce} of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless keep_dim is true.
If reduce_all is true, just reduce along all dimensions and output a scalar.
)DOC";
AddComment(comment_);
......
......@@ -26,10 +26,12 @@ using DDim = framework::DDim;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
struct SumFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
......@@ -95,26 +97,41 @@ template <typename DeviceContext, typename T, typename Functor>
class ReduceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size();
switch (rank) {
case 1:
ReduceCompute<1>(context);
break;
case 2:
ReduceCompute<2>(context);
break;
case 3:
ReduceCompute<3>(context);
break;
case 4:
ReduceCompute<4>(context);
break;
case 5:
ReduceCompute<5>(context);
break;
case 6:
ReduceCompute<6>(context);
break;
bool reduce_all = context.Attr<bool>("reduce_all");
if (reduce_all) {
// Flatten and reduce 1-D tensor
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto x = EigenVector<T>::Flatten(*input);
auto out = EigenScalar<T>::From(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}});
Functor functor;
functor(place, x, out, reduce_dim);
} else {
int rank = context.Input<Tensor>("X")->dims().size();
switch (rank) {
case 1:
ReduceCompute<1>(context);
break;
case 2:
ReduceCompute<2>(context);
break;
case 3:
ReduceCompute<3>(context);
break;
case 4:
ReduceCompute<4>(context);
break;
case 5:
ReduceCompute<5>(context);
break;
case 6:
ReduceCompute<6>(context);
break;
}
}
}
......@@ -157,26 +174,46 @@ template <typename DeviceContext, typename T, typename Functor>
class ReduceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size();
switch (rank) {
case 1:
ReduceGradCompute<1>(context);
break;
case 2:
ReduceGradCompute<2>(context);
break;
case 3:
ReduceGradCompute<3>(context);
break;
case 4:
ReduceGradCompute<4>(context);
break;
case 5:
ReduceGradCompute<5>(context);
break;
case 6:
ReduceGradCompute<6>(context);
break;
bool reduce_all = context.Attr<bool>("reduce_all");
if (reduce_all) {
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Out");
auto* input2 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
output->mutable_data<T>(context.GetPlace());
auto x = EigenVector<T>::Flatten(*input0);
auto x_reduce = EigenVector<T>::From(*input1);
auto x_reduce_grad = EigenVector<T>::From(*input2);
auto x_grad = EigenVector<T>::Flatten(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto broadcast_dim =
Eigen::array<int, 1>({{static_cast<int>(input0->numel())}});
Functor functor;
functor(place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim,
broadcast_dim[0]);
} else {
int rank = context.Input<Tensor>("X")->dims().size();
switch (rank) {
case 1:
ReduceGradCompute<1>(context);
break;
case 2:
ReduceGradCompute<2>(context);
break;
case 3:
ReduceGradCompute<3>(context);
break;
case 4:
ReduceGradCompute<4>(context);
break;
case 5:
ReduceGradCompute<5>(context);
break;
case 6:
ReduceGradCompute<6>(context);
break;
}
}
}
......
......@@ -85,5 +85,19 @@ class Test1DReduce(OpTest):
self.check_grad(['X'], 'Out')
class TestReduceAll(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")}
self.attrs = {'reduce_all': True}
self.outputs = {'Out': self.inputs['X'].sum()}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册