sum_op.cc 9.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
12
#include "paddle/fluid/operators/sum_op.h"
13

14
#include <algorithm>
M
minqiyang 已提交
15
#include <memory>
16
#include <string>
17
#include <unordered_map>
18
#include <vector>
19

Y
Yi Wang 已提交
20 21
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
22

23 24 25 26
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

27 28 29 30 31 32 33 34
namespace paddle {
namespace operators {
using framework::Tensor;

class SumOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

35
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
36
    PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null");
37

Q
Qiao Longfei 已提交
38 39
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
                   "Output(Out) of SumOp should not be null.");
40 41
    if (ctx->IsRuntime() &&
        ctx->GetOutputsVarType("Out")[0] ==
42
            framework::proto::VarType::LOD_TENSOR_ARRAY) {
43 44
      return;  // skip runtime infershape when is tensor array;
    }
45

46
    auto x_var_types = ctx->GetInputsVarType("X");
47
    auto x_dims = ctx->GetInputsDim("X");
48

Q
Qiao Longfei 已提交
49
    size_t N = x_dims.size();
50 51
    PADDLE_ENFORCE_GT(N, 0, "Input tensors count should > 0.");
    if (N == 1) {
M
minqiyang 已提交
52
      VLOG(3) << "Warning: sum have only one input, may waste memory";
53
    }
Q
qiaolongfei 已提交
54

55
    framework::DDim in_dim({0});
56
    for (size_t i = 0; i < x_dims.size(); ++i) {
57 58 59 60
      auto& x_dim = x_dims[i];
      // x_dim.size() == 1 means the real dim of selected rows is [0]
      if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS &&
          x_dim.size() == 1) {
61 62
        continue;
      }
63 64 65 66 67 68
      if (framework::product(x_dim) == 0) {
        continue;
      }
      if (framework::product(in_dim) == 0) {
        in_dim = x_dim;
      } else {
Z
zhaoyuchen 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
        if (ctx->IsRuntime()) {
          PADDLE_ENFORCE_EQ(in_dim, x_dim,
                            "Input tensors must have same shape");
        } else {
          PADDLE_ENFORCE_EQ(in_dim.size(), x_dim.size(),
                            "Input tensors must have same shape size");
          // if in_dim or x_dim has -1, not check equal
          for (int i = 0; i < x_dim.size(); ++i) {
            if (x_dim[i] == -1 || in_dim[i] == -1) {
              continue;
            }
            PADDLE_ENFORCE_EQ(in_dim[i], x_dim[i],
                              "Input tensors must have same shape if not -1");
          }
        }
84
      }
Q
qijun 已提交
85
    }
Q
Qiao Longfei 已提交
86 87
    ctx->SetOutputDim("Out", in_dim);
    ctx->ShareLoD("X", /*->*/ "Out");
88
  }
89 90

 protected:
91
  framework::OpKernelType GetExpectedKernelType(
92 93
      const framework::ExecutionContext& ctx) const override {
    auto x_vars = ctx.MultiInputVar("X");
C
chengduo 已提交
94
    auto x_vars_name = ctx.Inputs("X");
95 96 97 98 99 100 101 102 103 104 105 106

    framework::LibraryType library{framework::LibraryType::kPlain};
    framework::DataLayout layout{framework::DataLayout::kAnyLayout};

#ifdef PADDLE_WITH_MKLDNN
    if (library == framework::LibraryType::kPlain &&
        platform::CanMKLDNNBeUsed(ctx)) {
      library = framework::LibraryType::kMKLDNN;
      layout = framework::DataLayout::kMKLDNN;
    }
#endif

107
    if (x_vars[0]->IsType<framework::LoDTensor>()) {
108
      int dtype = -1;
C
chengduo 已提交
109 110 111
      for (size_t idx = 0; idx < x_vars.size(); ++idx) {
        PADDLE_ENFORCE(x_vars[idx] != nullptr,
                       "Input var[%s] should not be nullptr", x_vars_name[idx]);
C
chengduo 已提交
112 113
        auto tensor =
            framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_vars[idx]);
114
        if (tensor->numel() <= 0 || (!tensor->IsInitialized())) {
115 116 117
          continue;
        }
        if (dtype == -1) {
Y
Yu Yang 已提交
118
          dtype = tensor->type();
119
        } else {
Y
Yu Yang 已提交
120
          PADDLE_ENFORCE_EQ(dtype, tensor->type());
121 122 123 124 125
        }
      }
      PADDLE_ENFORCE_NE(dtype, -1,
                        "Sum operator should have at least one tensor");

126
      return framework::OpKernelType(
127 128
          static_cast<framework::proto::VarType::Type>(dtype), ctx.GetPlace(),
          layout, library);
129
    } else if (x_vars[0]->IsType<framework::SelectedRows>()) {
130 131 132
      for (auto& var : x_vars) {
        auto& value = var->Get<framework::SelectedRows>().value();
        if (value.IsInitialized()) {
Y
Yu Yang 已提交
133 134
          return framework::OpKernelType(value.type(), ctx.device_context(),
                                         layout, library);
135 136 137 138
        }
      }
      // if input sparse vars are not initialized, use an default kernel type.
      return framework::OpKernelType(framework::proto::VarType::FP32,
139
                                     ctx.device_context(), layout, library);
140
    } else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
Y
Yang Yang(Tony) 已提交
141 142 143 144
      for (auto& x_var : x_vars) {
        auto& array = x_var->Get<framework::LoDTensorArray>();
        for (auto& each : array) {
          if (each.numel() != 0) {
Y
Yu Yang 已提交
145 146
            return framework::OpKernelType(each.type(), ctx.device_context(),
                                           layout, library);
Y
Yang Yang(Tony) 已提交
147
          }
148 149
        }
      }
Y
Yang Yang(Tony) 已提交
150
      PADDLE_THROW("Cannot find the input data type by all input data");
151 152
    }
    PADDLE_THROW("Unexpected branch. Input type is %s",
S
sneaxiy 已提交
153
                 framework::ToTypeName(x_vars[0]->Type()));
154
  }
155 156 157 158
};

class SumOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
159
  void Make() override {
160 161
    AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
        .AsDuplicable();
162
    AddOutput("Out", "(Tensor) The output tensor of sum operator.");
163 164 165
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
166
    AddComment(R"DOC(
167
Sum operator.
168

169 170
This operators sums the input tensors. All the inputs can carry the
LoD (Level of Details) information. However, the output only shares
171
the LoD information with the first input.
172
)DOC");
173 174 175
  }
};

Q
QI JUN 已提交
176 177
class SumOpVarTypeInference : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
178 179
  void operator()(framework::InferVarTypeContext* ctx) const override {
    auto& inputs = ctx->Input("X");
180
    auto var_type = framework::proto::VarType::SELECTED_ROWS;
M
minqiyang 已提交
181 182
    for (auto& name : ctx->Input("X")) {
      VLOG(10) << name << " " << ctx->GetType(name);
Y
Yang Yang(Tony) 已提交
183 184
    }

Q
QI JUN 已提交
185
    bool any_input_is_lod_tensor = std::any_of(
M
minqiyang 已提交
186 187
        inputs.begin(), inputs.end(), [ctx](const std::string& name) {
          return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR;
Q
QI JUN 已提交
188
        });
189

M
minqiyang 已提交
190 191
    auto is_tensor_array = [ctx](const std::string& name) {
      return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY;
192 193 194 195 196 197 198 199
    };

    bool any_input_is_tensor_array =
        std::any_of(inputs.begin(), inputs.end(), is_tensor_array);
    bool all_inputs_are_tensor_array =
        std::all_of(inputs.begin(), inputs.end(), is_tensor_array);

    if (any_input_is_tensor_array) {
Y
Yang Yang(Tony) 已提交
200 201 202
      if (!all_inputs_are_tensor_array) {
        std::ostringstream os;
        for (auto& each : inputs) {
M
minqiyang 已提交
203
          os << "    " << each << " type is " << ctx->GetType(each) << "\n";
Y
Yang Yang(Tony) 已提交
204 205 206 207
        }
        PADDLE_ENFORCE(all_inputs_are_tensor_array,
                       "Not all inputs are tensor array:\n%s", os.str());
      }
208
      var_type = framework::proto::VarType::LOD_TENSOR_ARRAY;
209
    } else if (any_input_is_lod_tensor) {
210
      var_type = framework::proto::VarType::LOD_TENSOR;
Q
QI JUN 已提交
211 212
    }

M
minqiyang 已提交
213 214 215
    auto out_var_name = ctx->Output("Out").front();
    ctx->SetType(out_var_name, var_type);
    ctx->SetDataType(out_var_name, ctx->GetDataType(inputs.front()));
Q
QI JUN 已提交
216 217 218
  }
};

219
class SumGradMaker : public framework::GradOpDescMakerBase {
220
 public:
221
  using framework::GradOpDescMakerBase::GradOpDescMakerBase;
222

Y
Yu Yang 已提交
223
  std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
224
    auto x_grads = InputGrad("X", false);
Y
Yu Yang 已提交
225
    std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
226 227 228 229
    grad_ops.reserve(x_grads.size());
    auto og = OutputGrad("Out");
    std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
                   [&og](const std::string& x_grad) {
Y
Yu Yang 已提交
230
                     auto* grad_op = new framework::OpDesc();
Y
Yu Yang 已提交
231 232 233 234
                     grad_op->SetType("scale");
                     grad_op->SetInput("X", og);
                     grad_op->SetOutput("Out", {x_grad});
                     grad_op->SetAttr("scale", 1.0f);
Y
Yu Yang 已提交
235
                     return std::unique_ptr<framework::OpDesc>(grad_op);
236 237
                   });
    return grad_ops;
238 239 240
  }
};

241 242 243 244 245 246 247 248
class SumInplace : public framework::InplaceOpInference {
 public:
  std::unordered_map<std::string, std::string> operator()(
      const framework::OpDesc& op_desc, bool use_cuda) const override {
    return {{"X", "Out"}};
  }
};

249 250 251 252
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
253

Q
QI JUN 已提交
254
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker,
255
                  ops::SumOpVarTypeInference, ops::SumInplace);
256

Q
QI JUN 已提交
257 258 259 260 261
REGISTER_OP_CPU_KERNEL(
    sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, double>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, int>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, int64_t>);