operator.h 21.5 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
18
#include <atomic>
Q
Qiao Longfei 已提交
19 20 21 22
#include <string>
#include <unordered_map>
#include <vector>

Y
Yu Yang 已提交
23
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
24
#include "paddle/framework/attribute.h"
Q
qiaolongfei 已提交
25
#include "paddle/framework/block_desc.h"
Y
Yu Yang 已提交
26
#include "paddle/framework/data_type.h"
Y
Yu Yang 已提交
27
#include "paddle/framework/framework.pb.h"
28
#include "paddle/framework/lod_tensor.h"
Y
Yu Yang 已提交
29
#include "paddle/framework/op_info.h"
Q
qijun 已提交
30
#include "paddle/framework/scope.h"
Q
Qiao Longfei 已提交
31
#include "paddle/framework/shape_inference.h"
Q
qijun 已提交
32 33 34
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
Y
Yu Yang 已提交
35
#include "paddle/platform/variant.h"
Q
qijun 已提交
36
#include "paddle/utils/Error.h"
Q
Qiao Longfei 已提交
37 38 39 40

namespace paddle {
namespace framework {

41
/// If a variable is a empty variable, that name will be used.
42
constexpr char kEmptyVarName[] = "@EMPTY@";
43 44 45

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
46
constexpr char kTempVarName[] = "@TEMP@";
47 48 49 50

/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
51
constexpr char kGradVarSuffix[] = "@GRAD";
52 53

/// Variables with this suffix are supposed to be filled up with zeros.
54
constexpr char kZeroVarSuffix[] = "@ZERO";
55 56 57 58 59

inline std::string GradVarName(const std::string& var_name) {
  return var_name + kGradVarSuffix;
}

Q
Qiao Longfei 已提交
60
class OperatorBase;
61
class ExecutionContext;
62

Q
Qiao Longfei 已提交
63 64 65
extern const Tensor* GetTensorFromVar(const Variable* var);
extern Tensor* GetTensorFromVar(Variable* var);

Q
Qiao Longfei 已提交
66 67 68 69 70 71 72 73
/**
 * OperatorBase has the basic element that Net will call to do computation.
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
Y
Yu Yang 已提交
74 75
  OperatorBase(const std::string& type, const VariableNameMap& inputs,
               const VariableNameMap& outputs, const AttributeMap& attrs);
76

Q
Qiao Longfei 已提交
77 78 79
  virtual ~OperatorBase() {}

  template <typename T>
Y
Yu Yang 已提交
80
  inline const T& Attr(const std::string& name) const {
Q
Qiao Longfei 已提交
81 82 83 84 85
    PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
                   name);
    return boost::get<T>(attrs_.at(name));
  }

86
  virtual std::string DebugString() const;
Q
Qiao Longfei 已提交
87 88

  /// Net will call this function to Run an op.
Y
Yu Yang 已提交
89
  virtual void Run(const Scope& scope,
Y
Yu Yang 已提交
90 91
                   const platform::DeviceContext& dev_ctx) const = 0;

Y
Yu Yang 已提交
92 93
  virtual bool IsNetOp() const { return false; }

94 95
  virtual bool SupportGPU() const { return false; }

D
dongzhihong 已提交
96 97 98
  /// rename inputs outputs name
  void Rename(const std::string& old_name, const std::string& new_name);

Y
Yu Yang 已提交
99 100
  const VariableNameMap& Inputs() const { return inputs_; }
  const VariableNameMap& Outputs() const { return outputs_; }
101

Y
Yu Yang 已提交
102
  //! Get a input with argument's name described in `op_proto`
103
  std::string Input(const std::string& name) const;
Y
Yu Yang 已提交
104
  //! Get a input which has multiple variables.
Y
Yu Yang 已提交
105
  const std::vector<std::string>& Inputs(const std::string& name) const;
Y
Yi Wang 已提交
106

Q
qijun 已提交
107 108
  std::vector<std::string> InputVars() const;

Y
Yu Yang 已提交
109
  //! Get a output with argument's name described in `op_proto`
110
  std::string Output(const std::string& name) const;
Y
Yu Yang 已提交
111 112
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yu Yang 已提交
113
  const std::vector<std::string>& Outputs(const std::string& name) const;
Y
Yan Chunwei 已提交
114

Y
Yu Yang 已提交
115
  virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
116

Q
qiaolongfei 已提交
117
  const std::string& Type() const { return type_; }
Q
qiaolongfei 已提交
118
  void SetType(const std::string& type) { type_ = type; }
Y
Yi Wang 已提交
119 120
  const AttributeMap& Attrs() const { return attrs_; }

Y
Yu Yang 已提交
121
  // Return a new operator instance, which is as same as this.
Y
Yu Yang 已提交
122 123
  // Use unique_ptr to prevent caller forget to delete this pointer.
  virtual std::unique_ptr<OperatorBase> Clone() const = 0;
Y
Yu Yang 已提交
124

Q
qiaolongfei 已提交
125
 protected:
Q
Qiao Longfei 已提交
126
  std::string type_;
D
dongzhihong 已提交
127
  // NOTE: in case of OpGrad, inputs_ contains:
Y
Yu Yang 已提交
128
  // I (Inputs)opear
D
dongzhihong 已提交
129 130
  // O (Outputs)
  // OG (Output Gradients)
Y
Yu Yang 已提交
131
  VariableNameMap inputs_;
Y
Yu Yang 已提交
132

D
dongzhihong 已提交
133 134
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Y
Yu Yang 已提交
135
  VariableNameMap outputs_;
Q
Qiao Longfei 已提交
136
  AttributeMap attrs_;
137 138 139 140

 private:
  void GenerateTemporaryNames();
  void CheckAllInputOutputSet() const;
Y
Yan Chunwei 已提交
141 142
};

Y
Yu Yang 已提交
143 144
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
145
// register it. i.e. `Clone` method is not needed to define by yourself.
146 147 148
#define DEFINE_OP_CLONE_METHOD(cls)                                            \
  std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final {     \
    return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \
Y
Yu Yang 已提交
149
  }
Y
Yu Yang 已提交
150

Y
Yu Yang 已提交
151 152 153 154
// Macro for define a default constructor for Operator.
// You can also use
//   using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
Y
Yu Yang 已提交
155 156
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls)             \
  cls(const std::string& type,                             \
Y
Yu Yang 已提交
157 158 159
      const ::paddle::framework::VariableNameMap& inputs,  \
      const ::paddle::framework::VariableNameMap& outputs, \
      const paddle::framework::AttributeMap& attrs)        \
Y
Yu Yang 已提交
160
      : parent_cls(type, inputs, outputs, attrs) {}
Y
Yu Yang 已提交
161

162 163
class NOP : public OperatorBase {
 public:
164
  using OperatorBase::OperatorBase;
165 166
  void Run(const Scope& scope,
           const platform::DeviceContext& dev_ctx) const override {}
167 168 169
  std::unique_ptr<OperatorBase> Clone() const override {
    return std::unique_ptr<OperatorBase>(new NOP(*this));
  }
170 171
};

172
class ExecutionContext {
Y
Yan Chunwei 已提交
173
 public:
174 175 176
  ExecutionContext(const OperatorBase& op, const Scope& scope,
                   const platform::DeviceContext& device_context)
      : op_(op), scope_(scope), device_context_(device_context) {}
177

Q
qiaolongfei 已提交
178 179 180 181
  const OperatorBase& op() const { return op_; }

  const Scope& scope() const { return scope_; }

Q
qiaolongfei 已提交
182
  template <typename T>
Y
Yu Yang 已提交
183 184
  inline const T& Attr(const std::string& name) const {
    return op_.Attr<T>(name);
Q
qiaolongfei 已提交
185 186
  }

Y
Yu Yang 已提交
187
  size_t InputSize(const std::string& name) const {
Y
Yu Yang 已提交
188
    return op_.Inputs(name).size();
Y
Yan Chunwei 已提交
189 190
  }

Y
Yu Yang 已提交
191
  size_t OutputSize(const std::string& name) const {
Y
Yu Yang 已提交
192
    return op_.Outputs(name).size();
Y
Yan Chunwei 已提交
193 194
  }

195
  const Variable* InputVar(const std::string& name) const {
196
    auto ipt = op_.Input(name);
Y
Yu Yang 已提交
197
    return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
Y
Yan Chunwei 已提交
198 199
  }

200
  Variable* OutputVar(const std::string& name) const {
201
    auto opt = op_.Output(name);
Y
Yu Yang 已提交
202
    return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
Y
Yan Chunwei 已提交
203 204
  }

205 206
  const std::vector<const Variable*> MultiInputVar(
      const std::string& name) const {
Y
Yan Chunwei 已提交
207 208
    auto names = op_.Inputs(name);
    std::vector<const Variable*> res;
209
    res.reserve(names.size());
210 211
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
212 213
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
214
                   });
Y
Yan Chunwei 已提交
215 216 217
    return res;
  }

218
  std::vector<Variable*> MultiOutputVar(const std::string& name) const {
Y
Yan Chunwei 已提交
219
    auto names = op_.Outputs(name);
220
    std::vector<Variable*> res;
221
    res.reserve(names.size());
222 223
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
224 225
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
226
                   });
Y
Yan Chunwei 已提交
227 228 229
    return res;
  }

230 231
  template <typename T>
  const T* Input(const std::string& name) const {
Y
Yu Yang 已提交
232
    auto* var = InputVar(name);
233
    return var == nullptr ? nullptr : &var->Get<T>();
234 235 236 237
  }

  template <typename T>
  T* Output(const std::string& name) const {
238
    auto var = OutputVar(name);
239
    return var == nullptr ? nullptr : var->GetMutable<T>();
240 241 242 243 244 245 246 247
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
    auto names = op_.Inputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
248
                   [&](const std::string& sub_name) {
249
                     auto var = scope_.FindVar(sub_name);
250
                     return var == nullptr ? nullptr : &var->Get<T>();
251 252 253 254 255
                   });
    return res;
  }

  template <typename T>
256
  std::vector<T*> MultiOutput(const std::string& name) const {
257
    auto names = op_.Outputs(name);
258
    std::vector<T*> res;
259 260
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
261
                   [&](const std::string& sub_name) {
262
                     auto var = scope_.FindVar(sub_name);
263
                     return var == nullptr ? nullptr : var->GetMutable<T>();
264 265 266 267
                   });
    return res;
  }

268 269 270 271 272 273
  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) const {
    PADDLE_ENFORCE_LT(i, InputSize(in));
    PADDLE_ENFORCE_LT(j, OutputSize(out));
    auto* in_var = MultiInputVar(in)[i];
    auto* out_var = MultiOutputVar(out)[j];
274
    if (!in_var->IsType<LoDTensor>()) return;
275
    PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
276
                   "The %d-th output of Output(%s) must be LoDTensor.", j, out);
277 278 279
    auto in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_lod(in_tensor.lod());
280 281
  }

Q
qijun 已提交
282
  template <typename PlaceType,
283 284
            typename DeviceType = typename platform::EigenDeviceConverter<
                PlaceType>::EigenDeviceType>
285
  DeviceType& GetEigenDevice() const;
Q
qijun 已提交
286

287
  platform::Place GetPlace() const { return device_context_.GetPlace(); }
Q
qijun 已提交
288

289
  const platform::DeviceContext& device_context() const {
Q
qijun 已提交
290
    return device_context_;
Q
qijun 已提交
291
  }
Q
qijun 已提交
292

武毅 已提交
293 294 295 296 297 298 299 300 301
#ifdef PADDLE_WITH_CUDA
  const platform::CUDADeviceContext& cuda_device_context() const {
    PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
    auto cuda_ctx =
        reinterpret_cast<const platform::CUDADeviceContext*>(&device_context_);
    return *cuda_ctx;
  }
#endif

302
 private:
303 304
  const OperatorBase& op_;
  const Scope& scope_;
305
  const platform::DeviceContext& device_context_;
Q
Qiao Longfei 已提交
306 307
};

308 309 310 311 312 313 314 315 316 317 318 319 320 321
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;

template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
    const std::string& name) const;

template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;

template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
    const std::string& name) const;

322
class CompileTimeInferShapeContext : public InferShapeContext {
Q
tmp  
qiaolongfei 已提交
323
 public:
Q
qiaolongfei 已提交
324 325
  CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
      : op_(op), block_(block) {}
Q
tmp  
qiaolongfei 已提交
326

Q
qiaolongfei 已提交
327
  bool HasInput(const std::string& name) const override {
Q
qiaolongfei 已提交
328
    const std::vector<std::string>& input_names = op_.Input(name);
329
    auto length = input_names.size();
330 331 332
    if (length == 0) {
      return false;
    }
333 334 335 336
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Input(%s) should have only one value, "
                      "but it have %d now",
                      name, length);
337
    return block_.HasVarRecursive(input_names[0]);
Q
tmp  
qiaolongfei 已提交
338 339
  }

Q
qiaolongfei 已提交
340
  bool HasOutput(const std::string& name) const override {
Q
qiaolongfei 已提交
341
    const std::vector<std::string>& output_names = op_.Output(name);
342
    auto length = output_names.size();
343 344 345
    if (length == 0) {
      return false;
    }
346 347 348 349
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Output(%s) should have only one value, "
                      "but it have %d now",
                      name, length);
350
    return block_.HasVarRecursive(output_names[0]);
Q
tmp  
qiaolongfei 已提交
351 352
  }

Q
qiaolongfei 已提交
353
  bool HasInputs(const std::string& name) const override {
Q
qiaolongfei 已提交
354
    const std::vector<std::string>& input_names = op_.Input(name);
355 356 357
    if (input_names.empty()) {
      return false;
    }
Q
qiaolongfei 已提交
358
    for (auto& input : input_names) {
359
      if (!block_.HasVarRecursive(input)) return false;
Q
tmp  
qiaolongfei 已提交
360 361 362 363
    }
    return true;
  }

Q
qiaolongfei 已提交
364
  bool HasOutputs(const std::string& name) const override {
Q
qiaolongfei 已提交
365
    const std::vector<std::string>& output_names = op_.Output(name);
366 367 368
    if (output_names.empty()) {
      return false;
    }
Q
qiaolongfei 已提交
369
    for (auto& output : output_names) {
370
      if (!block_.HasVarRecursive(output)) return false;
Q
tmp  
qiaolongfei 已提交
371 372 373 374
    }
    return true;
  }

Q
qiaolongfei 已提交
375
  DDim GetInputDim(const std::string& name) const override {
Q
qiaolongfei 已提交
376
    std::vector<DDim> ddims = GetInputsDim(name);
377 378 379 380 381
    auto length = ddims.size();
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Input(%s) should have 1 value, "
                      "but it has %d now",
                      name, length);
Q
qiaolongfei 已提交
382
    return ddims[0];
Q
tmp  
qiaolongfei 已提交
383 384
  }

Q
qiaolongfei 已提交
385
  void SetInputDim(const std::string& name, const DDim& dim) override {
Q
qiaolongfei 已提交
386
    SetInputsDim(name, {dim});
Q
tmp  
qiaolongfei 已提交
387 388
  }

Q
qiaolongfei 已提交
389
  DDim GetOutputDim(const std::string& name) const override {
Q
qiaolongfei 已提交
390
    std::vector<DDim> ddims = GetOutputsDim(name);
391 392 393 394 395
    auto length = ddims.size();
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Output(%s) should have 1 value, "
                      "but it has %d now",
                      name, length);
Q
qiaolongfei 已提交
396
    return ddims[0];
Q
tmp  
qiaolongfei 已提交
397 398
  }

Q
qiaolongfei 已提交
399
  void SetOutputDim(const std::string& name, const DDim& dim) override {
Q
qiaolongfei 已提交
400
    SetOutputsDim(name, {dim});
Q
tmp  
qiaolongfei 已提交
401 402
  }

Q
qiaolongfei 已提交
403
  AttrReader Attrs() const override { return AttrReader(op_.GetAttrMap()); }
Q
tmp  
qiaolongfei 已提交
404

Q
qiaolongfei 已提交
405 406
  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
Q
qiaolongfei 已提交
407
    return op_.Input(name);
Q
tmp  
qiaolongfei 已提交
408 409
  }

Q
qiaolongfei 已提交
410 411
  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
Q
qiaolongfei 已提交
412
    return op_.Output(name);
Q
tmp  
qiaolongfei 已提交
413 414 415
  }

 private:
Q
qiaolongfei 已提交
416
  DDim GetDim(const std::string& name) const override {
417
    return framework::make_ddim(block_.FindVarRecursive(name)->Shape());
Q
tmp  
qiaolongfei 已提交
418 419
  }

Q
qiaolongfei 已提交
420
  void SetDim(const std::string& name, const DDim& dim) override {
421
    block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
Q
tmp  
qiaolongfei 已提交
422 423
  }

Q
qiaolongfei 已提交
424 425
  const OpDescBind& op_;
  const BlockDescBind& block_;
Q
tmp  
qiaolongfei 已提交
426 427
};

428
class RuntimeInferShapeContext : public InferShapeContext {
Q
Qiao Longfei 已提交
429 430 431 432
 public:
  RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
      : op_(op), scope_(scope) {}

Q
qiaolongfei 已提交
433
  bool HasInput(const std::string& name) const override {
434 435 436 437 438 439 440 441
    auto& ins = Inputs(name);
    size_t length = ins.size();
    if (length == 0) {
      return false;
    }
    PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs",
                      name);
    auto ipt = ins[0];
Q
Qiao Longfei 已提交
442 443 444 445
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

Q
qiaolongfei 已提交
446
  bool HasOutput(const std::string& name) const override {
447 448 449 450 451 452 453 454
    auto& outs = Outputs(name);
    size_t length = outs.size();
    if (length == 0) {
      return false;
    }
    PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs",
                      name);
    auto ipt = outs[0];
Q
Qiao Longfei 已提交
455 456 457 458
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

Q
qiaolongfei 已提交
459
  bool HasInputs(const std::string& name) const override {
460
    auto inputs = op_.Inputs(name);
Q
qiaolongfei 已提交
461
    if (inputs.empty()) {
462 463 464 465 466 467 468 469 470 471
      return false;
    }
    for (auto& input : inputs) {
      if (scope_.FindVar(input) == nullptr) {
        return false;
      }
    }
    return true;
  }

Q
qiaolongfei 已提交
472
  bool HasOutputs(const std::string& name) const override {
473
    auto outputs = op_.Outputs(name);
Q
qiaolongfei 已提交
474
    if (outputs.empty()) {
475 476 477 478 479 480 481 482 483 484
      return false;
    }
    for (auto& output : outputs) {
      if (scope_.FindVar(output) == nullptr) {
        return false;
      }
    }
    return true;
  }

Q
qiaolongfei 已提交
485
  DDim GetInputDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
486 487 488
    return GetDim(op_.Input(name));
  }

Q
qiaolongfei 已提交
489
  void SetInputDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
490 491 492
    SetDim(op_.Input(name), dim);
  }

Q
qiaolongfei 已提交
493
  DDim GetOutputDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
494 495 496
    return GetDim(op_.Output(name));
  }

Q
qiaolongfei 已提交
497
  void SetOutputDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
498 499 500
    SetDim(op_.Output(name), dim);
  }

Q
qiaolongfei 已提交
501
  AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
Q
Qiao Longfei 已提交
502

Q
qiaolongfei 已提交
503 504
  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
Q
Qiao Longfei 已提交
505 506 507
    return op_.Inputs(name);
  }

Q
qiaolongfei 已提交
508 509
  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
Q
Qiao Longfei 已提交
510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
    return op_.Outputs(name);
  }

 private:
  template <bool Allocate>
  Tensor* GetTensor(const std::string& name) const {
    Tensor* t = nullptr;
    auto* var = scope_.FindVar(name);
    if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) {
      if (Allocate) {
        t = var->GetMutable<LoDTensor>();
      } else {
        PADDLE_THROW("Variable(%s) should be tensor", name);
      }
    } else {
      t = GetTensorFromVar(scope_.FindVar(name));
    }
    return t;
  }

Q
qiaolongfei 已提交
530
  DDim GetDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
531 532 533
    return GetTensor<false>(name)->dims();
  }

Q
qiaolongfei 已提交
534
  void SetDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
535 536 537 538 539 540 541
    GetTensor<true>(name)->Resize(dim);
  }

  const OperatorBase& op_;
  const Scope& scope_;
};

Y
Yu Yang 已提交
542
class OpKernelBase {
Q
qijun 已提交
543
 public:
Q
qijun 已提交
544
  /**
545
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
546 547
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
548
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
549 550
   */

551
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
552

Y
Yu Yang 已提交
553 554 555 556 557 558 559
  virtual ~OpKernelBase() = default;
};

template <typename T>
class OpKernel : public OpKernelBase {
 public:
  using ELEMENT_TYPE = T;
Y
Yu Yang 已提交
560 561
};

Q
Qiao Longfei 已提交
562 563
class OperatorWithKernel : public OperatorBase {
 public:
Y
Yu Yang 已提交
564 565
  struct OpKernelKey {
    platform::Place place_;
Y
Yu Yang 已提交
566
    DataType data_type_;
Q
Qiao Longfei 已提交
567

Y
Yu Yang 已提交
568 569 570 571 572
    OpKernelKey(DataType data_type, platform::Place place)
        : place_(place), data_type_(data_type) {}

    OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
        : place_(dev_ctx.GetPlace()), data_type_(data_type) {}
Y
Yu Yang 已提交
573

Q
qijun 已提交
574
    bool operator==(const OpKernelKey& o) const {
Y
Yu Yang 已提交
575 576
      return platform::places_are_same_class(place_, o.place_) &&
             data_type_ == o.data_type_;
Q
qijun 已提交
577
    }
Y
Yu Yang 已提交
578 579 580
  };

  struct OpKernelHash {
Y
Yu Yang 已提交
581
    std::hash<int> hash_;
Y
Yu Yang 已提交
582
    size_t operator()(const OpKernelKey& key) const {
Y
Yu Yang 已提交
583 584
      int place = key.place_.which();
      int data_type = static_cast<int>(key.data_type_);
Y
Yu Yang 已提交
585 586
      int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
                     (place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
Y
Yu Yang 已提交
587
      return hash_(pre_hash);
Y
Yu Yang 已提交
588 589 590 591
    }
  };

  using OpKernelMap =
Y
Yu Yang 已提交
592 593
      std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
                         OpKernelHash>;
Q
Qiao Longfei 已提交
594

Y
Yu Yang 已提交
595 596
  OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
                     const VariableNameMap& outputs, const AttributeMap& attrs)
Y
Yu Yang 已提交
597 598
      : OperatorBase(type, inputs, outputs, attrs) {}

Y
Yu Yang 已提交
599
  void Run(const Scope& scope,
Y
Yu Yang 已提交
600
           const platform::DeviceContext& dev_ctx) const final {
Y
Yu Yang 已提交
601
    VLOG(3) << "Running operator " << this->Type();
Y
Yu Yang 已提交
602 603 604
    RuntimeInferShapeContext infer_shape_ctx(*this, scope);
    this->InferShape(&infer_shape_ctx);

Y
Yu Yang 已提交
605
    ExecutionContext ctx(*this, scope, dev_ctx);
606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624

    // check if op[type] has kernel registered.
    auto& all_op_kernels = AllOpKernels();
    auto kernels_iter = all_op_kernels.find(type_);
    if (kernels_iter == all_op_kernels.end()) {
      PADDLE_THROW("op[%s] has no kernel", type_);
    }

    // check if op[type] have kernel for kernel_key
    OpKernelMap& kernels = kernels_iter->second;
    auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
    auto kernel_iter = kernels.find(kernel_key);

    if (kernel_iter == kernels.end()) {
      PADDLE_THROW("op[%s] has no kernel with kernel_key[%s]", type_,
                   kernel_key);
    }

    kernel_iter->second->Compute(ctx);
Q
Qiao Longfei 已提交
625 626
  }

Y
Yu Yang 已提交
627 628 629 630
  static std::unordered_map<std::string /* op_type */, OpKernelMap>&
  AllOpKernels() {
    static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
    return g_all_op_kernels;
Y
Yu Yang 已提交
631
  }
Y
Yan Chunwei 已提交
632

633
  bool SupportGPU() const override {
Y
Yu Yang 已提交
634 635 636 637 638
    auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
    return std::any_of(op_kernels.begin(), op_kernels.end(),
                       [](OpKernelMap::const_reference kern_pair) {
                         return platform::is_gpu_place(kern_pair.first.place_);
                       });
639 640
  }

641
  virtual void InferShape(InferShapeContext* ctx) const = 0;
Y
Yu Yang 已提交
642

Q
qiaolongfei 已提交
643
 protected:
Y
Yu Yang 已提交
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
  // indicate kernel DataType by input data. Defaultly all input data must be
  // same.
  virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
    auto& scope = ctx.scope();
    int data_type = -1;
    for (auto& input : this->inputs_) {
      for (auto& ipt_name : input.second) {
        auto* var = scope.FindVar(ipt_name);
        if (var != nullptr) {
          const Tensor* t = nullptr;
          if (var->IsType<Tensor>()) {
            t = &var->Get<Tensor>();
          } else if (var->IsType<LoDTensor>()) {
            t = &var->Get<LoDTensor>();
          }
          if (t != nullptr) {
            int tmp = static_cast<int>(ToDataType(t->type()));
            PADDLE_ENFORCE(tmp == data_type || data_type == -1,
C
caoying03 已提交
662
                           "DataType of Paddle Op must be the same.");
Y
Yu Yang 已提交
663 664 665 666 667 668 669 670
            data_type = tmp;
          }
        }
      }
    }
    PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
    return static_cast<DataType>(data_type);
  }
Q
Qiao Longfei 已提交
671 672
};

673 674 675
std::ostream& operator<<(std::ostream& os,
                         const OperatorWithKernel::OpKernelKey& kernel_key);

Y
Yu Yang 已提交
676 677
extern bool OpSupportGPU(const std::string& op_type);

Q
Qiao Longfei 已提交
678 679
}  // namespace framework
}  // namespace paddle