sum_op.cc 11.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
12
#include "paddle/fluid/operators/sum_op.h"
13

14
#include <algorithm>
M
minqiyang 已提交
15
#include <memory>
16
#include <string>
17
#include <unordered_map>
18
#include <vector>
19

Y
Yi Wang 已提交
20 21
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
22

23 24 25 26
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

27 28 29 30 31 32 33 34
namespace paddle {
namespace operators {
using framework::Tensor;

class SumOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

35
  void InferShape(framework::InferShapeContext* ctx) const override {
36 37
    PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
                      "Inputs(X) should not be null");
38

39 40
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
                      "Output(Out) of SumOp should not be null.");
41 42
    if (ctx->IsRuntime() &&
        ctx->GetOutputsVarType("Out")[0] ==
43
            framework::proto::VarType::LOD_TENSOR_ARRAY) {
44 45
      return;  // skip runtime infershape when is tensor array;
    }
46

47
    auto x_var_types = ctx->GetInputsVarType("X");
48
    auto x_dims = ctx->GetInputsDim("X");
49

50 51 52 53 54 55 56
    auto N = x_dims.size();
    PADDLE_ENFORCE_GT(
        N, 0,
        "ShapeError: The input tensor X's dimensions of SumOp "
        "should be larger than 0. But received X's dimensions %d, "
        "X's shape = [%s].",
        N, &x_dims);
57
    if (N == 1) {
58
      VLOG(3) << "Warning: SumOp have only one input, may waste memory";
59
    }
Q
qiaolongfei 已提交
60

61
    framework::DDim in_dim({0});
62
    for (size_t i = 0; i < x_dims.size(); ++i) {
63 64 65 66
      auto& x_dim = x_dims[i];
      // x_dim.size() == 1 means the real dim of selected rows is [0]
      if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS &&
          x_dim.size() == 1) {
67 68
        continue;
      }
69 70 71 72 73 74
      if (framework::product(x_dim) == 0) {
        continue;
      }
      if (framework::product(in_dim) == 0) {
        in_dim = x_dim;
      } else {
Z
zhaoyuchen 已提交
75
        if (ctx->IsRuntime()) {
76 77 78 79 80
          PADDLE_ENFORCE_EQ(
              in_dim, x_dim,
              "ShapeError: The input tensor X of SumOp must have same shape."
              "But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
              in_dim, i, x_dim);
Z
zhaoyuchen 已提交
81
        } else {
82 83 84 85 86 87
          PADDLE_ENFORCE_EQ(
              in_dim.size(), x_dim.size(),
              "ShapeError: The input tensor X of SumOp must have same "
              "dimensions. But received X[0]'s dimensions = %d, X[0]'s shape = "
              "[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].",
              in_dim.size(), in_dim, i, x_dim.size(), i, x_dim);
Z
zhaoyuchen 已提交
88
          // if in_dim or x_dim has -1, not check equal
89 90
          for (int j = 0; j < x_dim.size(); ++j) {
            if (x_dim[j] == -1 || in_dim[j] == -1) {
Z
zhaoyuchen 已提交
91 92
              continue;
            }
93 94 95 96 97 98
            PADDLE_ENFORCE_EQ(
                in_dim[j], x_dim[j],
                "ShapeError: The input tensor X of SumOp must have same shape "
                "if not -1."
                "But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
                in_dim, i, x_dim);
Z
zhaoyuchen 已提交
99 100
          }
        }
101
      }
Q
qijun 已提交
102
    }
Q
Qiao Longfei 已提交
103 104
    ctx->SetOutputDim("Out", in_dim);
    ctx->ShareLoD("X", /*->*/ "Out");
105
  }
106 107

 protected:
108
  framework::OpKernelType GetExpectedKernelType(
109 110
      const framework::ExecutionContext& ctx) const override {
    auto x_vars = ctx.MultiInputVar("X");
H
hong 已提交
111
    auto x_vars_name = ctx.InputNames("X");
112 113 114 115

    framework::LibraryType library{framework::LibraryType::kPlain};
    framework::DataLayout layout{framework::DataLayout::kAnyLayout};

116
    if (x_vars[0]->IsType<framework::LoDTensor>()) {
117
      int dtype = -1;
C
chengduo 已提交
118
      for (size_t idx = 0; idx < x_vars.size(); ++idx) {
119 120 121
        PADDLE_ENFORCE_NOT_NULL(x_vars[idx],
                                "Input var[%s] should not be nullptr",
                                x_vars_name[idx]);
C
chengduo 已提交
122 123
        auto tensor =
            framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_vars[idx]);
124
        if (tensor->numel() <= 0 || (!tensor->IsInitialized())) {
125 126 127
          continue;
        }
        if (dtype == -1) {
Y
Yu Yang 已提交
128
          dtype = tensor->type();
129
        } else {
Y
Yu Yang 已提交
130
          PADDLE_ENFORCE_EQ(dtype, tensor->type());
131 132 133 134 135
        }
      }
      PADDLE_ENFORCE_NE(dtype, -1,
                        "Sum operator should have at least one tensor");

136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
#ifdef PADDLE_WITH_MKLDNN
      if (library == framework::LibraryType::kPlain &&
          platform::CanMKLDNNBeUsed(ctx) &&
          static_cast<framework::proto::VarType::Type>(dtype) ==
              framework::proto::VarType::FP32 &&
          ctx.OutputVar("Out")->IsType<framework::LoDTensor>()) {
        if (std::all_of(x_vars.begin(), x_vars.end(),
                        [](const framework::Variable* v) {
                          return v->IsType<framework::LoDTensor>();
                        })) {
          return framework::OpKernelType(
              framework::proto::VarType::FP32, ctx.GetPlace(),
              framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN);
        }
      }
#endif

153
      return framework::OpKernelType(
154 155
          static_cast<framework::proto::VarType::Type>(dtype), ctx.GetPlace(),
          layout, library);
156
    } else if (x_vars[0]->IsType<framework::SelectedRows>()) {
157 158 159
      for (auto& var : x_vars) {
        auto& value = var->Get<framework::SelectedRows>().value();
        if (value.IsInitialized()) {
Y
Yu Yang 已提交
160 161
          return framework::OpKernelType(value.type(), ctx.device_context(),
                                         layout, library);
162 163 164 165
        }
      }
      // if input sparse vars are not initialized, use an default kernel type.
      return framework::OpKernelType(framework::proto::VarType::FP32,
166
                                     ctx.device_context(), layout, library);
167
    } else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
Y
Yang Yang(Tony) 已提交
168 169 170
      for (auto& x_var : x_vars) {
        auto& array = x_var->Get<framework::LoDTensorArray>();
        for (auto& each : array) {
171
          if (each.numel() != 0 && each.IsInitialized()) {
Y
Yu Yang 已提交
172 173
            return framework::OpKernelType(each.type(), ctx.device_context(),
                                           layout, library);
Y
Yang Yang(Tony) 已提交
174
          }
175 176
        }
      }
Y
Yang Yang(Tony) 已提交
177
      PADDLE_THROW("Cannot find the input data type by all input data");
178 179
    }
    PADDLE_THROW("Unexpected branch. Input type is %s",
S
sneaxiy 已提交
180
                 framework::ToTypeName(x_vars[0]->Type()));
181
  }
182 183 184 185
};

class SumOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
186
  void Make() override {
187 188 189 190 191
    AddInput("X",
             "A Varaible list. The shape and data type of the list elements"
             "should be consistent. Variable can be multi-dimensional Tensor"
             "or LoDTensor, and data types can be: float32, float64, int32, "
             "int64.")
192
        .AsDuplicable();
193 194 195
    AddOutput("Out",
              "the sum of input :code:`x`. its shape and data types are "
              "consistent with :code:`x`.");
196 197 198
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
199 200 201
    AddComment(R"DOC(This OP is used to sum one or more Tensor or LoDTensor
                    of the input. If the input is LoDTensor, the output only
                    shares LoD information with the first input.)DOC");
202 203 204
  }
};

Q
QI JUN 已提交
205 206
class SumOpVarTypeInference : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
207 208
  void operator()(framework::InferVarTypeContext* ctx) const override {
    auto& inputs = ctx->Input("X");
209
    auto var_type = framework::proto::VarType::SELECTED_ROWS;
M
minqiyang 已提交
210 211
    for (auto& name : ctx->Input("X")) {
      VLOG(10) << name << " " << ctx->GetType(name);
Y
Yang Yang(Tony) 已提交
212 213
    }

Q
QI JUN 已提交
214
    bool any_input_is_lod_tensor = std::any_of(
M
minqiyang 已提交
215 216
        inputs.begin(), inputs.end(), [ctx](const std::string& name) {
          return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR;
Q
QI JUN 已提交
217
        });
218

M
minqiyang 已提交
219 220
    auto is_tensor_array = [ctx](const std::string& name) {
      return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY;
221 222 223 224 225 226 227 228
    };

    bool any_input_is_tensor_array =
        std::any_of(inputs.begin(), inputs.end(), is_tensor_array);
    bool all_inputs_are_tensor_array =
        std::all_of(inputs.begin(), inputs.end(), is_tensor_array);

    if (any_input_is_tensor_array) {
Y
Yang Yang(Tony) 已提交
229 230 231
      if (!all_inputs_are_tensor_array) {
        std::ostringstream os;
        for (auto& each : inputs) {
M
minqiyang 已提交
232
          os << "    " << each << " type is " << ctx->GetType(each) << "\n";
Y
Yang Yang(Tony) 已提交
233
        }
234 235
        PADDLE_ENFORCE_EQ(all_inputs_are_tensor_array, true,
                          "Not all inputs are tensor array:\n%s", os.str());
Y
Yang Yang(Tony) 已提交
236
      }
237
      var_type = framework::proto::VarType::LOD_TENSOR_ARRAY;
238
    } else if (any_input_is_lod_tensor) {
239
      var_type = framework::proto::VarType::LOD_TENSOR;
Q
QI JUN 已提交
240 241
    }

M
minqiyang 已提交
242 243 244
    auto out_var_name = ctx->Output("Out").front();
    ctx->SetType(out_var_name, var_type);
    ctx->SetDataType(out_var_name, ctx->GetDataType(inputs.front()));
Q
QI JUN 已提交
245 246 247
  }
};

H
hong 已提交
248
class SumGradDescMaker : public framework::GradOpDescMakerBase {
249
 public:
250
  using framework::GradOpDescMakerBase::GradOpDescMakerBase;
251

Y
Yu Yang 已提交
252
  std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
253
    auto x_grads = InputGrad("X", false);
Y
Yu Yang 已提交
254
    std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
255 256 257 258
    grad_ops.reserve(x_grads.size());
    auto og = OutputGrad("Out");
    std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
                   [&og](const std::string& x_grad) {
Y
Yu Yang 已提交
259
                     auto* grad_op = new framework::OpDesc();
Y
Yu Yang 已提交
260 261 262 263
                     grad_op->SetType("scale");
                     grad_op->SetInput("X", og);
                     grad_op->SetOutput("Out", {x_grad});
                     grad_op->SetAttr("scale", 1.0f);
Y
Yu Yang 已提交
264
                     return std::unique_ptr<framework::OpDesc>(grad_op);
265
                   });
H
hong 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289

    return grad_ops;
  }
};

class SumGradOpBaseMaker : public imperative::GradOpBaseMakerBase {
 public:
  using imperative::GradOpBaseMakerBase::GradOpBaseMakerBase;

  std::vector<std::unique_ptr<imperative::OpBase>> operator()() const override {
    auto x_grads = InputGrad("X", false);
    std::vector<std::unique_ptr<imperative::OpBase>> grad_ops;
    grad_ops.reserve(x_grads.size());
    auto og = OutputGrad("Out");
    std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
                   [&og](const std::shared_ptr<imperative::VarBase>& x_grad) {
                     auto* grad_op = new imperative::OpBase();
                     grad_op->SetType("scale");
                     grad_op->SetInput("X", og);
                     grad_op->SetOutput("Out", {x_grad});
                     grad_op->SetAttr("scale", 1.0f);
                     return std::unique_ptr<imperative::OpBase>(grad_op);
                   });

290
    return grad_ops;
291 292 293
  }
};

294
DECLARE_INPLACE_OP_INFERER(SumInplace, {"X", "Out"});
295

296 297 298 299
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
300

H
hong 已提交
301 302 303
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradDescMaker,
                  ops::SumGradOpBaseMaker, ops::SumOpVarTypeInference,
                  ops::SumInplace);
304

Q
QI JUN 已提交
305 306 307 308 309
REGISTER_OP_CPU_KERNEL(
    sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, double>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, int>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, int64_t>);