operator.h 13.5 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
18
#include <atomic>
Q
Qiao Longfei 已提交
19
#include <string>
D
dzhwinter 已提交
20
#include <tuple>
Q
Qiao Longfei 已提交
21 22 23
#include <unordered_map>
#include <vector>

Y
Yu Yang 已提交
24
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
25 26 27 28 29 30 31 32 33 34 35
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/variant.h"
Q
qijun 已提交
36
#include "paddle/utils/Error.h"
Q
Qiao Longfei 已提交
37 38 39 40

namespace paddle {
namespace framework {

41
/// If a variable is a empty variable, that name will be used.
42
constexpr char kEmptyVarName[] = "@EMPTY@";
43 44 45

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
46
constexpr char kTempVarName[] = "@TEMP@";
47 48 49 50

/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
51
constexpr char kGradVarSuffix[] = "@GRAD";
52 53

/// Variables with this suffix are supposed to be filled up with zeros.
54
constexpr char kZeroVarSuffix[] = "@ZERO";
55

D
dzhwinter 已提交
56
// define some kernel priority
57
/* Define multiple kernel type fallback order*/
D
dzhwinter 已提交
58 59
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;

60 61 62 63
inline std::string GradVarName(const std::string& var_name) {
  return var_name + kGradVarSuffix;
}

Q
Qiao Longfei 已提交
64
class OperatorBase;
65
class ExecutionContext;
66

Q
Qiao Longfei 已提交
67 68 69 70 71 72 73 74
/**
 * OperatorBase has the basic element that Net will call to do computation.
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
Y
Yu Yang 已提交
75 76
  OperatorBase(const std::string& type, const VariableNameMap& inputs,
               const VariableNameMap& outputs, const AttributeMap& attrs);
77

Q
Qiao Longfei 已提交
78 79 80
  virtual ~OperatorBase() {}

  template <typename T>
Y
Yu Yang 已提交
81
  inline const T& Attr(const std::string& name) const {
Q
Qiao Longfei 已提交
82 83 84 85 86
    PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
                   name);
    return boost::get<T>(attrs_.at(name));
  }

87 88 89 90
  /// if scope is not null, also show dimensions of arguments
  virtual std::string DebugStringEx(const Scope* scope) const;

  std::string DebugString() const { return DebugStringEx(nullptr); }
Q
Qiao Longfei 已提交
91 92

  /// Net will call this function to Run an op.
D
dzhwinter 已提交
93
  virtual void Run(const Scope& scope, const platform::Place& place) const = 0;
Y
Yu Yang 已提交
94

T
typhoonzero 已提交
95 96 97
  // FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
  virtual void Stop() {}

Y
Yu Yang 已提交
98 99
  virtual bool IsNetOp() const { return false; }

100 101
  virtual bool SupportGPU() const { return false; }

D
dongzhihong 已提交
102 103 104
  /// rename inputs outputs name
  void Rename(const std::string& old_name, const std::string& new_name);

Y
Yu Yang 已提交
105 106
  const VariableNameMap& Inputs() const { return inputs_; }
  const VariableNameMap& Outputs() const { return outputs_; }
107

Y
Yu Yang 已提交
108
  //! Get a input with argument's name described in `op_proto`
109
  std::string Input(const std::string& name) const;
Y
Yu Yang 已提交
110
  //! Get a input which has multiple variables.
Y
Yu Yang 已提交
111
  const std::vector<std::string>& Inputs(const std::string& name) const;
Y
Yi Wang 已提交
112

Q
qijun 已提交
113 114
  std::vector<std::string> InputVars() const;

Y
Yu Yang 已提交
115
  //! Get a output with argument's name described in `op_proto`
116
  std::string Output(const std::string& name) const;
Y
Yu Yang 已提交
117 118
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yu Yang 已提交
119
  const std::vector<std::string>& Outputs(const std::string& name) const;
Y
Yan Chunwei 已提交
120

Y
Yu Yang 已提交
121
  virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
122

Q
qiaolongfei 已提交
123
  const std::string& Type() const { return type_; }
Q
qiaolongfei 已提交
124
  void SetType(const std::string& type) { type_ = type; }
Y
Yi Wang 已提交
125 126
  const AttributeMap& Attrs() const { return attrs_; }

Y
Yu Yang 已提交
127
  // Return a new operator instance, which is as same as this.
Y
Yu Yang 已提交
128 129
  // Use unique_ptr to prevent caller forget to delete this pointer.
  virtual std::unique_ptr<OperatorBase> Clone() const = 0;
Y
Yu Yang 已提交
130

Q
qiaolongfei 已提交
131
 protected:
Q
Qiao Longfei 已提交
132
  std::string type_;
D
dongzhihong 已提交
133
  // NOTE: in case of OpGrad, inputs_ contains:
134
  // I (Inputs)
D
dongzhihong 已提交
135 136
  // O (Outputs)
  // OG (Output Gradients)
Y
Yu Yang 已提交
137
  VariableNameMap inputs_;
Y
Yu Yang 已提交
138

D
dongzhihong 已提交
139 140
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Y
Yu Yang 已提交
141
  VariableNameMap outputs_;
Q
Qiao Longfei 已提交
142
  AttributeMap attrs_;
143 144 145 146

 private:
  void GenerateTemporaryNames();
  void CheckAllInputOutputSet() const;
Y
Yan Chunwei 已提交
147 148
};

Y
Yu Yang 已提交
149 150
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
151
// register it. i.e. `Clone` method is not needed to define by yourself.
152 153 154
#define DEFINE_OP_CLONE_METHOD(cls)                                            \
  std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final {     \
    return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \
Y
Yu Yang 已提交
155
  }
Y
Yu Yang 已提交
156

Y
Yu Yang 已提交
157 158 159 160
// Macro for define a default constructor for Operator.
// You can also use
//   using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
Y
Yu Yang 已提交
161 162
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls)             \
  cls(const std::string& type,                             \
Y
Yu Yang 已提交
163 164 165
      const ::paddle::framework::VariableNameMap& inputs,  \
      const ::paddle::framework::VariableNameMap& outputs, \
      const paddle::framework::AttributeMap& attrs)        \
Y
Yu Yang 已提交
166
      : parent_cls(type, inputs, outputs, attrs) {}
Y
Yu Yang 已提交
167

168 169
class NOP : public OperatorBase {
 public:
170
  using OperatorBase::OperatorBase;
D
dzhwinter 已提交
171
  void Run(const Scope& scope, const platform::Place& place) const override {}
172 173 174
  std::unique_ptr<OperatorBase> Clone() const override {
    return std::unique_ptr<OperatorBase>(new NOP(*this));
  }
175 176
};

177
class ExecutionContext {
Y
Yan Chunwei 已提交
178
 public:
179 180 181
  ExecutionContext(const OperatorBase& op, const Scope& scope,
                   const platform::DeviceContext& device_context)
      : op_(op), scope_(scope), device_context_(device_context) {}
182

Q
qiaolongfei 已提交
183 184 185 186
  const OperatorBase& op() const { return op_; }

  const Scope& scope() const { return scope_; }

Q
qiaolongfei 已提交
187
  template <typename T>
Y
Yu Yang 已提交
188 189
  inline const T& Attr(const std::string& name) const {
    return op_.Attr<T>(name);
Q
qiaolongfei 已提交
190 191
  }

Y
Yu Yang 已提交
192
  size_t InputSize(const std::string& name) const {
Y
Yu Yang 已提交
193
    return op_.Inputs(name).size();
Y
Yan Chunwei 已提交
194 195
  }

Y
Yu Yang 已提交
196
  size_t OutputSize(const std::string& name) const {
Y
Yu Yang 已提交
197
    return op_.Outputs(name).size();
Y
Yan Chunwei 已提交
198 199
  }

200
  const Variable* InputVar(const std::string& name) const {
201
    auto ipt = op_.Input(name);
Y
Yu Yang 已提交
202
    return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
Y
Yan Chunwei 已提交
203 204
  }

205
  Variable* OutputVar(const std::string& name) const {
206
    auto opt = op_.Output(name);
Y
Yu Yang 已提交
207
    return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
Y
Yan Chunwei 已提交
208 209
  }

210 211
  const std::vector<const Variable*> MultiInputVar(
      const std::string& name) const {
Y
Yan Chunwei 已提交
212 213
    auto names = op_.Inputs(name);
    std::vector<const Variable*> res;
214
    res.reserve(names.size());
215 216
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
217 218
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
219
                   });
Y
Yan Chunwei 已提交
220 221 222
    return res;
  }

223
  std::vector<Variable*> MultiOutputVar(const std::string& name) const {
Y
Yan Chunwei 已提交
224
    auto names = op_.Outputs(name);
225
    std::vector<Variable*> res;
226
    res.reserve(names.size());
227 228
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
229 230
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
231
                   });
Y
Yan Chunwei 已提交
232 233 234
    return res;
  }

235 236
  template <typename T>
  const T* Input(const std::string& name) const {
Y
Yu Yang 已提交
237
    auto* var = InputVar(name);
238
    return var == nullptr ? nullptr : &var->Get<T>();
239 240 241 242
  }

  template <typename T>
  T* Output(const std::string& name) const {
243
    auto var = OutputVar(name);
244
    return var == nullptr ? nullptr : var->GetMutable<T>();
245 246 247 248 249 250 251 252
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
    auto names = op_.Inputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
253
                   [&](const std::string& sub_name) {
254
                     auto var = scope_.FindVar(sub_name);
255
                     return var == nullptr ? nullptr : &var->Get<T>();
256 257 258 259 260
                   });
    return res;
  }

  template <typename T>
261
  std::vector<T*> MultiOutput(const std::string& name) const {
262
    auto names = op_.Outputs(name);
263
    std::vector<T*> res;
264 265
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
266
                   [&](const std::string& sub_name) {
267
                     auto var = scope_.FindVar(sub_name);
268
                     return var == nullptr ? nullptr : var->GetMutable<T>();
269 270 271 272
                   });
    return res;
  }

273 274 275 276 277 278
  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) const {
    PADDLE_ENFORCE_LT(i, InputSize(in));
    PADDLE_ENFORCE_LT(j, OutputSize(out));
    auto* in_var = MultiInputVar(in)[i];
    auto* out_var = MultiOutputVar(out)[j];
279
    if (!in_var->IsType<LoDTensor>()) return;
280
    PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
281
                   "The %d-th output of Output(%s) must be LoDTensor.", j, out);
282 283 284
    auto in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_lod(in_tensor.lod());
285 286
  }

287
  platform::Place GetPlace() const { return device_context_.GetPlace(); }
Q
qijun 已提交
288

Q
QI JUN 已提交
289 290 291 292 293
  template <typename DeviceContextType>
  const DeviceContextType& device_context() const {
    return *reinterpret_cast<const DeviceContextType*>(&device_context_);
  }

294
  const platform::DeviceContext& device_context() const {
Q
qijun 已提交
295
    return device_context_;
Q
qijun 已提交
296
  }
Q
qijun 已提交
297

Q
QI JUN 已提交
298 299 300 301 302 303 304 305
#ifdef PADDLE_WITH_CUDA
  const inline platform::CUDADeviceContext& cuda_device_context() const {
    PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
    return *reinterpret_cast<const platform::CUDADeviceContext*>(
        &device_context_);
  }
#endif

D
dzhwinter 已提交
306
  //! Get actual name vector for this input.
D
Dong Zhihong 已提交
307 308 309
  const std::vector<std::string>& Inputs(const std::string& name) const {
    return op_.Inputs(name);
  }
D
Dong Zhihong 已提交
310

D
dzhwinter 已提交
311
  //! Get actual name vector for this output.
D
Dong Zhihong 已提交
312 313 314 315
  const std::vector<std::string>& Outputs(const std::string& name) const {
    return op_.Outputs(name);
  }

316
 private:
317 318
  const OperatorBase& op_;
  const Scope& scope_;
319
  const platform::DeviceContext& device_context_;
Q
Qiao Longfei 已提交
320 321
};

322 323 324 325 326 327 328 329 330 331 332 333 334 335
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;

template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
    const std::string& name) const;

template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;

template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
    const std::string& name) const;

Y
Yu Yang 已提交
336
class OpKernelBase {
Q
qijun 已提交
337
 public:
Q
qijun 已提交
338
  /**
339
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
340 341
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
342
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
343 344
   */

345
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
346

Y
Yu Yang 已提交
347 348 349 350 351 352 353
  virtual ~OpKernelBase() = default;
};

template <typename T>
class OpKernel : public OpKernelBase {
 public:
  using ELEMENT_TYPE = T;
Y
Yu Yang 已提交
354 355
};

Y
Yu Yang 已提交
356 357
class OperatorWithKernel : public OperatorBase {
 public:
Y
Yu Yang 已提交
358
  using OpKernelMap =
Y
Yu Yang 已提交
359 360
      std::unordered_map<OpKernelType, std::unique_ptr<OpKernelBase>,
                         OpKernelType::Hash>;
Q
Qiao Longfei 已提交
361

Y
Yu Yang 已提交
362 363
  OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
                     const VariableNameMap& outputs, const AttributeMap& attrs)
Y
Yu Yang 已提交
364 365
      : OperatorBase(type, inputs, outputs, attrs) {}

D
dzhwinter 已提交
366
  void Run(const Scope& scope, const platform::Place& place) const final;
Q
Qiao Longfei 已提交
367

Y
Yu Yang 已提交
368 369 370 371
  static std::unordered_map<std::string /* op_type */, OpKernelMap>&
  AllOpKernels() {
    static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
    return g_all_op_kernels;
Y
Yu Yang 已提交
372
  }
Y
Yan Chunwei 已提交
373

374
  bool SupportGPU() const override {
Y
Yu Yang 已提交
375 376 377 378 379
    auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
    return std::any_of(op_kernels.begin(), op_kernels.end(),
                       [](OpKernelMap::const_reference kern_pair) {
                         return platform::is_gpu_place(kern_pair.first.place_);
                       });
380 381
  }

382 383 384
  virtual void InferShape(InferShapeContext* ctx) const {
    OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
  }
Y
Yu Yang 已提交
385

Q
qiaolongfei 已提交
386
 protected:
387 388 389 390
  virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
  virtual OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const OpKernelType& expected_kernel_type) const;
Y
Yu Yang 已提交
391 392

 private:
Y
Yu Yang 已提交
393 394
  // indicate kernel DataType by input data. Defaultly all input data must be
  // same.
395
  proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
Q
Qiao Longfei 已提交
396 397
};

Y
Yu Yang 已提交
398 399
extern bool OpSupportGPU(const std::string& op_type);

Q
Qiao Longfei 已提交
400 401
}  // namespace framework
}  // namespace paddle