sum_op.cc 11.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
12
#include "paddle/fluid/operators/sum_op.h"
13

14
#include <algorithm>
M
minqiyang 已提交
15
#include <memory>
16
#include <string>
17
#include <unordered_map>
18
#include <vector>
19

Y
Yi Wang 已提交
20 21
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
22

23 24 25 26
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

27 28 29 30 31 32 33 34
namespace paddle {
namespace operators {
using framework::Tensor;

class SumOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

35
  void InferShape(framework::InferShapeContext* ctx) const override {
36 37
    PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
                      "Inputs(X) should not be null");
38

39 40
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
                      "Output(Out) of SumOp should not be null.");
41 42
    if (ctx->IsRuntime() &&
        ctx->GetOutputsVarType("Out")[0] ==
43
            framework::proto::VarType::LOD_TENSOR_ARRAY) {
44 45
      return;  // skip runtime infershape when is tensor array;
    }
46

47
    auto x_var_types = ctx->GetInputsVarType("X");
48
    auto x_dims = ctx->GetInputsDim("X");
49

50 51 52 53 54 55 56
    auto N = x_dims.size();
    PADDLE_ENFORCE_GT(
        N, 0,
        "ShapeError: The input tensor X's dimensions of SumOp "
        "should be larger than 0. But received X's dimensions %d, "
        "X's shape = [%s].",
        N, &x_dims);
57
    if (N == 1) {
58
      VLOG(3) << "Warning: SumOp have only one input, may waste memory";
59
    }
Q
qiaolongfei 已提交
60

61
    framework::DDim in_dim({0});
62
    for (size_t i = 0; i < x_dims.size(); ++i) {
63 64 65 66
      auto& x_dim = x_dims[i];
      // x_dim.size() == 1 means the real dim of selected rows is [0]
      if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS &&
          x_dim.size() == 1) {
67 68
        continue;
      }
69 70 71 72 73 74
      if (framework::product(x_dim) == 0) {
        continue;
      }
      if (framework::product(in_dim) == 0) {
        in_dim = x_dim;
      } else {
Z
zhaoyuchen 已提交
75
        if (ctx->IsRuntime()) {
76 77 78 79 80
          PADDLE_ENFORCE_EQ(
              in_dim, x_dim,
              "ShapeError: The input tensor X of SumOp must have same shape."
              "But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
              in_dim, i, x_dim);
Z
zhaoyuchen 已提交
81
        } else {
82 83 84 85 86 87
          PADDLE_ENFORCE_EQ(
              in_dim.size(), x_dim.size(),
              "ShapeError: The input tensor X of SumOp must have same "
              "dimensions. But received X[0]'s dimensions = %d, X[0]'s shape = "
              "[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].",
              in_dim.size(), in_dim, i, x_dim.size(), i, x_dim);
Z
zhaoyuchen 已提交
88
          // if in_dim or x_dim has -1, not check equal
89 90
          for (int j = 0; j < x_dim.size(); ++j) {
            if (x_dim[j] == -1 || in_dim[j] == -1) {
Z
zhaoyuchen 已提交
91 92
              continue;
            }
93 94 95 96 97 98
            PADDLE_ENFORCE_EQ(
                in_dim[j], x_dim[j],
                "ShapeError: The input tensor X of SumOp must have same shape "
                "if not -1."
                "But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
                in_dim, i, x_dim);
Z
zhaoyuchen 已提交
99 100
          }
        }
101
      }
Q
qijun 已提交
102
    }
Q
Qiao Longfei 已提交
103 104
    ctx->SetOutputDim("Out", in_dim);
    ctx->ShareLoD("X", /*->*/ "Out");
105
  }
106 107

 protected:
108
  framework::OpKernelType GetExpectedKernelType(
109 110
      const framework::ExecutionContext& ctx) const override {
    auto x_vars = ctx.MultiInputVar("X");
C
chengduo 已提交
111
    auto x_vars_name = ctx.Inputs("X");
112 113 114 115 116 117 118 119 120 121 122 123

    framework::LibraryType library{framework::LibraryType::kPlain};
    framework::DataLayout layout{framework::DataLayout::kAnyLayout};

#ifdef PADDLE_WITH_MKLDNN
    if (library == framework::LibraryType::kPlain &&
        platform::CanMKLDNNBeUsed(ctx)) {
      library = framework::LibraryType::kMKLDNN;
      layout = framework::DataLayout::kMKLDNN;
    }
#endif

124
    if (x_vars[0]->IsType<framework::LoDTensor>()) {
125
      int dtype = -1;
C
chengduo 已提交
126
      for (size_t idx = 0; idx < x_vars.size(); ++idx) {
127 128 129
        PADDLE_ENFORCE_NOT_NULL(x_vars[idx],
                                "Input var[%s] should not be nullptr",
                                x_vars_name[idx]);
C
chengduo 已提交
130 131
        auto tensor =
            framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_vars[idx]);
132
        if (tensor->numel() <= 0 || (!tensor->IsInitialized())) {
133 134 135
          continue;
        }
        if (dtype == -1) {
Y
Yu Yang 已提交
136
          dtype = tensor->type();
137
        } else {
Y
Yu Yang 已提交
138
          PADDLE_ENFORCE_EQ(dtype, tensor->type());
139 140 141 142 143
        }
      }
      PADDLE_ENFORCE_NE(dtype, -1,
                        "Sum operator should have at least one tensor");

144
      return framework::OpKernelType(
145 146
          static_cast<framework::proto::VarType::Type>(dtype), ctx.GetPlace(),
          layout, library);
147
    } else if (x_vars[0]->IsType<framework::SelectedRows>()) {
148 149 150
      for (auto& var : x_vars) {
        auto& value = var->Get<framework::SelectedRows>().value();
        if (value.IsInitialized()) {
Y
Yu Yang 已提交
151 152
          return framework::OpKernelType(value.type(), ctx.device_context(),
                                         layout, library);
153 154 155 156
        }
      }
      // if input sparse vars are not initialized, use an default kernel type.
      return framework::OpKernelType(framework::proto::VarType::FP32,
157
                                     ctx.device_context(), layout, library);
158
    } else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
Y
Yang Yang(Tony) 已提交
159 160 161
      for (auto& x_var : x_vars) {
        auto& array = x_var->Get<framework::LoDTensorArray>();
        for (auto& each : array) {
162
          if (each.numel() != 0 && each.IsInitialized()) {
Y
Yu Yang 已提交
163 164
            return framework::OpKernelType(each.type(), ctx.device_context(),
                                           layout, library);
Y
Yang Yang(Tony) 已提交
165
          }
166 167
        }
      }
Y
Yang Yang(Tony) 已提交
168
      PADDLE_THROW("Cannot find the input data type by all input data");
169 170
    }
    PADDLE_THROW("Unexpected branch. Input type is %s",
S
sneaxiy 已提交
171
                 framework::ToTypeName(x_vars[0]->Type()));
172
  }
173 174 175 176
};

class SumOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
177
  void Make() override {
178 179 180 181 182
    AddInput("X",
             "A Varaible list. The shape and data type of the list elements"
             "should be consistent. Variable can be multi-dimensional Tensor"
             "or LoDTensor, and data types can be: float32, float64, int32, "
             "int64.")
183
        .AsDuplicable();
184 185 186
    AddOutput("Out",
              "the sum of input :code:`x`. its shape and data types are "
              "consistent with :code:`x`.");
187 188 189
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
190 191 192
    AddComment(R"DOC(This OP is used to sum one or more Tensor or LoDTensor
                    of the input. If the input is LoDTensor, the output only
                    shares LoD information with the first input.)DOC");
193 194 195
  }
};

Q
QI JUN 已提交
196 197
class SumOpVarTypeInference : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
198 199
  void operator()(framework::InferVarTypeContext* ctx) const override {
    auto& inputs = ctx->Input("X");
200
    auto var_type = framework::proto::VarType::SELECTED_ROWS;
M
minqiyang 已提交
201 202
    for (auto& name : ctx->Input("X")) {
      VLOG(10) << name << " " << ctx->GetType(name);
Y
Yang Yang(Tony) 已提交
203 204
    }

Q
QI JUN 已提交
205
    bool any_input_is_lod_tensor = std::any_of(
M
minqiyang 已提交
206 207
        inputs.begin(), inputs.end(), [ctx](const std::string& name) {
          return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR;
Q
QI JUN 已提交
208
        });
209

M
minqiyang 已提交
210 211
    auto is_tensor_array = [ctx](const std::string& name) {
      return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY;
212 213 214 215 216 217 218 219
    };

    bool any_input_is_tensor_array =
        std::any_of(inputs.begin(), inputs.end(), is_tensor_array);
    bool all_inputs_are_tensor_array =
        std::all_of(inputs.begin(), inputs.end(), is_tensor_array);

    if (any_input_is_tensor_array) {
Y
Yang Yang(Tony) 已提交
220 221 222
      if (!all_inputs_are_tensor_array) {
        std::ostringstream os;
        for (auto& each : inputs) {
M
minqiyang 已提交
223
          os << "    " << each << " type is " << ctx->GetType(each) << "\n";
Y
Yang Yang(Tony) 已提交
224
        }
225 226
        PADDLE_ENFORCE_EQ(all_inputs_are_tensor_array, true,
                          "Not all inputs are tensor array:\n%s", os.str());
Y
Yang Yang(Tony) 已提交
227
      }
228
      var_type = framework::proto::VarType::LOD_TENSOR_ARRAY;
229
    } else if (any_input_is_lod_tensor) {
230
      var_type = framework::proto::VarType::LOD_TENSOR;
Q
QI JUN 已提交
231 232
    }

M
minqiyang 已提交
233 234 235
    auto out_var_name = ctx->Output("Out").front();
    ctx->SetType(out_var_name, var_type);
    ctx->SetDataType(out_var_name, ctx->GetDataType(inputs.front()));
Q
QI JUN 已提交
236 237 238
  }
};

H
hong 已提交
239
class SumGradDescMaker : public framework::GradOpDescMakerBase {
240
 public:
241
  using framework::GradOpDescMakerBase::GradOpDescMakerBase;
242

Y
Yu Yang 已提交
243
  std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
244
    auto x_grads = InputGrad("X", false);
Y
Yu Yang 已提交
245
    std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
246 247 248 249
    grad_ops.reserve(x_grads.size());
    auto og = OutputGrad("Out");
    std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
                   [&og](const std::string& x_grad) {
Y
Yu Yang 已提交
250
                     auto* grad_op = new framework::OpDesc();
Y
Yu Yang 已提交
251 252 253 254
                     grad_op->SetType("scale");
                     grad_op->SetInput("X", og);
                     grad_op->SetOutput("Out", {x_grad});
                     grad_op->SetAttr("scale", 1.0f);
Y
Yu Yang 已提交
255
                     return std::unique_ptr<framework::OpDesc>(grad_op);
256
                   });
H
hong 已提交
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280

    return grad_ops;
  }
};

class SumGradOpBaseMaker : public imperative::GradOpBaseMakerBase {
 public:
  using imperative::GradOpBaseMakerBase::GradOpBaseMakerBase;

  std::vector<std::unique_ptr<imperative::OpBase>> operator()() const override {
    auto x_grads = InputGrad("X", false);
    std::vector<std::unique_ptr<imperative::OpBase>> grad_ops;
    grad_ops.reserve(x_grads.size());
    auto og = OutputGrad("Out");
    std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
                   [&og](const std::shared_ptr<imperative::VarBase>& x_grad) {
                     auto* grad_op = new imperative::OpBase();
                     grad_op->SetType("scale");
                     grad_op->SetInput("X", og);
                     grad_op->SetOutput("Out", {x_grad});
                     grad_op->SetAttr("scale", 1.0f);
                     return std::unique_ptr<imperative::OpBase>(grad_op);
                   });

281
    return grad_ops;
282 283 284
  }
};

285
DECLARE_INPLACE_OP_INFERER(SumInplace, {"X", "Out"});
286

287 288 289 290
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
291

H
hong 已提交
292 293 294
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradDescMaker,
                  ops::SumGradOpBaseMaker, ops::SumOpVarTypeInference,
                  ops::SumInplace);
295

Q
QI JUN 已提交
296 297 298 299 300
REGISTER_OP_CPU_KERNEL(
    sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, double>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, int>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, int64_t>);