sum_op.cc 8.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
12
#include "paddle/fluid/operators/sum_op.h"
13

14 15
#include <algorithm>
#include <string>
16
#include <vector>
17

Y
Yi Wang 已提交
18 19
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
20

21 22 23 24
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

25 26 27 28 29 30 31 32
namespace paddle {
namespace operators {
using framework::Tensor;

class SumOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

33
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
34
    PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null");
35

Q
Qiao Longfei 已提交
36 37
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
                   "Output(Out) of SumOp should not be null.");
38 39
    if (ctx->IsRuntime() &&
        ctx->GetOutputsVarType("Out")[0] ==
40
            framework::proto::VarType::LOD_TENSOR_ARRAY) {
41 42
      return;  // skip runtime infershape when is tensor array;
    }
43

44
    auto x_dims = ctx->GetInputsDim("X");
Q
Qiao Longfei 已提交
45
    size_t N = x_dims.size();
46 47 48 49
    PADDLE_ENFORCE_GT(N, 0, "Input tensors count should > 0.");
    if (N == 1) {
      VLOG(3) << "Warning: sum have only one input, may waste memory";
    }
Q
qiaolongfei 已提交
50

51 52 53 54 55 56 57 58 59 60
    framework::DDim in_dim({0});
    for (auto& x_dim : x_dims) {
      if (framework::product(x_dim) == 0) {
        continue;
      }
      if (framework::product(in_dim) == 0) {
        in_dim = x_dim;
      } else {
        PADDLE_ENFORCE_EQ(in_dim, x_dim, "Input tensors must have same shape");
      }
Q
qijun 已提交
61
    }
Q
Qiao Longfei 已提交
62 63
    ctx->SetOutputDim("Out", in_dim);
    ctx->ShareLoD("X", /*->*/ "Out");
64
  }
65 66

 protected:
67
  framework::OpKernelType GetExpectedKernelType(
68 69
      const framework::ExecutionContext& ctx) const override {
    auto x_vars = ctx.MultiInputVar("X");
70 71 72 73 74 75 76 77 78 79 80 81

    framework::LibraryType library{framework::LibraryType::kPlain};
    framework::DataLayout layout{framework::DataLayout::kAnyLayout};

#ifdef PADDLE_WITH_MKLDNN
    if (library == framework::LibraryType::kPlain &&
        platform::CanMKLDNNBeUsed(ctx)) {
      library = framework::LibraryType::kMKLDNN;
      layout = framework::DataLayout::kMKLDNN;
    }
#endif

82
    if (x_vars[0]->IsType<framework::LoDTensor>()) {
83
      int dtype = -1;
C
chengduozh 已提交
84
      for (auto& x_var : x_vars) {
C
chengduozh 已提交
85
        // FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
C
chengduozh 已提交
86 87
        auto tensor = framework::GetTensorFromVar(
            const_cast<framework::Variable*>(x_var));
C
chengduozh 已提交
88
        if (tensor->numel() == 0) {
89 90 91
          continue;
        }
        if (dtype == -1) {
C
chengduozh 已提交
92
          dtype = framework::ToDataType(tensor->type());
93
        } else {
C
chengduozh 已提交
94
          PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(tensor->type()));
95 96 97 98 99
        }
      }
      PADDLE_ENFORCE_NE(dtype, -1,
                        "Sum operator should have at least one tensor");

100
      return framework::OpKernelType(
101 102
          static_cast<framework::proto::VarType::Type>(dtype), ctx.GetPlace(),
          layout, library);
103
    } else if (x_vars[0]->IsType<framework::SelectedRows>()) {
104 105 106 107
      for (auto& var : x_vars) {
        auto& value = var->Get<framework::SelectedRows>().value();
        if (value.IsInitialized()) {
          return framework::OpKernelType(framework::ToDataType(value.type()),
108
                                         ctx.device_context(), layout, library);
109 110 111 112
        }
      }
      // if input sparse vars are not initialized, use an default kernel type.
      return framework::OpKernelType(framework::proto::VarType::FP32,
113
                                     ctx.device_context(), layout, library);
114
    } else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
Y
Yang Yang(Tony) 已提交
115 116 117 118 119
      for (auto& x_var : x_vars) {
        auto& array = x_var->Get<framework::LoDTensorArray>();
        for (auto& each : array) {
          if (each.numel() != 0) {
            return framework::OpKernelType(framework::ToDataType(each.type()),
120 121
                                           ctx.device_context(), layout,
                                           library);
Y
Yang Yang(Tony) 已提交
122
          }
123 124
        }
      }
Y
Yang Yang(Tony) 已提交
125
      PADDLE_THROW("Cannot find the input data type by all input data");
126 127 128 129
    }
    PADDLE_THROW("Unexpected branch. Input type is %s",
                 x_vars[0]->Type().name());
  }
130 131 132 133
};

class SumOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
134
  void Make() override {
135 136
    AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
        .AsDuplicable();
137
    AddOutput("Out", "(Tensor) The output tensor of sum operator.");
138 139 140
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
141
    AddComment(R"DOC(
142
Sum operator.
143

144 145
This operators sums the input tensors. All the inputs can carry the
LoD (Level of Details) information. However, the output only shares
146
the LoD information with the first input.
147
)DOC");
148 149 150
  }
};

Q
QI JUN 已提交
151 152
class SumOpVarTypeInference : public framework::VarTypeInference {
 public:
Y
Yu Yang 已提交
153 154
  void operator()(const framework::OpDesc& op_desc,
                  framework::BlockDesc* block) const override {
Q
QI JUN 已提交
155
    auto& inputs = op_desc.Input("X");
156
    auto var_type = framework::proto::VarType::SELECTED_ROWS;
Y
Yang Yang(Tony) 已提交
157 158
    for (auto& name : op_desc.Input("X")) {
      VLOG(10) << name << " "
Y
Yang Yu 已提交
159
               << block->FindRecursiveOrCreateVar(name).GetType();
Y
Yang Yang(Tony) 已提交
160 161
    }

Q
QI JUN 已提交
162 163
    bool any_input_is_lod_tensor = std::any_of(
        inputs.begin(), inputs.end(), [block](const std::string& name) {
Y
Yang Yu 已提交
164
          return block->FindRecursiveOrCreateVar(name).GetType() ==
165
                 framework::proto::VarType::LOD_TENSOR;
Q
QI JUN 已提交
166
        });
167 168

    auto is_tensor_array = [block](const std::string& name) {
Y
Yang Yu 已提交
169
      return block->FindRecursiveOrCreateVar(name).GetType() ==
170
             framework::proto::VarType::LOD_TENSOR_ARRAY;
171 172 173 174 175 176 177 178
    };

    bool any_input_is_tensor_array =
        std::any_of(inputs.begin(), inputs.end(), is_tensor_array);
    bool all_inputs_are_tensor_array =
        std::all_of(inputs.begin(), inputs.end(), is_tensor_array);

    if (any_input_is_tensor_array) {
Y
Yang Yang(Tony) 已提交
179 180 181 182
      if (!all_inputs_are_tensor_array) {
        std::ostringstream os;
        for (auto& each : inputs) {
          os << "    " << each << " type is "
Y
Yang Yu 已提交
183
             << block->FindRecursiveOrCreateVar(each).GetType() << "\n";
Y
Yang Yang(Tony) 已提交
184 185 186 187
        }
        PADDLE_ENFORCE(all_inputs_are_tensor_array,
                       "Not all inputs are tensor array:\n%s", os.str());
      }
188
      var_type = framework::proto::VarType::LOD_TENSOR_ARRAY;
189
    } else if (any_input_is_lod_tensor) {
190
      var_type = framework::proto::VarType::LOD_TENSOR;
Q
QI JUN 已提交
191 192 193
    }

    auto out_var_name = op_desc.Output("Out").front();
Y
Yang Yu 已提交
194
    auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
Y
Yang Yang(Tony) 已提交
195 196 197
    out_var.SetType(var_type);
    auto& in_var = detail::Ref(block->FindVarRecursive(inputs.front()));
    out_var.SetDataType(in_var.GetDataType());
Q
QI JUN 已提交
198 199 200
  }
};

201
class SumGradMaker : public framework::GradOpDescMakerBase {
202
 public:
203
  using framework::GradOpDescMakerBase::GradOpDescMakerBase;
204

Y
Yu Yang 已提交
205
  std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
206
    auto x_grads = InputGrad("X", false);
Y
Yu Yang 已提交
207
    std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
208 209 210 211
    grad_ops.reserve(x_grads.size());
    auto og = OutputGrad("Out");
    std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
                   [&og](const std::string& x_grad) {
Y
Yu Yang 已提交
212
                     auto* grad_op = new framework::OpDesc();
Y
Yu Yang 已提交
213 214 215 216
                     grad_op->SetType("scale");
                     grad_op->SetInput("X", og);
                     grad_op->SetOutput("Out", {x_grad});
                     grad_op->SetAttr("scale", 1.0f);
Y
Yu Yang 已提交
217
                     return std::unique_ptr<framework::OpDesc>(grad_op);
218 219
                   });
    return grad_ops;
220 221 222 223 224 225 226
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
227

Q
QI JUN 已提交
228 229
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker,
                  ops::SumOpVarTypeInference);
230

Q
QI JUN 已提交
231 232 233 234 235
REGISTER_OP_CPU_KERNEL(
    sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, double>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, int>,
    ops::SumKernel<paddle::platform::CPUDeviceContext, int64_t>);