未验证 提交 7ab48aec 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #6586 from guoshengCS/enhance-ReduceOp

Enhance ReduceOp to support reducing over all elements
...@@ -37,18 +37,23 @@ class ReduceOp : public framework::OperatorWithKernel { ...@@ -37,18 +37,23 @@ class ReduceOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
dim, x_rank, dim, x_rank,
"The dim should be in the range [-rank(input), rank(input))."); "The dim should be in the range [-rank(input), rank(input)).");
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim"); bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
auto dims_vector = vectorize(x_dims); if (reduce_all) {
if (keep_dim || x_rank == 1) { ctx->SetOutputDim("Out", {1});
dims_vector[dim] = 1;
} else { } else {
dims_vector.erase(dims_vector.begin() + dim); bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
} auto dims_vector = vectorize(x_dims);
auto out_dims = framework::make_ddim(dims_vector); if (keep_dim || x_rank == 1) {
ctx->SetOutputDim("Out", out_dims); dims_vector[dim] = 1;
if (dim != 0) { } else {
// Only pass LoD when not reducing on the first dim. dims_vector.erase(dims_vector.begin() + dim);
ctx->ShareLoD("X", /*->*/ "Out"); }
auto out_dims = framework::make_ddim(dims_vector);
ctx->SetOutputDim("Out", out_dims);
if (dim != 0) {
// Only pass LoD when not reducing on the first dim.
ctx->ShareLoD("X", /*->*/ "Out");
}
} }
} }
}; };
...@@ -95,11 +100,16 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -95,11 +100,16 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) " "(bool, default false) "
"If true, retain the reduced dimension with length 1.") "If true, retain the reduced dimension with length 1.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("reduce_all",
"(bool, default false) "
"If true, output a scalar reduced along all dimensions.")
.SetDefault(false);
comment_ = R"DOC( comment_ = R"DOC(
{ReduceOp} Operator. {ReduceOp} Operator.
This operator computes the {reduce} of input tensor along the given dimension. This operator computes the {reduce} of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless keep_dim is true. The result tensor has 1 fewer dimension than the input unless keep_dim is true.
If reduce_all is true, just reduce along all dimensions and output a scalar.
)DOC"; )DOC";
AddComment(comment_); AddComment(comment_);
......
...@@ -26,10 +26,12 @@ using DDim = framework::DDim; ...@@ -26,10 +26,12 @@ using DDim = framework::DDim;
template <typename T, size_t D, int MajorType = Eigen::RowMajor, template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>; using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
struct SumFunctor { struct SumFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim> template <typename DeviceContext, typename X, typename Y, typename Dim>
...@@ -95,26 +97,41 @@ template <typename DeviceContext, typename T, typename Functor> ...@@ -95,26 +97,41 @@ template <typename DeviceContext, typename T, typename Functor>
class ReduceKernel : public framework::OpKernel<T> { class ReduceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size(); bool reduce_all = context.Attr<bool>("reduce_all");
switch (rank) { if (reduce_all) {
case 1: // Flatten and reduce 1-D tensor
ReduceCompute<1>(context); auto* input = context.Input<Tensor>("X");
break; auto* output = context.Output<Tensor>("Out");
case 2: output->mutable_data<T>(context.GetPlace());
ReduceCompute<2>(context); auto x = EigenVector<T>::Flatten(*input);
break; auto out = EigenScalar<T>::From(*output);
case 3: auto& place =
ReduceCompute<3>(context); *context.template device_context<DeviceContext>().eigen_device();
break; auto reduce_dim = Eigen::array<int, 1>({{0}});
case 4: Functor functor;
ReduceCompute<4>(context); functor(place, x, out, reduce_dim);
break; } else {
case 5: int rank = context.Input<Tensor>("X")->dims().size();
ReduceCompute<5>(context); switch (rank) {
break; case 1:
case 6: ReduceCompute<1>(context);
ReduceCompute<6>(context); break;
break; case 2:
ReduceCompute<2>(context);
break;
case 3:
ReduceCompute<3>(context);
break;
case 4:
ReduceCompute<4>(context);
break;
case 5:
ReduceCompute<5>(context);
break;
case 6:
ReduceCompute<6>(context);
break;
}
} }
} }
...@@ -157,26 +174,46 @@ template <typename DeviceContext, typename T, typename Functor> ...@@ -157,26 +174,46 @@ template <typename DeviceContext, typename T, typename Functor>
class ReduceGradKernel : public framework::OpKernel<T> { class ReduceGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size(); bool reduce_all = context.Attr<bool>("reduce_all");
switch (rank) { if (reduce_all) {
case 1: auto* input0 = context.Input<Tensor>("X");
ReduceGradCompute<1>(context); auto* input1 = context.Input<Tensor>("Out");
break; auto* input2 = context.Input<Tensor>(framework::GradVarName("Out"));
case 2: auto* output = context.Output<Tensor>(framework::GradVarName("X"));
ReduceGradCompute<2>(context); output->mutable_data<T>(context.GetPlace());
break; auto x = EigenVector<T>::Flatten(*input0);
case 3: auto x_reduce = EigenVector<T>::From(*input1);
ReduceGradCompute<3>(context); auto x_reduce_grad = EigenVector<T>::From(*input2);
break; auto x_grad = EigenVector<T>::Flatten(*output);
case 4: auto& place =
ReduceGradCompute<4>(context); *context.template device_context<DeviceContext>().eigen_device();
break; auto broadcast_dim =
case 5: Eigen::array<int, 1>({{static_cast<int>(input0->numel())}});
ReduceGradCompute<5>(context); Functor functor;
break; functor(place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim,
case 6: broadcast_dim[0]);
ReduceGradCompute<6>(context); } else {
break; int rank = context.Input<Tensor>("X")->dims().size();
switch (rank) {
case 1:
ReduceGradCompute<1>(context);
break;
case 2:
ReduceGradCompute<2>(context);
break;
case 3:
ReduceGradCompute<3>(context);
break;
case 4:
ReduceGradCompute<4>(context);
break;
case 5:
ReduceGradCompute<5>(context);
break;
case 6:
ReduceGradCompute<6>(context);
break;
}
} }
} }
......
...@@ -85,5 +85,19 @@ class Test1DReduce(OpTest): ...@@ -85,5 +85,19 @@ class Test1DReduce(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestReduceAll(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")}
self.attrs = {'reduce_all': True}
self.outputs = {'Out': self.inputs['X'].sum()}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册