operator.h 21.8 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
18
#include <atomic>
Q
Qiao Longfei 已提交
19 20 21 22
#include <string>
#include <unordered_map>
#include <vector>

Y
Yu Yang 已提交
23
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
24
#include "paddle/framework/attribute.h"
Q
qiaolongfei 已提交
25
#include "paddle/framework/block_desc.h"
Y
Yu Yang 已提交
26
#include "paddle/framework/data_type.h"
Y
Yu Yang 已提交
27
#include "paddle/framework/framework.pb.h"
28
#include "paddle/framework/lod_tensor.h"
Y
Yu Yang 已提交
29
#include "paddle/framework/op_info.h"
Q
qijun 已提交
30
#include "paddle/framework/scope.h"
Q
Qiao Longfei 已提交
31
#include "paddle/framework/shape_inference.h"
Q
qijun 已提交
32 33 34
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
Y
Yu Yang 已提交
35
#include "paddle/platform/variant.h"
Q
qijun 已提交
36
#include "paddle/utils/Error.h"
Q
Qiao Longfei 已提交
37 38 39 40

namespace paddle {
namespace framework {

41
/// If a variable is a empty variable, that name will be used.
42
constexpr char kEmptyVarName[] = "@EMPTY@";
43 44 45

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
46
constexpr char kTempVarName[] = "@TEMP@";
47 48 49 50

/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
51
constexpr char kGradVarSuffix[] = "@GRAD";
52 53

/// Variables with this suffix are supposed to be filled up with zeros.
54
constexpr char kZeroVarSuffix[] = "@ZERO";
55 56 57 58 59

inline std::string GradVarName(const std::string& var_name) {
  return var_name + kGradVarSuffix;
}

Q
Qiao Longfei 已提交
60
class OperatorBase;
61
class ExecutionContext;
62

Q
Qiao Longfei 已提交
63 64 65
extern const Tensor* GetTensorFromVar(const Variable* var);
extern Tensor* GetTensorFromVar(Variable* var);

Q
Qiao Longfei 已提交
66 67 68 69 70 71 72 73
/**
 * OperatorBase has the basic element that Net will call to do computation.
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
Y
Yu Yang 已提交
74 75
  OperatorBase(const std::string& type, const VariableNameMap& inputs,
               const VariableNameMap& outputs, const AttributeMap& attrs);
76

Q
Qiao Longfei 已提交
77 78 79
  virtual ~OperatorBase() {}

  template <typename T>
Y
Yu Yang 已提交
80
  inline const T& Attr(const std::string& name) const {
Q
Qiao Longfei 已提交
81 82 83 84 85
    PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
                   name);
    return boost::get<T>(attrs_.at(name));
  }

86
  virtual std::string DebugString() const;
Q
Qiao Longfei 已提交
87 88

  /// Net will call this function to Run an op.
Y
Yu Yang 已提交
89
  virtual void Run(const Scope& scope,
Y
Yu Yang 已提交
90 91
                   const platform::DeviceContext& dev_ctx) const = 0;

Y
Yu Yang 已提交
92 93
  virtual bool IsNetOp() const { return false; }

94 95
  virtual bool SupportGPU() const { return false; }

D
dongzhihong 已提交
96 97 98
  /// rename inputs outputs name
  void Rename(const std::string& old_name, const std::string& new_name);

Y
Yu Yang 已提交
99 100
  const VariableNameMap& Inputs() const { return inputs_; }
  const VariableNameMap& Outputs() const { return outputs_; }
101

Y
Yu Yang 已提交
102
  //! Get a input with argument's name described in `op_proto`
103
  std::string Input(const std::string& name) const;
Y
Yu Yang 已提交
104
  //! Get a input which has multiple variables.
Y
Yu Yang 已提交
105
  const std::vector<std::string>& Inputs(const std::string& name) const;
Y
Yi Wang 已提交
106

Q
qijun 已提交
107 108
  std::vector<std::string> InputVars() const;

Y
Yu Yang 已提交
109
  //! Get a output with argument's name described in `op_proto`
110
  std::string Output(const std::string& name) const;
Y
Yu Yang 已提交
111 112
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yu Yang 已提交
113
  const std::vector<std::string>& Outputs(const std::string& name) const;
Y
Yan Chunwei 已提交
114

Y
Yu Yang 已提交
115
  virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
116

Q
qiaolongfei 已提交
117
  const std::string& Type() const { return type_; }
Q
qiaolongfei 已提交
118
  void SetType(const std::string& type) { type_ = type; }
Y
Yi Wang 已提交
119 120
  const AttributeMap& Attrs() const { return attrs_; }

Y
Yu Yang 已提交
121
  // Return a new operator instance, which is as same as this.
Y
Yu Yang 已提交
122 123
  // Use unique_ptr to prevent caller forget to delete this pointer.
  virtual std::unique_ptr<OperatorBase> Clone() const = 0;
Y
Yu Yang 已提交
124

Q
qiaolongfei 已提交
125
 protected:
Q
Qiao Longfei 已提交
126
  std::string type_;
D
dongzhihong 已提交
127
  // NOTE: in case of OpGrad, inputs_ contains:
128
  // I (Inputs)
D
dongzhihong 已提交
129 130
  // O (Outputs)
  // OG (Output Gradients)
Y
Yu Yang 已提交
131
  VariableNameMap inputs_;
Y
Yu Yang 已提交
132

D
dongzhihong 已提交
133 134
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Y
Yu Yang 已提交
135
  VariableNameMap outputs_;
Q
Qiao Longfei 已提交
136
  AttributeMap attrs_;
137 138 139 140

 private:
  void GenerateTemporaryNames();
  void CheckAllInputOutputSet() const;
Y
Yan Chunwei 已提交
141 142
};

Y
Yu Yang 已提交
143 144
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
145
// register it. i.e. `Clone` method is not needed to define by yourself.
146 147 148
#define DEFINE_OP_CLONE_METHOD(cls)                                            \
  std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final {     \
    return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \
Y
Yu Yang 已提交
149
  }
Y
Yu Yang 已提交
150

Y
Yu Yang 已提交
151 152 153 154
// Macro for define a default constructor for Operator.
// You can also use
//   using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
Y
Yu Yang 已提交
155 156
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls)             \
  cls(const std::string& type,                             \
Y
Yu Yang 已提交
157 158 159
      const ::paddle::framework::VariableNameMap& inputs,  \
      const ::paddle::framework::VariableNameMap& outputs, \
      const paddle::framework::AttributeMap& attrs)        \
Y
Yu Yang 已提交
160
      : parent_cls(type, inputs, outputs, attrs) {}
Y
Yu Yang 已提交
161

162 163
class NOP : public OperatorBase {
 public:
164
  using OperatorBase::OperatorBase;
165 166
  void Run(const Scope& scope,
           const platform::DeviceContext& dev_ctx) const override {}
167 168 169
  std::unique_ptr<OperatorBase> Clone() const override {
    return std::unique_ptr<OperatorBase>(new NOP(*this));
  }
170 171
};

172
class ExecutionContext {
Y
Yan Chunwei 已提交
173
 public:
174 175 176
  ExecutionContext(const OperatorBase& op, const Scope& scope,
                   const platform::DeviceContext& device_context)
      : op_(op), scope_(scope), device_context_(device_context) {}
177

Q
qiaolongfei 已提交
178 179 180 181
  const OperatorBase& op() const { return op_; }

  const Scope& scope() const { return scope_; }

Q
qiaolongfei 已提交
182
  template <typename T>
Y
Yu Yang 已提交
183 184
  inline const T& Attr(const std::string& name) const {
    return op_.Attr<T>(name);
Q
qiaolongfei 已提交
185 186
  }

Y
Yu Yang 已提交
187
  size_t InputSize(const std::string& name) const {
Y
Yu Yang 已提交
188
    return op_.Inputs(name).size();
Y
Yan Chunwei 已提交
189 190
  }

Y
Yu Yang 已提交
191
  size_t OutputSize(const std::string& name) const {
Y
Yu Yang 已提交
192
    return op_.Outputs(name).size();
Y
Yan Chunwei 已提交
193 194
  }

195
  const Variable* InputVar(const std::string& name) const {
196
    auto ipt = op_.Input(name);
Y
Yu Yang 已提交
197
    return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
Y
Yan Chunwei 已提交
198 199
  }

200
  Variable* OutputVar(const std::string& name) const {
201
    auto opt = op_.Output(name);
Y
Yu Yang 已提交
202
    return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
Y
Yan Chunwei 已提交
203 204
  }

205 206
  const std::vector<const Variable*> MultiInputVar(
      const std::string& name) const {
Y
Yan Chunwei 已提交
207 208
    auto names = op_.Inputs(name);
    std::vector<const Variable*> res;
209
    res.reserve(names.size());
210 211
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
212 213
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
214
                   });
Y
Yan Chunwei 已提交
215 216 217
    return res;
  }

218
  std::vector<Variable*> MultiOutputVar(const std::string& name) const {
Y
Yan Chunwei 已提交
219
    auto names = op_.Outputs(name);
220
    std::vector<Variable*> res;
221
    res.reserve(names.size());
222 223
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
224 225
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
226
                   });
Y
Yan Chunwei 已提交
227 228 229
    return res;
  }

230 231
  template <typename T>
  const T* Input(const std::string& name) const {
Y
Yu Yang 已提交
232
    auto* var = InputVar(name);
233
    return var == nullptr ? nullptr : &var->Get<T>();
234 235 236 237
  }

  template <typename T>
  T* Output(const std::string& name) const {
238
    auto var = OutputVar(name);
239
    return var == nullptr ? nullptr : var->GetMutable<T>();
240 241 242 243 244 245 246 247
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
    auto names = op_.Inputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
248
                   [&](const std::string& sub_name) {
249
                     auto var = scope_.FindVar(sub_name);
250
                     return var == nullptr ? nullptr : &var->Get<T>();
251 252 253 254 255
                   });
    return res;
  }

  template <typename T>
256
  std::vector<T*> MultiOutput(const std::string& name) const {
257
    auto names = op_.Outputs(name);
258
    std::vector<T*> res;
259 260
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
261
                   [&](const std::string& sub_name) {
262
                     auto var = scope_.FindVar(sub_name);
263
                     return var == nullptr ? nullptr : var->GetMutable<T>();
264 265 266 267
                   });
    return res;
  }

268 269 270 271 272 273
  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) const {
    PADDLE_ENFORCE_LT(i, InputSize(in));
    PADDLE_ENFORCE_LT(j, OutputSize(out));
    auto* in_var = MultiInputVar(in)[i];
    auto* out_var = MultiOutputVar(out)[j];
274
    if (!in_var->IsType<LoDTensor>()) return;
275
    PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
276
                   "The %d-th output of Output(%s) must be LoDTensor.", j, out);
277 278 279
    auto in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_lod(in_tensor.lod());
280 281
  }

Q
qijun 已提交
282
  template <typename PlaceType,
283 284
            typename DeviceType = typename platform::EigenDeviceConverter<
                PlaceType>::EigenDeviceType>
285
  DeviceType& GetEigenDevice() const;
Q
qijun 已提交
286

287
  platform::Place GetPlace() const { return device_context_.GetPlace(); }
Q
qijun 已提交
288

289
  const platform::DeviceContext& device_context() const {
Q
qijun 已提交
290
    return device_context_;
Q
qijun 已提交
291
  }
Q
qijun 已提交
292

D
Dong Zhihong 已提交
293 294 295 296 297 298 299 300 301
  //! Get a input which has multiple variables.
  const std::vector<std::string>& Inputs(const std::string& name) const {
    return op_.Inputs(name);
  }
  //! Get an output which has multiple variables.
  const std::vector<std::string>& Outputs(const std::string& name) const {
    return op_.Outputs(name);
  }

武毅 已提交
302 303 304 305 306 307 308 309 310
#ifdef PADDLE_WITH_CUDA
  const platform::CUDADeviceContext& cuda_device_context() const {
    PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
    auto cuda_ctx =
        reinterpret_cast<const platform::CUDADeviceContext*>(&device_context_);
    return *cuda_ctx;
  }
#endif

311
 private:
312 313
  const OperatorBase& op_;
  const Scope& scope_;
314
  const platform::DeviceContext& device_context_;
Q
Qiao Longfei 已提交
315 316
};

317 318 319 320 321 322 323 324 325 326 327 328 329 330
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;

template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
    const std::string& name) const;

template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;

template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
    const std::string& name) const;

331
class CompileTimeInferShapeContext : public InferShapeContext {
Q
tmp  
qiaolongfei 已提交
332
 public:
Q
qiaolongfei 已提交
333 334
  CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
      : op_(op), block_(block) {}
Q
tmp  
qiaolongfei 已提交
335

Q
qiaolongfei 已提交
336
  bool HasInput(const std::string& name) const override {
Q
qiaolongfei 已提交
337
    const std::vector<std::string>& input_names = op_.Input(name);
338
    auto length = input_names.size();
339 340 341
    if (length == 0) {
      return false;
    }
342 343 344 345
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Input(%s) should have only one value, "
                      "but it have %d now",
                      name, length);
346
    return block_.HasVarRecursive(input_names[0]);
Q
tmp  
qiaolongfei 已提交
347 348
  }

Q
qiaolongfei 已提交
349
  bool HasOutput(const std::string& name) const override {
Q
qiaolongfei 已提交
350
    const std::vector<std::string>& output_names = op_.Output(name);
351
    auto length = output_names.size();
352 353 354
    if (length == 0) {
      return false;
    }
355 356 357 358
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Output(%s) should have only one value, "
                      "but it have %d now",
                      name, length);
359
    return block_.HasVarRecursive(output_names[0]);
Q
tmp  
qiaolongfei 已提交
360 361
  }

Q
qiaolongfei 已提交
362
  bool HasInputs(const std::string& name) const override {
Q
qiaolongfei 已提交
363
    const std::vector<std::string>& input_names = op_.Input(name);
364 365 366
    if (input_names.empty()) {
      return false;
    }
Q
qiaolongfei 已提交
367
    for (auto& input : input_names) {
368
      if (!block_.HasVarRecursive(input)) return false;
Q
tmp  
qiaolongfei 已提交
369 370 371 372
    }
    return true;
  }

Q
qiaolongfei 已提交
373
  bool HasOutputs(const std::string& name) const override {
Q
qiaolongfei 已提交
374
    const std::vector<std::string>& output_names = op_.Output(name);
375 376 377
    if (output_names.empty()) {
      return false;
    }
Q
qiaolongfei 已提交
378
    for (auto& output : output_names) {
379
      if (!block_.HasVarRecursive(output)) return false;
Q
tmp  
qiaolongfei 已提交
380 381 382 383
    }
    return true;
  }

Q
qiaolongfei 已提交
384
  DDim GetInputDim(const std::string& name) const override {
Q
qiaolongfei 已提交
385
    std::vector<DDim> ddims = GetInputsDim(name);
386 387 388 389 390
    auto length = ddims.size();
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Input(%s) should have 1 value, "
                      "but it has %d now",
                      name, length);
Q
qiaolongfei 已提交
391
    return ddims[0];
Q
tmp  
qiaolongfei 已提交
392 393
  }

Q
qiaolongfei 已提交
394
  void SetInputDim(const std::string& name, const DDim& dim) override {
Q
qiaolongfei 已提交
395
    SetInputsDim(name, {dim});
Q
tmp  
qiaolongfei 已提交
396 397
  }

Q
qiaolongfei 已提交
398
  DDim GetOutputDim(const std::string& name) const override {
Q
qiaolongfei 已提交
399
    std::vector<DDim> ddims = GetOutputsDim(name);
400 401 402 403 404
    auto length = ddims.size();
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Output(%s) should have 1 value, "
                      "but it has %d now",
                      name, length);
Q
qiaolongfei 已提交
405
    return ddims[0];
Q
tmp  
qiaolongfei 已提交
406 407
  }

Q
qiaolongfei 已提交
408
  void SetOutputDim(const std::string& name, const DDim& dim) override {
Q
qiaolongfei 已提交
409
    SetOutputsDim(name, {dim});
Q
tmp  
qiaolongfei 已提交
410 411
  }

Q
qiaolongfei 已提交
412
  AttrReader Attrs() const override { return AttrReader(op_.GetAttrMap()); }
Q
tmp  
qiaolongfei 已提交
413

Q
qiaolongfei 已提交
414 415
  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
Q
qiaolongfei 已提交
416
    return op_.Input(name);
Q
tmp  
qiaolongfei 已提交
417 418
  }

Q
qiaolongfei 已提交
419 420
  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
Q
qiaolongfei 已提交
421
    return op_.Output(name);
Q
tmp  
qiaolongfei 已提交
422 423 424
  }

 private:
Q
qiaolongfei 已提交
425
  DDim GetDim(const std::string& name) const override {
426
    return framework::make_ddim(block_.FindVarRecursive(name)->Shape());
Q
tmp  
qiaolongfei 已提交
427 428
  }

Q
qiaolongfei 已提交
429
  void SetDim(const std::string& name, const DDim& dim) override {
430
    block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
Q
tmp  
qiaolongfei 已提交
431 432
  }

Q
qiaolongfei 已提交
433 434
  const OpDescBind& op_;
  const BlockDescBind& block_;
Q
tmp  
qiaolongfei 已提交
435 436
};

437
class RuntimeInferShapeContext : public InferShapeContext {
Q
Qiao Longfei 已提交
438 439 440 441
 public:
  RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
      : op_(op), scope_(scope) {}

Q
qiaolongfei 已提交
442
  bool HasInput(const std::string& name) const override {
443 444 445 446 447 448 449 450
    auto& ins = Inputs(name);
    size_t length = ins.size();
    if (length == 0) {
      return false;
    }
    PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs",
                      name);
    auto ipt = ins[0];
Q
Qiao Longfei 已提交
451 452 453 454
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

Q
qiaolongfei 已提交
455
  bool HasOutput(const std::string& name) const override {
456 457 458 459 460 461 462 463
    auto& outs = Outputs(name);
    size_t length = outs.size();
    if (length == 0) {
      return false;
    }
    PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs",
                      name);
    auto ipt = outs[0];
Q
Qiao Longfei 已提交
464 465 466 467
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

Q
qiaolongfei 已提交
468
  bool HasInputs(const std::string& name) const override {
469
    auto inputs = op_.Inputs(name);
Q
qiaolongfei 已提交
470
    if (inputs.empty()) {
471 472 473 474 475 476 477 478 479 480
      return false;
    }
    for (auto& input : inputs) {
      if (scope_.FindVar(input) == nullptr) {
        return false;
      }
    }
    return true;
  }

Q
qiaolongfei 已提交
481
  bool HasOutputs(const std::string& name) const override {
482
    auto outputs = op_.Outputs(name);
Q
qiaolongfei 已提交
483
    if (outputs.empty()) {
484 485 486 487 488 489 490 491 492 493
      return false;
    }
    for (auto& output : outputs) {
      if (scope_.FindVar(output) == nullptr) {
        return false;
      }
    }
    return true;
  }

Q
qiaolongfei 已提交
494
  DDim GetInputDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
495 496 497
    return GetDim(op_.Input(name));
  }

Q
qiaolongfei 已提交
498
  void SetInputDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
499 500 501
    SetDim(op_.Input(name), dim);
  }

Q
qiaolongfei 已提交
502
  DDim GetOutputDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
503 504 505
    return GetDim(op_.Output(name));
  }

Q
qiaolongfei 已提交
506
  void SetOutputDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
507 508 509
    SetDim(op_.Output(name), dim);
  }

Q
qiaolongfei 已提交
510
  AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
Q
Qiao Longfei 已提交
511

Q
qiaolongfei 已提交
512 513
  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
Q
Qiao Longfei 已提交
514 515 516
    return op_.Inputs(name);
  }

Q
qiaolongfei 已提交
517 518
  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
Q
Qiao Longfei 已提交
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
    return op_.Outputs(name);
  }

 private:
  template <bool Allocate>
  Tensor* GetTensor(const std::string& name) const {
    Tensor* t = nullptr;
    auto* var = scope_.FindVar(name);
    if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) {
      if (Allocate) {
        t = var->GetMutable<LoDTensor>();
      } else {
        PADDLE_THROW("Variable(%s) should be tensor", name);
      }
    } else {
      t = GetTensorFromVar(scope_.FindVar(name));
    }
    return t;
  }

Q
qiaolongfei 已提交
539
  DDim GetDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
540 541 542
    return GetTensor<false>(name)->dims();
  }

Q
qiaolongfei 已提交
543
  void SetDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
544 545 546 547 548 549 550
    GetTensor<true>(name)->Resize(dim);
  }

  const OperatorBase& op_;
  const Scope& scope_;
};

Y
Yu Yang 已提交
551
class OpKernelBase {
Q
qijun 已提交
552
 public:
Q
qijun 已提交
553
  /**
554
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
555 556
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
557
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
558 559
   */

560
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
561

Y
Yu Yang 已提交
562 563 564 565 566 567 568
  virtual ~OpKernelBase() = default;
};

template <typename T>
class OpKernel : public OpKernelBase {
 public:
  using ELEMENT_TYPE = T;
Y
Yu Yang 已提交
569 570
};

Q
Qiao Longfei 已提交
571 572
class OperatorWithKernel : public OperatorBase {
 public:
Y
Yu Yang 已提交
573 574
  struct OpKernelKey {
    platform::Place place_;
Y
Yu Yang 已提交
575
    DataType data_type_;
Q
Qiao Longfei 已提交
576

Y
Yu Yang 已提交
577 578 579 580 581
    OpKernelKey(DataType data_type, platform::Place place)
        : place_(place), data_type_(data_type) {}

    OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
        : place_(dev_ctx.GetPlace()), data_type_(data_type) {}
Y
Yu Yang 已提交
582

Q
qijun 已提交
583
    bool operator==(const OpKernelKey& o) const {
Y
Yu Yang 已提交
584 585
      return platform::places_are_same_class(place_, o.place_) &&
             data_type_ == o.data_type_;
Q
qijun 已提交
586
    }
Y
Yu Yang 已提交
587 588 589
  };

  struct OpKernelHash {
Y
Yu Yang 已提交
590
    std::hash<int> hash_;
Y
Yu Yang 已提交
591
    size_t operator()(const OpKernelKey& key) const {
Y
Yu Yang 已提交
592 593
      int place = key.place_.which();
      int data_type = static_cast<int>(key.data_type_);
Y
Yu Yang 已提交
594 595
      int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
                     (place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
Y
Yu Yang 已提交
596
      return hash_(pre_hash);
Y
Yu Yang 已提交
597 598 599 600
    }
  };

  using OpKernelMap =
Y
Yu Yang 已提交
601 602
      std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
                         OpKernelHash>;
Q
Qiao Longfei 已提交
603

Y
Yu Yang 已提交
604 605
  OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
                     const VariableNameMap& outputs, const AttributeMap& attrs)
Y
Yu Yang 已提交
606 607
      : OperatorBase(type, inputs, outputs, attrs) {}

Y
Yu Yang 已提交
608
  void Run(const Scope& scope,
Y
Yu Yang 已提交
609
           const platform::DeviceContext& dev_ctx) const final {
Y
Yu Yang 已提交
610
    VLOG(3) << "Running operator " << this->Type();
Y
Yu Yang 已提交
611 612 613
    RuntimeInferShapeContext infer_shape_ctx(*this, scope);
    this->InferShape(&infer_shape_ctx);

Y
Yu Yang 已提交
614
    ExecutionContext ctx(*this, scope, dev_ctx);
615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633

    // check if op[type] has kernel registered.
    auto& all_op_kernels = AllOpKernels();
    auto kernels_iter = all_op_kernels.find(type_);
    if (kernels_iter == all_op_kernels.end()) {
      PADDLE_THROW("op[%s] has no kernel", type_);
    }

    // check if op[type] have kernel for kernel_key
    OpKernelMap& kernels = kernels_iter->second;
    auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
    auto kernel_iter = kernels.find(kernel_key);

    if (kernel_iter == kernels.end()) {
      PADDLE_THROW("op[%s] has no kernel with kernel_key[%s]", type_,
                   kernel_key);
    }

    kernel_iter->second->Compute(ctx);
Q
Qiao Longfei 已提交
634 635
  }

Y
Yu Yang 已提交
636 637 638 639
  static std::unordered_map<std::string /* op_type */, OpKernelMap>&
  AllOpKernels() {
    static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
    return g_all_op_kernels;
Y
Yu Yang 已提交
640
  }
Y
Yan Chunwei 已提交
641

642
  bool SupportGPU() const override {
Y
Yu Yang 已提交
643 644 645 646 647
    auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
    return std::any_of(op_kernels.begin(), op_kernels.end(),
                       [](OpKernelMap::const_reference kern_pair) {
                         return platform::is_gpu_place(kern_pair.first.place_);
                       });
648 649
  }

650
  virtual void InferShape(InferShapeContext* ctx) const = 0;
Y
Yu Yang 已提交
651

Q
qiaolongfei 已提交
652
 protected:
Y
Yu Yang 已提交
653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679
  // indicate kernel DataType by input data. Defaultly all input data must be
  // same.
  virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
    auto& scope = ctx.scope();
    int data_type = -1;
    for (auto& input : this->inputs_) {
      for (auto& ipt_name : input.second) {
        auto* var = scope.FindVar(ipt_name);
        if (var != nullptr) {
          const Tensor* t = nullptr;
          if (var->IsType<Tensor>()) {
            t = &var->Get<Tensor>();
          } else if (var->IsType<LoDTensor>()) {
            t = &var->Get<LoDTensor>();
          }
          if (t != nullptr) {
            int tmp = static_cast<int>(ToDataType(t->type()));
            PADDLE_ENFORCE(tmp == data_type || data_type == -1,
                           "DataType of Paddle Op must be same.");
            data_type = tmp;
          }
        }
      }
    }
    PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
    return static_cast<DataType>(data_type);
  }
Q
Qiao Longfei 已提交
680 681
};

682 683 684
std::ostream& operator<<(std::ostream& os,
                         const OperatorWithKernel::OpKernelKey& kernel_key);

Y
Yu Yang 已提交
685 686
extern bool OpSupportGPU(const std::string& op_type);

Q
Qiao Longfei 已提交
687 688
}  // namespace framework
}  // namespace paddle