reduce_op.cc 8.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
guosheng 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
guosheng 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
guosheng 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
G
guosheng 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/reduce_op.h"
G
guosheng 已提交
16

Y
Yang Yang 已提交
17 18 19
#include <string>
#include <vector>

G
guosheng 已提交
20 21 22 23 24 25 26 27 28
namespace paddle {
namespace operators {

using framework::Tensor;

class ReduceOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

29
  void InferShape(framework::InferShapeContext *ctx) const override {
30 31 32 33 34
    PADDLE_ENFORCE(ctx->HasInput("X"),
                   "Input(X) of ReduceOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
                   "Output(Out) of ReduceOp should not be null.");
    auto x_dims = ctx->GetInputDim("X");
G
guosheng 已提交
35
    auto x_rank = x_dims.size();
G
guosheng 已提交
36
    PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
37
    int dim = ctx->Attrs().Get<int>("dim");
G
guosheng 已提交
38 39 40
    if (dim < 0) dim = x_rank + dim;
    PADDLE_ENFORCE_LT(
        dim, x_rank,
G
guosheng 已提交
41
        "The dim should be in the range [-rank(input), rank(input)).");
42
    bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
43
    bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
44
    if (reduce_all) {
45 46 47 48 49
      if (keep_dim)
        ctx->SetOutputDim(
            "Out", framework::make_ddim(std::vector<int64_t>(x_rank, 1)));
      else
        ctx->SetOutputDim("Out", {1});
G
guosheng 已提交
50
    } else {
51 52 53 54 55 56 57 58 59 60 61 62
      auto dims_vector = vectorize(x_dims);
      if (keep_dim || x_rank == 1) {
        dims_vector[dim] = 1;
      } else {
        dims_vector.erase(dims_vector.begin() + dim);
      }
      auto out_dims = framework::make_ddim(dims_vector);
      ctx->SetOutputDim("Out", out_dims);
      if (dim != 0) {
        // Only pass LoD when not reducing on the first dim.
        ctx->ShareLoD("X", /*->*/ "Out");
      }
63
    }
G
guosheng 已提交
64 65 66 67 68 69 70
  }
};

class ReduceGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

71
  void InferShape(framework::InferShapeContext *ctx) const override {
72 73 74 75
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
                   "Input(Out@GRAD) should not be null.");
    auto x_dims = ctx->GetInputDim("X");
G
guosheng 已提交
76
    auto x_rank = x_dims.size();
G
guosheng 已提交
77
    PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
78
    int dim = ctx->Attrs().Get<int>("dim");
G
guosheng 已提交
79 80 81
    if (dim < 0) dim = x_rank + dim;
    PADDLE_ENFORCE_LT(
        dim, x_rank,
G
guosheng 已提交
82
        "The dim should be in the range [-rank(input), rank(input)).");
83 84 85
    auto x_grad_name = framework::GradVarName("X");
    if (ctx->HasOutput(x_grad_name)) {
      ctx->SetOutputDim(x_grad_name, x_dims);
86
      ctx->ShareLoD("X", /*->*/ x_grad_name);
87
    }
G
guosheng 已提交
88 89 90
  }
};

G
guosheng 已提交
91
class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
G
guosheng 已提交
92
 public:
93
  ReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
G
guosheng 已提交
94
      : OpProtoAndCheckerMaker(proto, op_checker) {
K
kexinzhao 已提交
95 96 97
    AddInput("X",
             "(Tensor) The input tensor. Tensors with rank at most 6 are "
             "supported.");
G
guosheng 已提交
98
    AddOutput("Out", "(Tensor) The result tensor.");
99 100
    AddAttr<int>(
        "dim",
K
kexinzhao 已提交
101
        "(int, default 0) The dimension to reduce. "
102 103
        "Must be in the range [-rank(input), rank(input)). "
        "If `dim < 0`, the dim to reduce is `rank + dim`. "
K
kexinzhao 已提交
104
        "Note that reducing on the first dim will make the LoD info lost.")
105
        .SetDefault(0);
G
guosheng 已提交
106 107 108 109
    AddAttr<bool>("keep_dim",
                  "(bool, default false) "
                  "If true, retain the reduced dimension with length 1.")
        .SetDefault(false);
110 111 112 113
    AddAttr<bool>("reduce_all",
                  "(bool, default false) "
                  "If true, output a scalar reduced along all dimensions.")
        .SetDefault(false);
G
guosheng 已提交
114
    comment_ = R"DOC(
K
kexinzhao 已提交
115 116 117 118
{ReduceOp} Operator.

This operator computes the {reduce} of input tensor along the given dimension. 
The result tensor has 1 fewer dimension than the input unless keep_dim is true.
119
If reduce_all is true, just reduce along all dimensions and output a scalar.
K
kexinzhao 已提交
120

G
guosheng 已提交
121 122 123 124 125 126 127
)DOC";
    AddComment(comment_);
  }

 protected:
  std::string comment_;

Y
Yang Yang 已提交
128
  void Replace(std::string *src, std::string from, std::string to) {
G
guosheng 已提交
129 130
    std::size_t len_from = std::strlen(from.c_str());
    std::size_t len_to = std::strlen(to.c_str());
Y
Yang Yang 已提交
131 132 133
    for (std::size_t pos = src->find(from); pos != std::string::npos;
         pos = src->find(from, pos + len_to)) {
      src->replace(pos, len_from, to);
G
guosheng 已提交
134 135 136 137
    }
  }

  void SetComment(std::string name, std::string op) {
Y
Yang Yang 已提交
138 139
    Replace(&comment_, "{ReduceOp}", name);
    Replace(&comment_, "{reduce}", op);
G
guosheng 已提交
140 141 142
  }
};

G
guosheng 已提交
143 144
class ReduceSumOpMaker : public ReduceOpMaker {
 public:
145
  ReduceSumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
G
guosheng 已提交
146 147 148 149 150 151 152
      : ReduceOpMaker(proto, op_checker) {
    SetComment("ReduceSum", "sum");
    AddComment(comment_);
  }
};

class ReduceMeanOpMaker : public ReduceOpMaker {
G
guosheng 已提交
153
 public:
154
  ReduceMeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
G
guosheng 已提交
155 156 157
      : ReduceOpMaker(proto, op_checker) {
    SetComment("ReduceMean", "mean");
    AddComment(comment_);
G
guosheng 已提交
158 159 160
  }
};

G
guosheng 已提交
161
class ReduceMaxOpMaker : public ReduceOpMaker {
G
guosheng 已提交
162
 public:
163
  ReduceMaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
G
guosheng 已提交
164 165 166
      : ReduceOpMaker(proto, op_checker) {
    SetComment("ReduceMax", "max");
    AddComment(comment_);
G
guosheng 已提交
167 168 169
  }
};

G
guosheng 已提交
170
class ReduceMinOpMaker : public ReduceOpMaker {
G
guosheng 已提交
171
 public:
172
  ReduceMinOpMaker(OpProto *proto, OpAttrChecker *op_checker)
G
guosheng 已提交
173 174 175
      : ReduceOpMaker(proto, op_checker) {
    SetComment("ReduceMin", "min");
    AddComment(comment_);
G
guosheng 已提交
176 177 178
  }
};

179 180 181 182
class ReduceProdOpMaker : public ReduceOpMaker {
 public:
  ReduceProdOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : ReduceOpMaker(proto, op_checker) {
Z
zhouhanqing 已提交
183
    SetComment("ReduceProd", "production");
184 185 186 187
    AddComment(comment_);
  }
};

G
guosheng 已提交
188 189 190 191 192
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

Y
Yang Yang 已提交
193 194 195
REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ops::ReduceSumOpMaker,
                  paddle::framework::DefaultGradOpDescMaker<true>)
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp)
G
guosheng 已提交
196

Y
Yang Yang 已提交
197 198 199
REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, ops::ReduceMeanOpMaker,
                  paddle::framework::DefaultGradOpDescMaker<true>)
REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp)
G
guosheng 已提交
200

Y
Yang Yang 已提交
201 202 203
REGISTER_OPERATOR(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker,
                  paddle::framework::DefaultGradOpDescMaker<true>)
REGISTER_OPERATOR(reduce_max_grad, ops::ReduceGradOp)
G
guosheng 已提交
204

Y
Yang Yang 已提交
205 206 207
REGISTER_OPERATOR(reduce_min, ops::ReduceOp, ops::ReduceMinOpMaker,
                  paddle::framework::DefaultGradOpDescMaker<true>)
REGISTER_OPERATOR(reduce_min_grad, ops::ReduceGradOp)
G
guosheng 已提交
208

Y
Yang Yang 已提交
209 210 211
REGISTER_OPERATOR(reduce_prod, ops::ReduceOp, ops::ReduceProdOpMaker,
                  paddle::framework::DefaultGradOpDescMaker<true>)
REGISTER_OPERATOR(reduce_prod_grad, ops::ReduceGradOp)
212

Q
QI JUN 已提交
213 214 215
#define REGISTER_REDUCE_CPU_KERNEL(reduce_type, functor, grad_functor)         \
  REGISTER_OP_CPU_KERNEL(reduce_type,                                          \
                         ops::ReduceKernel<paddle::platform::CPUDeviceContext, \
216 217 218 219 220 221 222
                                           float, ops::functor>,               \
                         ops::ReduceKernel<paddle::platform::CPUDeviceContext, \
                                           double, ops::functor>,              \
                         ops::ReduceKernel<paddle::platform::CPUDeviceContext, \
                                           int, ops::functor>,                 \
                         ops::ReduceKernel<paddle::platform::CPUDeviceContext, \
                                           int64_t, ops::functor>);            \
Q
QI JUN 已提交
223 224 225
  REGISTER_OP_CPU_KERNEL(                                                      \
      reduce_type##_grad,                                                      \
      ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, float,         \
226 227 228 229 230 231
                            ops::grad_functor>,                                \
      ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, double,        \
                            ops::grad_functor>,                                \
      ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, int,           \
                            ops::grad_functor>,                                \
      ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, int64_t,       \
Q
QI JUN 已提交
232
                            ops::grad_functor>);
233 234

FOR_EACH_KERNEL_FUNCTOR(REGISTER_REDUCE_CPU_KERNEL);