operator.h 12.9 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
18
#include <atomic>
Q
Qiao Longfei 已提交
19 20 21 22
#include <string>
#include <unordered_map>
#include <vector>

Y
Yu Yang 已提交
23
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
24
#include "paddle/framework/attribute.h"
Q
qiaolongfei 已提交
25
#include "paddle/framework/block_desc.h"
Y
Yu Yang 已提交
26
#include "paddle/framework/framework.pb.h"
27
#include "paddle/framework/lod_tensor.h"
Y
Yu Yang 已提交
28
#include "paddle/framework/op_info.h"
Q
QI JUN 已提交
29
#include "paddle/framework/op_kernel_type.h"
Q
qijun 已提交
30
#include "paddle/framework/scope.h"
Q
QI JUN 已提交
31
#include "paddle/framework/selected_rows.h"
Q
qijun 已提交
32 33
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
Y
Yu Yang 已提交
34
#include "paddle/platform/variant.h"
Q
qijun 已提交
35
#include "paddle/utils/Error.h"
Q
Qiao Longfei 已提交
36 37 38 39

namespace paddle {
namespace framework {

40
/// If a variable is a empty variable, that name will be used.
41
constexpr char kEmptyVarName[] = "@EMPTY@";
42 43 44

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
45
constexpr char kTempVarName[] = "@TEMP@";
46 47 48 49

/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
50
constexpr char kGradVarSuffix[] = "@GRAD";
51 52

/// Variables with this suffix are supposed to be filled up with zeros.
53
constexpr char kZeroVarSuffix[] = "@ZERO";
54 55 56 57 58

inline std::string GradVarName(const std::string& var_name) {
  return var_name + kGradVarSuffix;
}

Q
Qiao Longfei 已提交
59
class OperatorBase;
60
class ExecutionContext;
61

Q
Qiao Longfei 已提交
62 63 64 65 66 67 68 69
/**
 * OperatorBase has the basic element that Net will call to do computation.
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
Y
Yu Yang 已提交
70 71
  OperatorBase(const std::string& type, const VariableNameMap& inputs,
               const VariableNameMap& outputs, const AttributeMap& attrs);
72

Q
Qiao Longfei 已提交
73 74 75
  virtual ~OperatorBase() {}

  template <typename T>
Y
Yu Yang 已提交
76
  inline const T& Attr(const std::string& name) const {
Q
Qiao Longfei 已提交
77 78 79 80 81
    PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
                   name);
    return boost::get<T>(attrs_.at(name));
  }

82
  virtual std::string DebugString() const;
Q
Qiao Longfei 已提交
83 84

  /// Net will call this function to Run an op.
D
dzhwinter 已提交
85
  virtual void Run(const Scope& scope, const platform::Place& place) const = 0;
Y
Yu Yang 已提交
86

Y
Yu Yang 已提交
87 88
  virtual bool IsNetOp() const { return false; }

89 90
  virtual bool SupportGPU() const { return false; }

D
dongzhihong 已提交
91 92 93
  /// rename inputs outputs name
  void Rename(const std::string& old_name, const std::string& new_name);

Y
Yu Yang 已提交
94 95
  const VariableNameMap& Inputs() const { return inputs_; }
  const VariableNameMap& Outputs() const { return outputs_; }
96

Y
Yu Yang 已提交
97
  //! Get a input with argument's name described in `op_proto`
98
  std::string Input(const std::string& name) const;
Y
Yu Yang 已提交
99
  //! Get a input which has multiple variables.
Y
Yu Yang 已提交
100
  const std::vector<std::string>& Inputs(const std::string& name) const;
Y
Yi Wang 已提交
101

Q
qijun 已提交
102 103
  std::vector<std::string> InputVars() const;

Y
Yu Yang 已提交
104
  //! Get a output with argument's name described in `op_proto`
105
  std::string Output(const std::string& name) const;
Y
Yu Yang 已提交
106 107
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yu Yang 已提交
108
  const std::vector<std::string>& Outputs(const std::string& name) const;
Y
Yan Chunwei 已提交
109

Y
Yu Yang 已提交
110
  virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
111

Q
qiaolongfei 已提交
112
  const std::string& Type() const { return type_; }
Q
qiaolongfei 已提交
113
  void SetType(const std::string& type) { type_ = type; }
Y
Yi Wang 已提交
114 115
  const AttributeMap& Attrs() const { return attrs_; }

Y
Yu Yang 已提交
116
  // Return a new operator instance, which is as same as this.
Y
Yu Yang 已提交
117 118
  // Use unique_ptr to prevent caller forget to delete this pointer.
  virtual std::unique_ptr<OperatorBase> Clone() const = 0;
Y
Yu Yang 已提交
119

Q
qiaolongfei 已提交
120
 protected:
Q
Qiao Longfei 已提交
121
  std::string type_;
D
dongzhihong 已提交
122
  // NOTE: in case of OpGrad, inputs_ contains:
123
  // I (Inputs)
D
dongzhihong 已提交
124 125
  // O (Outputs)
  // OG (Output Gradients)
Y
Yu Yang 已提交
126
  VariableNameMap inputs_;
Y
Yu Yang 已提交
127

D
dongzhihong 已提交
128 129
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Y
Yu Yang 已提交
130
  VariableNameMap outputs_;
Q
Qiao Longfei 已提交
131
  AttributeMap attrs_;
132 133 134 135

 private:
  void GenerateTemporaryNames();
  void CheckAllInputOutputSet() const;
Y
Yan Chunwei 已提交
136 137
};

Y
Yu Yang 已提交
138 139
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
140
// register it. i.e. `Clone` method is not needed to define by yourself.
141 142 143
#define DEFINE_OP_CLONE_METHOD(cls)                                            \
  std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final {     \
    return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \
Y
Yu Yang 已提交
144
  }
Y
Yu Yang 已提交
145

Y
Yu Yang 已提交
146 147 148 149
// Macro for define a default constructor for Operator.
// You can also use
//   using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
Y
Yu Yang 已提交
150 151
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls)             \
  cls(const std::string& type,                             \
Y
Yu Yang 已提交
152 153 154
      const ::paddle::framework::VariableNameMap& inputs,  \
      const ::paddle::framework::VariableNameMap& outputs, \
      const paddle::framework::AttributeMap& attrs)        \
Y
Yu Yang 已提交
155
      : parent_cls(type, inputs, outputs, attrs) {}
Y
Yu Yang 已提交
156

157 158
class NOP : public OperatorBase {
 public:
159
  using OperatorBase::OperatorBase;
D
dzhwinter 已提交
160
  void Run(const Scope& scope, const platform::Place& place) const override {}
161 162 163
  std::unique_ptr<OperatorBase> Clone() const override {
    return std::unique_ptr<OperatorBase>(new NOP(*this));
  }
164 165
};

166
class ExecutionContext {
Y
Yan Chunwei 已提交
167
 public:
168 169 170
  ExecutionContext(const OperatorBase& op, const Scope& scope,
                   const platform::DeviceContext& device_context)
      : op_(op), scope_(scope), device_context_(device_context) {}
171

Q
qiaolongfei 已提交
172 173 174 175
  const OperatorBase& op() const { return op_; }

  const Scope& scope() const { return scope_; }

Q
qiaolongfei 已提交
176
  template <typename T>
Y
Yu Yang 已提交
177 178
  inline const T& Attr(const std::string& name) const {
    return op_.Attr<T>(name);
Q
qiaolongfei 已提交
179 180
  }

Y
Yu Yang 已提交
181
  size_t InputSize(const std::string& name) const {
Y
Yu Yang 已提交
182
    return op_.Inputs(name).size();
Y
Yan Chunwei 已提交
183 184
  }

Y
Yu Yang 已提交
185
  size_t OutputSize(const std::string& name) const {
Y
Yu Yang 已提交
186
    return op_.Outputs(name).size();
Y
Yan Chunwei 已提交
187 188
  }

189
  const Variable* InputVar(const std::string& name) const {
190
    auto ipt = op_.Input(name);
Y
Yu Yang 已提交
191
    return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
Y
Yan Chunwei 已提交
192 193
  }

194
  Variable* OutputVar(const std::string& name) const {
195
    auto opt = op_.Output(name);
Y
Yu Yang 已提交
196
    return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
Y
Yan Chunwei 已提交
197 198
  }

199 200
  const std::vector<const Variable*> MultiInputVar(
      const std::string& name) const {
Y
Yan Chunwei 已提交
201 202
    auto names = op_.Inputs(name);
    std::vector<const Variable*> res;
203
    res.reserve(names.size());
204 205
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
206 207
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
208
                   });
Y
Yan Chunwei 已提交
209 210 211
    return res;
  }

212
  std::vector<Variable*> MultiOutputVar(const std::string& name) const {
Y
Yan Chunwei 已提交
213
    auto names = op_.Outputs(name);
214
    std::vector<Variable*> res;
215
    res.reserve(names.size());
216 217
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
218 219
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
220
                   });
Y
Yan Chunwei 已提交
221 222 223
    return res;
  }

224 225
  template <typename T>
  const T* Input(const std::string& name) const {
Y
Yu Yang 已提交
226
    auto* var = InputVar(name);
227
    return var == nullptr ? nullptr : &var->Get<T>();
228 229 230 231
  }

  template <typename T>
  T* Output(const std::string& name) const {
232
    auto var = OutputVar(name);
233
    return var == nullptr ? nullptr : var->GetMutable<T>();
234 235 236 237 238 239 240 241
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
    auto names = op_.Inputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
242
                   [&](const std::string& sub_name) {
243
                     auto var = scope_.FindVar(sub_name);
244
                     return var == nullptr ? nullptr : &var->Get<T>();
245 246 247 248 249
                   });
    return res;
  }

  template <typename T>
250
  std::vector<T*> MultiOutput(const std::string& name) const {
251
    auto names = op_.Outputs(name);
252
    std::vector<T*> res;
253 254
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
255
                   [&](const std::string& sub_name) {
256
                     auto var = scope_.FindVar(sub_name);
257
                     return var == nullptr ? nullptr : var->GetMutable<T>();
258 259 260 261
                   });
    return res;
  }

262 263 264 265 266 267
  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) const {
    PADDLE_ENFORCE_LT(i, InputSize(in));
    PADDLE_ENFORCE_LT(j, OutputSize(out));
    auto* in_var = MultiInputVar(in)[i];
    auto* out_var = MultiOutputVar(out)[j];
268
    if (!in_var->IsType<LoDTensor>()) return;
269
    PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
270
                   "The %d-th output of Output(%s) must be LoDTensor.", j, out);
271 272 273
    auto in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_lod(in_tensor.lod());
274 275
  }

276
  platform::Place GetPlace() const { return device_context_.GetPlace(); }
Q
qijun 已提交
277

Q
QI JUN 已提交
278 279 280 281 282
  template <typename DeviceContextType>
  const DeviceContextType& device_context() const {
    return *reinterpret_cast<const DeviceContextType*>(&device_context_);
  }

283
  const platform::DeviceContext& device_context() const {
Q
qijun 已提交
284
    return device_context_;
Q
qijun 已提交
285
  }
Q
qijun 已提交
286

Q
QI JUN 已提交
287 288 289 290 291 292 293 294
#ifdef PADDLE_WITH_CUDA
  const inline platform::CUDADeviceContext& cuda_device_context() const {
    PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
    return *reinterpret_cast<const platform::CUDADeviceContext*>(
        &device_context_);
  }
#endif

D
dzhwinter 已提交
295
  //! Get actual name vector for this input.
D
Dong Zhihong 已提交
296 297 298
  const std::vector<std::string>& Inputs(const std::string& name) const {
    return op_.Inputs(name);
  }
D
Dong Zhihong 已提交
299

D
dzhwinter 已提交
300
  //! Get actual name vector for this output.
D
Dong Zhihong 已提交
301 302 303 304
  const std::vector<std::string>& Outputs(const std::string& name) const {
    return op_.Outputs(name);
  }

305
 private:
306 307
  const OperatorBase& op_;
  const Scope& scope_;
308
  const platform::DeviceContext& device_context_;
Q
Qiao Longfei 已提交
309 310
};

311 312 313 314 315 316 317 318 319 320 321 322 323 324
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;

template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
    const std::string& name) const;

template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;

template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
    const std::string& name) const;

Y
Yu Yang 已提交
325
class OpKernelBase {
Q
qijun 已提交
326
 public:
Q
qijun 已提交
327
  /**
328
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
329 330
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
331
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
332 333
   */

334
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
335

Y
Yu Yang 已提交
336 337 338 339 340 341 342
  virtual ~OpKernelBase() = default;
};

template <typename T>
class OpKernel : public OpKernelBase {
 public:
  using ELEMENT_TYPE = T;
Y
Yu Yang 已提交
343 344
};

Y
Yu Yang 已提交
345 346
class OperatorWithKernel : public OperatorBase {
 public:
Y
Yu Yang 已提交
347
  using OpKernelMap =
Y
Yu Yang 已提交
348 349
      std::unordered_map<OpKernelType, std::unique_ptr<OpKernelBase>,
                         OpKernelType::Hash>;
Q
Qiao Longfei 已提交
350

Y
Yu Yang 已提交
351 352
  OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
                     const VariableNameMap& outputs, const AttributeMap& attrs)
Y
Yu Yang 已提交
353 354
      : OperatorBase(type, inputs, outputs, attrs) {}

D
dzhwinter 已提交
355
  void Run(const Scope& scope, const platform::Place& place) const final;
Q
Qiao Longfei 已提交
356

Y
Yu Yang 已提交
357 358 359 360
  static std::unordered_map<std::string /* op_type */, OpKernelMap>&
  AllOpKernels() {
    static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
    return g_all_op_kernels;
Y
Yu Yang 已提交
361
  }
Y
Yan Chunwei 已提交
362

363
  bool SupportGPU() const override {
Y
Yu Yang 已提交
364 365 366 367 368
    auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
    return std::any_of(op_kernels.begin(), op_kernels.end(),
                       [](OpKernelMap::const_reference kern_pair) {
                         return platform::is_gpu_place(kern_pair.first.place_);
                       });
369 370
  }

371 372 373
  virtual void InferShape(InferShapeContext* ctx) const {
    OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
  }
Y
Yu Yang 已提交
374

Q
qiaolongfei 已提交
375
 protected:
Y
Yu Yang 已提交
376 377 378
  virtual OpKernelType GetKernelType(const ExecutionContext& ctx) const;

 private:
Y
Yu Yang 已提交
379 380
  // indicate kernel DataType by input data. Defaultly all input data must be
  // same.
381
  proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
Q
Qiao Longfei 已提交
382 383
};

Y
Yu Yang 已提交
384 385
extern bool OpSupportGPU(const std::string& op_type);

Q
Qiao Longfei 已提交
386 387
}  // namespace framework
}  // namespace paddle