提交 842b485f 编写于 作者: G guosheng

Enhance ReduceOp to support reducing over all elements

上级 0a8addf8
...@@ -37,6 +37,10 @@ class ReduceOp : public framework::OperatorWithKernel { ...@@ -37,6 +37,10 @@ class ReduceOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
dim, x_rank, dim, x_rank,
"The dim should be in the range [-rank(input), rank(input))."); "The dim should be in the range [-rank(input), rank(input)).");
bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
if (reduce_all) {
ctx->SetOutputDim("Out", {1});
} else {
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim"); bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
auto dims_vector = vectorize(x_dims); auto dims_vector = vectorize(x_dims);
if (keep_dim || x_rank == 1) { if (keep_dim || x_rank == 1) {
...@@ -51,6 +55,7 @@ class ReduceOp : public framework::OperatorWithKernel { ...@@ -51,6 +55,7 @@ class ReduceOp : public framework::OperatorWithKernel {
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
} }
}
}; };
class ReduceGradOp : public framework::OperatorWithKernel { class ReduceGradOp : public framework::OperatorWithKernel {
...@@ -95,11 +100,16 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -95,11 +100,16 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) " "(bool, default false) "
"If true, retain the reduced dimension with length 1.") "If true, retain the reduced dimension with length 1.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("reduce_all",
"(bool, default false) "
"If true, output a scalar reduced along all dimensions.")
.SetDefault(false);
comment_ = R"DOC( comment_ = R"DOC(
{ReduceOp} Operator. {ReduceOp} Operator.
This operator computes the {reduce} of input tensor along the given dimension. This operator computes the {reduce} of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless keep_dim is true. The result tensor has 1 fewer dimension than the input unless keep_dim is true.
If reduce_all is true, just reduce along all dimensions and output a scalar.
)DOC"; )DOC";
AddComment(comment_); AddComment(comment_);
......
...@@ -26,10 +26,12 @@ using DDim = framework::DDim; ...@@ -26,10 +26,12 @@ using DDim = framework::DDim;
template <typename T, size_t D, int MajorType = Eigen::RowMajor, template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>; using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
struct SumFunctor { struct SumFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim> template <typename DeviceContext, typename X, typename Y, typename Dim>
...@@ -95,6 +97,20 @@ template <typename DeviceContext, typename T, typename Functor> ...@@ -95,6 +97,20 @@ template <typename DeviceContext, typename T, typename Functor>
class ReduceKernel : public framework::OpKernel<T> { class ReduceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
if (reduce_all) {
// Flatten and reduce 1-D tensor
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto x = EigenVector<T>::Flatten(*input);
auto out = EigenScalar<T>::From(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}});
Functor functor;
functor(place, x, out, reduce_dim);
} else {
int rank = context.Input<Tensor>("X")->dims().size(); int rank = context.Input<Tensor>("X")->dims().size();
switch (rank) { switch (rank) {
case 1: case 1:
...@@ -117,6 +133,7 @@ class ReduceKernel : public framework::OpKernel<T> { ...@@ -117,6 +133,7 @@ class ReduceKernel : public framework::OpKernel<T> {
break; break;
} }
} }
}
private: private:
template <size_t D> template <size_t D>
...@@ -157,6 +174,25 @@ template <typename DeviceContext, typename T, typename Functor> ...@@ -157,6 +174,25 @@ template <typename DeviceContext, typename T, typename Functor>
class ReduceGradKernel : public framework::OpKernel<T> { class ReduceGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
if (reduce_all) {
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Out");
auto* input2 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
output->mutable_data<T>(context.GetPlace());
auto x = EigenVector<T>::Flatten(*input0);
auto x_reduce = EigenVector<T>::From(*input1);
auto x_reduce_grad = EigenVector<T>::From(*input2);
auto x_grad = EigenVector<T>::Flatten(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto broadcast_dim =
Eigen::array<int, 1>({{static_cast<int>(input0->numel())}});
Functor functor;
functor(place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim,
broadcast_dim[0]);
} else {
int rank = context.Input<Tensor>("X")->dims().size(); int rank = context.Input<Tensor>("X")->dims().size();
switch (rank) { switch (rank) {
case 1: case 1:
...@@ -179,6 +215,7 @@ class ReduceGradKernel : public framework::OpKernel<T> { ...@@ -179,6 +215,7 @@ class ReduceGradKernel : public framework::OpKernel<T> {
break; break;
} }
} }
}
private: private:
template <size_t D> template <size_t D>
......
...@@ -85,5 +85,19 @@ class Test1DReduce(OpTest): ...@@ -85,5 +85,19 @@ class Test1DReduce(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestReduceAll(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")}
self.attrs = {'reduce_all': True}
self.outputs = {'Out': self.inputs['X'].sum()}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册