sum_op.cc 9.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
12
#include "paddle/fluid/operators/sum_op.h"
13

14
#include <algorithm>
M
minqiyang 已提交
15
#include <memory>
16
#include <string>
17
#include <unordered_map>
18
#include <vector>
19

Y
Yi Wang 已提交
20 21
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
22

23 24 25 26
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

27 28 29 30 31 32 33 34
namespace paddle {
namespace operators {
using framework::Tensor;

class SumOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

35
  void InferShape(framework::InferShapeContext* ctx) const override {
36 37
    PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
                      "Inputs(X) should not be null");
38

39 40
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
                      "Output(Out) of SumOp should not be null.");
41 42
    if (ctx->IsRuntime() &&
        ctx->GetOutputsVarType("Out")[0] ==
43
            framework::proto::VarType::LOD_TENSOR_ARRAY) {
44 45
      return;  // skip runtime infershape when is tensor array;
    }
46

47
    auto x_var_types = ctx->GetInputsVarType("X");
48
    auto x_dims = ctx->GetInputsDim("X");
49

50 51 52 53 54 55 56
    auto N = x_dims.size();
    PADDLE_ENFORCE_GT(
        N, 0,
        "ShapeError: The input tensor X's dimensions of SumOp "
        "should be larger than 0. But received X's dimensions %d, "
        "X's shape = [%s].",
        N, &x_dims);
57
    if (N == 1) {
58
      VLOG(3) << "Warning: SumOp have only one input, may waste memory";
59
    }
Q
qiaolongfei 已提交
60

61
    framework::DDim in_dim({0});
62
    for (size_t i = 0; i < x_dims.size(); ++i) {
63 64 65 66
      auto& x_dim = x_dims[i];
      // x_dim.size() == 1 means the real dim of selected rows is [0]
      if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS &&
          x_dim.size() == 1) {
67 68
        continue;
      }
69 70 71 72 73 74
      if (framework::product(x_dim) == 0) {
        continue;
      }
      if (framework::product(in_dim) == 0) {
        in_dim = x_dim;
      } else {
Z
zhaoyuchen 已提交
75
        if (ctx->IsRuntime()) {
76 77 78 79 80
          PADDLE_ENFORCE_EQ(
              in_dim, x_dim,
              "ShapeError: The input tensor X of SumOp must have same shape."
              "But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
              in_dim, i, x_dim);
Z
zhaoyuchen 已提交
81
        } else {
82 83 84 85 86 87
          PADDLE_ENFORCE_EQ(
              in_dim.size(), x_dim.size(),
              "ShapeError: The input tensor X of SumOp must have same "
              "dimensions. But received X[0]'s dimensions = %d, X[0]'s shape = "
              "[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].",
              in_dim.size(), in_dim, i, x_dim.size(), i, x_dim);
Z
zhaoyuchen 已提交
88
          // if in_dim or x_dim has -1, not check equal
89 90
          for (int j = 0; j < x_dim.size(); ++j) {
            if (x_dim[j] == -1 || in_dim[j] == -1) {
Z
zhaoyuchen 已提交
91 92
              continue;
            }
93 94 95 96 97 98
            PADDLE_ENFORCE_EQ(
                in_dim[j], x_dim[j],
                "ShapeError: The input tensor X of SumOp must have same shape "
                "if not -1."
                "But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
                in_dim, i, x_dim);
Z
zhaoyuchen 已提交
99 100
          }
        }
101
      }
Q
qijun 已提交
102
    }
Q
Qiao Longfei 已提交
103 104
    ctx->SetOutputDim("Out", in_dim);
    ctx->ShareLoD("X", /*->*/ "Out");
105
  }
106 107

 protected:
108
  framework::OpKernelType GetExpectedKernelType(
109 110
      const framework::ExecutionContext& ctx) const override {
    auto x_vars = ctx.MultiInputVar("X");
C
chengduo 已提交
111
    auto x_vars_name = ctx.Inputs("X");
112 113 114 115 116 117 118 119 120 121 122 123

    framework::LibraryType library{framework::LibraryType::kPlain};
    framework::DataLayout layout{framework::DataLayout::kAnyLayout};

#ifdef PADDLE_WITH_MKLDNN
    if (library == framework::LibraryType::kPlain &&
        platform::CanMKLDNNBeUsed(ctx)) {
      library = framework::LibraryType::kMKLDNN;
      layout = framework::DataLayout::kMKLDNN;
    }
#endif

124
    if (x_vars[0]->IsType<framework::LoDTensor>()) {
125
      int dtype = -1;
C
chengduo 已提交
126
      for (size_t idx = 0; idx < x_vars.size(); ++idx) {
127 128 129
        PADDLE_ENFORCE_NOT_NULL(x_vars[idx],
                                "Input var[%s] should not be nullptr",
                                x_vars_name[idx]);
C
chengduo 已提交
130 131
        auto tensor =
            framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_vars[idx]);
132
        if (tensor->numel() <= 0 || (!tensor->IsInitialized())) {
133 134 135
          continue;
        }
        if (dtype == -1) {
Y
Yu Yang 已提交
136
          dtype = tensor->type();
137
        } else {
Y
Yu Yang 已提交
138
          PADDLE_ENFORCE_EQ(dtype, tensor->type());
139 140 141 142 143
        }
      }
      PADDLE_ENFORCE_NE(dtype, -1,
                        "Sum operator should have at least one tensor");

144
      return framework::OpKernelType(
145 146
          static_cast<framework::proto::VarType::Type>(dtype), ctx.GetPlace(),
          layout, library);
147
    } else if (x_vars[0]->IsType<framework::SelectedRows>()) {
148 149 150
      for (auto& var : x_vars) {
        auto& value = var->Get<framework::SelectedRows>().value();
        if (value.IsInitialized()) {
Y
Yu Yang 已提交
151 152
          return framework::OpKernelType(value.type(), ctx.device_context(),
                                         layout, library);
153 154 155 156
        }
      }
      // if input sparse vars are not initialized, use an default kernel type.
      return framework::OpKernelType(framework::proto::VarType::FP32,
157
                                     ctx.device_context(), layout, library);
158
    } else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
Y
Yang Yang(Tony) 已提交
159 160 161
      for (auto& x_var : x_vars) {
        auto& array = x_var->Get<framework::LoDTensorArray>();
        for (auto& each : array) {
162
          if (each.numel() != 0 && each.IsInitialized()) {
Y
Yu Yang 已提交
163 164
            return framework::OpKernelType(each.type(), ctx.device_context(),
                                           layout, library);
Y
Yang Yang(Tony) 已提交
165
          }
166 167
        }
      }
Y
Yang Yang(Tony) 已提交
168
      PADDLE_THROW("Cannot find the input data type by all input data");
169 170
    }
    PADDLE_THROW("Unexpected branch. Input type is %s",
S
sneaxiy 已提交
171
                 framework::ToTypeName(x_vars[0]->Type()));
172
  }
173 174 175 176
};

class SumOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
177
  void Make() override {
178 179
    AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
        .AsDuplicable();
180
    AddOutput("Out", "(Tensor) The output tensor of sum operator.");
181 182 183
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
184
    AddComment(R"DOC(
185
Sum operator.
186

187 188
This operators sums the input tensors. All the inputs can carry the
LoD (Level of Details) information. However, the output only shares
189
the LoD information with the first input.
190
)DOC");
191 192 193
  }
};

Q
QI JUN 已提交
194 195
class SumOpVarTypeInference : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
196 197
  void operator()(framework::InferVarTypeContext* ctx) const override {
    auto& inputs = ctx->Input("X");
198
    auto var_type = framework::proto::VarType::SELECTED_ROWS;
M
minqiyang 已提交
199 200
    for (auto& name : ctx->Input("X")) {
      VLOG(10) << name << " " << ctx->GetType(name);
Y
Yang Yang(Tony) 已提交
201 202
    }

Q
QI JUN 已提交
203
    bool any_input_is_lod_tensor = std::any_of(
M
minqiyang 已提交
204 205
        inputs.begin(), inputs.end(), [ctx](const std::string& name) {
          return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR;
Q
QI JUN 已提交
206
        });
207

M
minqiyang 已提交
208 209
    auto is_tensor_array = [ctx](const std::string& name) {
      return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY;
210 211 212 213 214 215 216 217
    };

    bool any_input_is_tensor_array =
        std::any_of(inputs.begin(), inputs.end(), is_tensor_array);
    bool all_inputs_are_tensor_array =
        std::all_of(inputs.begin(), inputs.end(), is_tensor_array);

    if (any_input_is_tensor_array) {
Y
Yang Yang(Tony) 已提交
218 219 220
      if (!all_inputs_are_tensor_array) {
        std::ostringstream os;
        for (auto& each : inputs) {
M
minqiyang 已提交
221
          os << "    " << each << " type is " << ctx->GetType(each) << "\n";
Y
Yang Yang(Tony) 已提交
222
        }
223 224
        PADDLE_ENFORCE_EQ(all_inputs_are_tensor_array, true,
                          "Not all inputs are tensor array:\n%s", os.str());
Y
Yang Yang(Tony) 已提交
225
      }
226
      var_type = framework::proto::VarType::LOD_TENSOR_ARRAY;
227
    } else if (any_input_is_lod_tensor) {
228
      var_type = framework::proto::VarType::LOD_TENSOR;
Q
QI JUN 已提交
229 230
    }

M
minqiyang 已提交
231 232 233
    auto out_var_name = ctx->Output("Out").front();
    ctx->SetType(out_var_name, var_type);
    ctx->SetDataType(out_var_name, ctx->GetDataType(inputs.front()));
Q
QI JUN 已提交
234 235 236
  }
};

237
class SumGradMaker : public framework::GradOpDescMakerBase {
238
 public:
239
  using framework::GradOpDescMakerBase::GradOpDescMakerBase;
240

Y
Yu Yang 已提交
241
  std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
242
    auto x_grads = InputGrad("X", false);
Y
Yu Yang 已提交
243
    std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
244 245 246 247
    grad_ops.reserve(x_grads.size());
    auto og = OutputGrad("Out");
    std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
                   [&og](const std::string& x_grad) {
Y
Yu Yang 已提交
248
                     auto* grad_op = new framework::OpDesc();
Y
Yu Yang 已提交
249 250 251 252
                     grad_op->SetType("scale");
                     grad_op->SetInput("X", og);
                     grad_op->SetOutput("Out", {x_grad});
                     grad_op->SetAttr("scale", 1.0f);
Y
Yu Yang 已提交
253
                     return std::unique_ptr<framework::OpDesc>(grad_op);
254 255
                   });
    return grad_ops;
256 257 258
  }
};

259
DECLARE_INPLACE_OP_INFERER(SumInplace, {"X", "Out"});
260

261 262 263 264
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
265

Q
QI JUN 已提交
266
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker,
267
                  ops::SumOpVarTypeInference, ops::SumInplace);
268

Q
QI JUN 已提交
269 270 271 272 273
REGISTER_OP_CPU_KERNEL(
    sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, double>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, int>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, int64_t>);