sum_op.cc 9.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
12
#include "paddle/fluid/operators/sum_op.h"
13

14
#include <algorithm>
M
minqiyang 已提交
15
#include <memory>
16
#include <string>
17
#include <vector>
18

Y
Yi Wang 已提交
19 20
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
21

22 23 24 25
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

26 27 28 29 30 31 32 33
namespace paddle {
namespace operators {
using framework::Tensor;

class SumOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

34
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
35
    PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null");
36

Q
Qiao Longfei 已提交
37 38
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
                   "Output(Out) of SumOp should not be null.");
39 40
    if (ctx->IsRuntime() &&
        ctx->GetOutputsVarType("Out")[0] ==
41
            framework::proto::VarType::LOD_TENSOR_ARRAY) {
42 43
      return;  // skip runtime infershape when is tensor array;
    }
44

45
    auto x_var_types = ctx->GetInputsVarType("X");
46
    auto x_dims = ctx->GetInputsDim("X");
47

Q
Qiao Longfei 已提交
48
    size_t N = x_dims.size();
49 50
    PADDLE_ENFORCE_GT(N, 0, "Input tensors count should > 0.");
    if (N == 1) {
M
minqiyang 已提交
51
      VLOG(3) << "Warning: sum have only one input, may waste memory";
52
    }
Q
qiaolongfei 已提交
53

54
    framework::DDim in_dim({0});
55
    for (size_t i = 0; i < x_dims.size(); ++i) {
56 57 58 59
      auto& x_dim = x_dims[i];
      // x_dim.size() == 1 means the real dim of selected rows is [0]
      if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS &&
          x_dim.size() == 1) {
60 61
        continue;
      }
62 63 64 65 66 67
      if (framework::product(x_dim) == 0) {
        continue;
      }
      if (framework::product(in_dim) == 0) {
        in_dim = x_dim;
      } else {
Z
zhaoyuchen 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        if (ctx->IsRuntime()) {
          PADDLE_ENFORCE_EQ(in_dim, x_dim,
                            "Input tensors must have same shape");
        } else {
          PADDLE_ENFORCE_EQ(in_dim.size(), x_dim.size(),
                            "Input tensors must have same shape size");
          // if in_dim or x_dim has -1, not check equal
          for (int i = 0; i < x_dim.size(); ++i) {
            if (x_dim[i] == -1 || in_dim[i] == -1) {
              continue;
            }
            PADDLE_ENFORCE_EQ(in_dim[i], x_dim[i],
                              "Input tensors must have same shape if not -1");
          }
        }
83
      }
Q
qijun 已提交
84
    }
Q
Qiao Longfei 已提交
85 86
    ctx->SetOutputDim("Out", in_dim);
    ctx->ShareLoD("X", /*->*/ "Out");
87
  }
88 89

 protected:
90
  framework::OpKernelType GetExpectedKernelType(
91 92
      const framework::ExecutionContext& ctx) const override {
    auto x_vars = ctx.MultiInputVar("X");
C
chengduo 已提交
93
    auto x_vars_name = ctx.Inputs("X");
94 95 96 97 98 99 100 101 102 103 104 105

    framework::LibraryType library{framework::LibraryType::kPlain};
    framework::DataLayout layout{framework::DataLayout::kAnyLayout};

#ifdef PADDLE_WITH_MKLDNN
    if (library == framework::LibraryType::kPlain &&
        platform::CanMKLDNNBeUsed(ctx)) {
      library = framework::LibraryType::kMKLDNN;
      layout = framework::DataLayout::kMKLDNN;
    }
#endif

106
    if (x_vars[0]->IsType<framework::LoDTensor>()) {
107
      int dtype = -1;
C
chengduo 已提交
108 109 110
      for (size_t idx = 0; idx < x_vars.size(); ++idx) {
        PADDLE_ENFORCE(x_vars[idx] != nullptr,
                       "Input var[%s] should not be nullptr", x_vars_name[idx]);
C
chengduo 已提交
111 112
        auto tensor =
            framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_vars[idx]);
113
        if (tensor->numel() == 0) {
114 115 116
          continue;
        }
        if (dtype == -1) {
Y
Yu Yang 已提交
117
          dtype = tensor->type();
118
        } else {
Y
Yu Yang 已提交
119
          PADDLE_ENFORCE_EQ(dtype, tensor->type());
120 121 122 123 124
        }
      }
      PADDLE_ENFORCE_NE(dtype, -1,
                        "Sum operator should have at least one tensor");

125
      return framework::OpKernelType(
126 127
          static_cast<framework::proto::VarType::Type>(dtype), ctx.GetPlace(),
          layout, library);
128
    } else if (x_vars[0]->IsType<framework::SelectedRows>()) {
129 130 131
      for (auto& var : x_vars) {
        auto& value = var->Get<framework::SelectedRows>().value();
        if (value.IsInitialized()) {
Y
Yu Yang 已提交
132 133
          return framework::OpKernelType(value.type(), ctx.device_context(),
                                         layout, library);
134 135 136 137
        }
      }
      // if input sparse vars are not initialized, use an default kernel type.
      return framework::OpKernelType(framework::proto::VarType::FP32,
138
                                     ctx.device_context(), layout, library);
139
    } else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
Y
Yang Yang(Tony) 已提交
140 141 142 143
      for (auto& x_var : x_vars) {
        auto& array = x_var->Get<framework::LoDTensorArray>();
        for (auto& each : array) {
          if (each.numel() != 0) {
Y
Yu Yang 已提交
144 145
            return framework::OpKernelType(each.type(), ctx.device_context(),
                                           layout, library);
Y
Yang Yang(Tony) 已提交
146
          }
147 148
        }
      }
Y
Yang Yang(Tony) 已提交
149
      PADDLE_THROW("Cannot find the input data type by all input data");
150 151
    }
    PADDLE_THROW("Unexpected branch. Input type is %s",
S
sneaxiy 已提交
152
                 framework::ToTypeName(x_vars[0]->Type()));
153
  }
154 155 156 157
};

class SumOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
158
  void Make() override {
159 160
    AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
        .AsDuplicable();
161
    AddOutput("Out", "(Tensor) The output tensor of sum operator.");
162 163 164
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
165
    AddComment(R"DOC(
166
Sum operator.
167

168 169
This operators sums the input tensors. All the inputs can carry the
LoD (Level of Details) information. However, the output only shares
170
the LoD information with the first input.
171
)DOC");
172 173 174
  }
};

Q
QI JUN 已提交
175 176
class SumOpVarTypeInference : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
177 178
  void operator()(framework::InferVarTypeContext* ctx) const override {
    auto& inputs = ctx->Input("X");
179
    auto var_type = framework::proto::VarType::SELECTED_ROWS;
M
minqiyang 已提交
180 181
    for (auto& name : ctx->Input("X")) {
      VLOG(10) << name << " " << ctx->GetType(name);
Y
Yang Yang(Tony) 已提交
182 183
    }

Q
QI JUN 已提交
184
    bool any_input_is_lod_tensor = std::any_of(
M
minqiyang 已提交
185 186
        inputs.begin(), inputs.end(), [ctx](const std::string& name) {
          return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR;
Q
QI JUN 已提交
187
        });
188

M
minqiyang 已提交
189 190
    auto is_tensor_array = [ctx](const std::string& name) {
      return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY;
191 192 193 194 195 196 197 198
    };

    bool any_input_is_tensor_array =
        std::any_of(inputs.begin(), inputs.end(), is_tensor_array);
    bool all_inputs_are_tensor_array =
        std::all_of(inputs.begin(), inputs.end(), is_tensor_array);

    if (any_input_is_tensor_array) {
Y
Yang Yang(Tony) 已提交
199 200 201
      if (!all_inputs_are_tensor_array) {
        std::ostringstream os;
        for (auto& each : inputs) {
M
minqiyang 已提交
202
          os << "    " << each << " type is " << ctx->GetType(each) << "\n";
Y
Yang Yang(Tony) 已提交
203 204 205 206
        }
        PADDLE_ENFORCE(all_inputs_are_tensor_array,
                       "Not all inputs are tensor array:\n%s", os.str());
      }
207
      var_type = framework::proto::VarType::LOD_TENSOR_ARRAY;
208
    } else if (any_input_is_lod_tensor) {
209
      var_type = framework::proto::VarType::LOD_TENSOR;
Q
QI JUN 已提交
210 211
    }

M
minqiyang 已提交
212 213 214
    auto out_var_name = ctx->Output("Out").front();
    ctx->SetType(out_var_name, var_type);
    ctx->SetDataType(out_var_name, ctx->GetDataType(inputs.front()));
Q
QI JUN 已提交
215 216 217
  }
};

218
class SumGradMaker : public framework::GradOpDescMakerBase {
219
 public:
220
  using framework::GradOpDescMakerBase::GradOpDescMakerBase;
221

Y
Yu Yang 已提交
222
  std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
223
    auto x_grads = InputGrad("X", false);
Y
Yu Yang 已提交
224
    std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
225 226 227 228
    grad_ops.reserve(x_grads.size());
    auto og = OutputGrad("Out");
    std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
                   [&og](const std::string& x_grad) {
Y
Yu Yang 已提交
229
                     auto* grad_op = new framework::OpDesc();
Y
Yu Yang 已提交
230 231 232 233
                     grad_op->SetType("scale");
                     grad_op->SetInput("X", og);
                     grad_op->SetOutput("Out", {x_grad});
                     grad_op->SetAttr("scale", 1.0f);
Y
Yu Yang 已提交
234
                     return std::unique_ptr<framework::OpDesc>(grad_op);
235 236
                   });
    return grad_ops;
237 238 239 240 241 242 243
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
244

Q
QI JUN 已提交
245 246
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker,
                  ops::SumOpVarTypeInference);
247

Q
QI JUN 已提交
248 249 250 251 252
REGISTER_OP_CPU_KERNEL(
    sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, double>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, int>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, int64_t>);