operator.h 20.9 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
18
#include <atomic>
Q
Qiao Longfei 已提交
19 20 21 22
#include <string>
#include <unordered_map>
#include <vector>

Y
Yu Yang 已提交
23
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
24
#include "paddle/framework/attribute.h"
Q
qiaolongfei 已提交
25
#include "paddle/framework/block_desc.h"
Y
Yu Yang 已提交
26
#include "paddle/framework/data_type.h"
Y
Yu Yang 已提交
27
#include "paddle/framework/framework.pb.h"
28
#include "paddle/framework/lod_tensor.h"
Y
Yu Yang 已提交
29
#include "paddle/framework/op_info.h"
Q
qijun 已提交
30
#include "paddle/framework/scope.h"
Q
Qiao Longfei 已提交
31
#include "paddle/framework/shape_inference.h"
Q
qijun 已提交
32 33 34
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
Y
Yu Yang 已提交
35
#include "paddle/platform/variant.h"
Q
qijun 已提交
36
#include "paddle/utils/Error.h"
Q
Qiao Longfei 已提交
37 38 39 40

namespace paddle {
namespace framework {

41
/// If a variable is a empty variable, that name will be used.
42
constexpr char kEmptyVarName[] = "@EMPTY@";
43 44 45

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
46
constexpr char kTempVarName[] = "@TEMP@";
47 48 49 50

/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
51
constexpr char kGradVarSuffix[] = "@GRAD";
52 53

/// Variables with this suffix are supposed to be filled up with zeros.
54
constexpr char kZeroVarSuffix[] = "@ZERO";
55 56 57 58 59

inline std::string GradVarName(const std::string& var_name) {
  return var_name + kGradVarSuffix;
}

Q
Qiao Longfei 已提交
60
class OperatorBase;
61
class ExecutionContext;
62

Q
Qiao Longfei 已提交
63 64 65
extern const Tensor* GetTensorFromVar(const Variable* var);
extern Tensor* GetTensorFromVar(Variable* var);

Q
Qiao Longfei 已提交
66 67 68 69 70 71 72 73
/**
 * OperatorBase has the basic element that Net will call to do computation.
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
Y
Yu Yang 已提交
74 75
  OperatorBase(const std::string& type, const VariableNameMap& inputs,
               const VariableNameMap& outputs, const AttributeMap& attrs);
76

Q
Qiao Longfei 已提交
77 78 79
  virtual ~OperatorBase() {}

  template <typename T>
Y
Yu Yang 已提交
80
  inline const T& Attr(const std::string& name) const {
Q
Qiao Longfei 已提交
81 82 83 84 85
    PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
                   name);
    return boost::get<T>(attrs_.at(name));
  }

86
  virtual std::string DebugString() const;
Q
Qiao Longfei 已提交
87 88

  /// Net will call this function to Run an op.
Y
Yu Yang 已提交
89
  virtual void Run(const Scope& scope,
Y
Yu Yang 已提交
90 91
                   const platform::DeviceContext& dev_ctx) const = 0;

Y
Yu Yang 已提交
92 93
  virtual bool IsNetOp() const { return false; }

94 95
  virtual bool SupportGPU() const { return false; }

D
dongzhihong 已提交
96 97 98
  /// rename inputs outputs name
  void Rename(const std::string& old_name, const std::string& new_name);

Y
Yu Yang 已提交
99 100
  const VariableNameMap& Inputs() const { return inputs_; }
  const VariableNameMap& Outputs() const { return outputs_; }
101

Y
Yu Yang 已提交
102
  //! Get a input with argument's name described in `op_proto`
103
  std::string Input(const std::string& name) const;
Y
Yu Yang 已提交
104
  //! Get a input which has multiple variables.
Y
Yu Yang 已提交
105
  const std::vector<std::string>& Inputs(const std::string& name) const;
Y
Yi Wang 已提交
106

Q
qijun 已提交
107 108
  std::vector<std::string> InputVars() const;

Y
Yu Yang 已提交
109
  //! Get a output with argument's name described in `op_proto`
110
  std::string Output(const std::string& name) const;
Y
Yu Yang 已提交
111 112
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yu Yang 已提交
113
  const std::vector<std::string>& Outputs(const std::string& name) const;
Y
Yan Chunwei 已提交
114

Y
Yu Yang 已提交
115
  virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
116

Q
qiaolongfei 已提交
117
  const std::string& Type() const { return type_; }
Q
qiaolongfei 已提交
118
  void SetType(const std::string& type) { type_ = type; }
Y
Yi Wang 已提交
119 120
  const AttributeMap& Attrs() const { return attrs_; }

Y
Yu Yang 已提交
121
  // Return a new operator instance, which is as same as this.
Y
Yu Yang 已提交
122 123
  // Use unique_ptr to prevent caller forget to delete this pointer.
  virtual std::unique_ptr<OperatorBase> Clone() const = 0;
Y
Yu Yang 已提交
124

Q
qiaolongfei 已提交
125
 protected:
Q
Qiao Longfei 已提交
126
  std::string type_;
D
dongzhihong 已提交
127
  // NOTE: in case of OpGrad, inputs_ contains:
Y
Yu Yang 已提交
128
  // I (Inputs)opear
D
dongzhihong 已提交
129 130
  // O (Outputs)
  // OG (Output Gradients)
Y
Yu Yang 已提交
131
  VariableNameMap inputs_;
Y
Yu Yang 已提交
132

D
dongzhihong 已提交
133 134
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Y
Yu Yang 已提交
135
  VariableNameMap outputs_;
Q
Qiao Longfei 已提交
136
  AttributeMap attrs_;
137 138 139 140

 private:
  void GenerateTemporaryNames();
  void CheckAllInputOutputSet() const;
Y
Yan Chunwei 已提交
141 142
};

Y
Yu Yang 已提交
143 144
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
145
// register it. i.e. `Clone` method is not needed to define by yourself.
146 147 148
#define DEFINE_OP_CLONE_METHOD(cls)                                            \
  std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final {     \
    return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \
Y
Yu Yang 已提交
149
  }
Y
Yu Yang 已提交
150

Y
Yu Yang 已提交
151 152 153 154
// Macro for define a default constructor for Operator.
// You can also use
//   using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
Y
Yu Yang 已提交
155 156
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls)             \
  cls(const std::string& type,                             \
Y
Yu Yang 已提交
157 158 159
      const ::paddle::framework::VariableNameMap& inputs,  \
      const ::paddle::framework::VariableNameMap& outputs, \
      const paddle::framework::AttributeMap& attrs)        \
Y
Yu Yang 已提交
160
      : parent_cls(type, inputs, outputs, attrs) {}
Y
Yu Yang 已提交
161

162 163
class NOP : public OperatorBase {
 public:
164
  using OperatorBase::OperatorBase;
165 166
  void Run(const Scope& scope,
           const platform::DeviceContext& dev_ctx) const override {}
167 168 169
  std::unique_ptr<OperatorBase> Clone() const override {
    return std::unique_ptr<OperatorBase>(new NOP(*this));
  }
170 171
};

172
class ExecutionContext {
Y
Yan Chunwei 已提交
173
 public:
174 175 176
  ExecutionContext(const OperatorBase& op, const Scope& scope,
                   const platform::DeviceContext& device_context)
      : op_(op), scope_(scope), device_context_(device_context) {}
177

Q
qiaolongfei 已提交
178 179 180 181
  const OperatorBase& op() const { return op_; }

  const Scope& scope() const { return scope_; }

Q
qiaolongfei 已提交
182
  template <typename T>
Y
Yu Yang 已提交
183 184
  inline const T& Attr(const std::string& name) const {
    return op_.Attr<T>(name);
Q
qiaolongfei 已提交
185 186
  }

Y
Yu Yang 已提交
187
  size_t InputSize(const std::string& name) const {
Y
Yu Yang 已提交
188
    return op_.Inputs(name).size();
Y
Yan Chunwei 已提交
189 190
  }

Y
Yu Yang 已提交
191
  size_t OutputSize(const std::string& name) const {
Y
Yu Yang 已提交
192
    return op_.Outputs(name).size();
Y
Yan Chunwei 已提交
193 194
  }

195
  const Variable* InputVar(const std::string& name) const {
196
    auto ipt = op_.Input(name);
Y
Yu Yang 已提交
197
    return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
Y
Yan Chunwei 已提交
198 199
  }

200
  Variable* OutputVar(const std::string& name) const {
201
    auto opt = op_.Output(name);
Y
Yu Yang 已提交
202
    return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
Y
Yan Chunwei 已提交
203 204
  }

205 206
  const std::vector<const Variable*> MultiInputVar(
      const std::string& name) const {
Y
Yan Chunwei 已提交
207 208
    auto names = op_.Inputs(name);
    std::vector<const Variable*> res;
209
    res.reserve(names.size());
210 211
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
212 213
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
214
                   });
Y
Yan Chunwei 已提交
215 216 217
    return res;
  }

218
  std::vector<Variable*> MultiOutputVar(const std::string& name) const {
Y
Yan Chunwei 已提交
219
    auto names = op_.Outputs(name);
220
    std::vector<Variable*> res;
221
    res.reserve(names.size());
222 223
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
224 225
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
226
                   });
Y
Yan Chunwei 已提交
227 228 229
    return res;
  }

230 231
  template <typename T>
  const T* Input(const std::string& name) const {
Y
Yu Yang 已提交
232
    auto* var = InputVar(name);
233
    return var == nullptr ? nullptr : &var->Get<T>();
234 235 236 237
  }

  template <typename T>
  T* Output(const std::string& name) const {
238
    auto var = OutputVar(name);
239
    return var == nullptr ? nullptr : var->GetMutable<T>();
240 241 242 243 244 245 246 247
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
    auto names = op_.Inputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
248
                   [&](const std::string& sub_name) {
249
                     auto var = scope_.FindVar(sub_name);
250
                     return var == nullptr ? nullptr : &var->Get<T>();
251 252 253 254 255
                   });
    return res;
  }

  template <typename T>
256
  std::vector<T*> MultiOutput(const std::string& name) const {
257
    auto names = op_.Outputs(name);
258
    std::vector<T*> res;
259 260
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
261
                   [&](const std::string& sub_name) {
262
                     auto var = scope_.FindVar(sub_name);
263
                     return var == nullptr ? nullptr : var->GetMutable<T>();
264 265 266 267
                   });
    return res;
  }

268 269 270 271 272 273
  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) const {
    PADDLE_ENFORCE_LT(i, InputSize(in));
    PADDLE_ENFORCE_LT(j, OutputSize(out));
    auto* in_var = MultiInputVar(in)[i];
    auto* out_var = MultiOutputVar(out)[j];
274
    if (!in_var->IsType<LoDTensor>()) return;
275
    PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
276
                   "The %d-th output of Output(%s) must be LoDTensor.", j, out);
277 278 279
    auto in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_lod(in_tensor.lod());
280 281
  }

Q
qijun 已提交
282
  template <typename PlaceType,
283 284
            typename DeviceType = typename platform::EigenDeviceConverter<
                PlaceType>::EigenDeviceType>
285
  DeviceType& GetEigenDevice() const;
Q
qijun 已提交
286

287
  platform::Place GetPlace() const { return device_context_.GetPlace(); }
Q
qijun 已提交
288

289
  const platform::DeviceContext& device_context() const {
Q
qijun 已提交
290
    return device_context_;
Q
qijun 已提交
291
  }
Q
qijun 已提交
292

武毅 已提交
293 294 295 296 297 298 299 300 301
#ifdef PADDLE_WITH_CUDA
  const platform::CUDADeviceContext& cuda_device_context() const {
    PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
    auto cuda_ctx =
        reinterpret_cast<const platform::CUDADeviceContext*>(&device_context_);
    return *cuda_ctx;
  }
#endif

302
 private:
303 304
  const OperatorBase& op_;
  const Scope& scope_;
305
  const platform::DeviceContext& device_context_;
Q
Qiao Longfei 已提交
306 307
};

308 309 310 311 312 313 314 315 316 317 318 319 320 321
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;

template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
    const std::string& name) const;

template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;

template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
    const std::string& name) const;

322
class CompileTimeInferShapeContext : public InferShapeContext {
Q
tmp  
qiaolongfei 已提交
323
 public:
Q
qiaolongfei 已提交
324 325
  CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
      : op_(op), block_(block) {}
Q
tmp  
qiaolongfei 已提交
326

Q
qiaolongfei 已提交
327
  bool HasInput(const std::string& name) const override {
Q
qiaolongfei 已提交
328
    const std::vector<std::string>& input_names = op_.Input(name);
329 330 331 332 333
    auto length = input_names.size();
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Input(%s) should have only one value, "
                      "but it have %d now",
                      name, length);
Q
qiaolongfei 已提交
334
    return block_.HasVar(input_names[0]);
Q
tmp  
qiaolongfei 已提交
335 336
  }

Q
qiaolongfei 已提交
337
  bool HasOutput(const std::string& name) const override {
Q
qiaolongfei 已提交
338
    const std::vector<std::string>& output_names = op_.Output(name);
339 340 341 342 343
    auto length = output_names.size();
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Output(%s) should have only one value, "
                      "but it have %d now",
                      name, length);
Q
qiaolongfei 已提交
344
    return block_.HasVar(output_names[0]);
Q
tmp  
qiaolongfei 已提交
345 346
  }

Q
qiaolongfei 已提交
347
  bool HasInputs(const std::string& name) const override {
Q
qiaolongfei 已提交
348
    const std::vector<std::string>& input_names = op_.Input(name);
Q
qiaolongfei 已提交
349
    PADDLE_ENFORCE(!input_names.empty(), "Inputs(%s) length is 0", name);
Q
qiaolongfei 已提交
350 351
    for (auto& input : input_names) {
      if (!block_.HasVar(input)) return false;
Q
tmp  
qiaolongfei 已提交
352 353 354 355
    }
    return true;
  }

Q
qiaolongfei 已提交
356
  bool HasOutputs(const std::string& name) const override {
Q
qiaolongfei 已提交
357
    const std::vector<std::string>& output_names = op_.Output(name);
Q
qiaolongfei 已提交
358
    PADDLE_ENFORCE(!output_names.empty(), "Inputs(%s) length is 0", name);
Q
qiaolongfei 已提交
359
    for (auto& output : output_names) {
Q
qiaolongfei 已提交
360
      if (!block_.HasVar(output)) return false;
Q
tmp  
qiaolongfei 已提交
361 362 363 364
    }
    return true;
  }

Q
qiaolongfei 已提交
365
  DDim GetInputDim(const std::string& name) const override {
Q
qiaolongfei 已提交
366
    std::vector<DDim> ddims = GetInputsDim(name);
367 368 369 370 371
    auto length = ddims.size();
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Input(%s) should have 1 value, "
                      "but it has %d now",
                      name, length);
Q
qiaolongfei 已提交
372
    return ddims[0];
Q
tmp  
qiaolongfei 已提交
373 374
  }

Q
qiaolongfei 已提交
375
  void SetInputDim(const std::string& name, const DDim& dim) override {
Q
qiaolongfei 已提交
376
    SetInputsDim(name, {dim});
Q
tmp  
qiaolongfei 已提交
377 378
  }

Q
qiaolongfei 已提交
379
  DDim GetOutputDim(const std::string& name) const override {
Q
qiaolongfei 已提交
380
    std::vector<DDim> ddims = GetOutputsDim(name);
381 382 383 384 385
    auto length = ddims.size();
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Output(%s) should have 1 value, "
                      "but it has %d now",
                      name, length);
Q
qiaolongfei 已提交
386
    return ddims[0];
Q
tmp  
qiaolongfei 已提交
387 388
  }

Q
qiaolongfei 已提交
389
  void SetOutputDim(const std::string& name, const DDim& dim) override {
Q
qiaolongfei 已提交
390
    SetOutputsDim(name, {dim});
Q
tmp  
qiaolongfei 已提交
391 392
  }

Q
qiaolongfei 已提交
393
  AttrReader Attrs() const override { return AttrReader(op_.GetAttrMap()); }
Q
tmp  
qiaolongfei 已提交
394

Q
qiaolongfei 已提交
395 396
  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
Q
qiaolongfei 已提交
397
    return op_.Input(name);
Q
tmp  
qiaolongfei 已提交
398 399
  }

Q
qiaolongfei 已提交
400 401
  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
Q
qiaolongfei 已提交
402
    return op_.Output(name);
Q
tmp  
qiaolongfei 已提交
403 404 405
  }

 private:
Q
qiaolongfei 已提交
406
  DDim GetDim(const std::string& name) const override {
D
Dong Zhihong 已提交
407
    return framework::make_ddim(block_.FindVar(name)->Shape());
Q
tmp  
qiaolongfei 已提交
408 409
  }

Q
qiaolongfei 已提交
410
  void SetDim(const std::string& name, const DDim& dim) override {
D
Dong Zhihong 已提交
411
    block_.FindVar(name)->SetShape(framework::vectorize(dim));
Q
tmp  
qiaolongfei 已提交
412 413
  }

Q
qiaolongfei 已提交
414 415
  const OpDescBind& op_;
  const BlockDescBind& block_;
Q
tmp  
qiaolongfei 已提交
416 417
};

418
class RuntimeInferShapeContext : public InferShapeContext {
Q
Qiao Longfei 已提交
419 420 421 422
 public:
  RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
      : op_(op), scope_(scope) {}

Q
qiaolongfei 已提交
423
  bool HasInput(const std::string& name) const override {
Q
Qiao Longfei 已提交
424 425 426 427 428
    auto ipt = op_.Input(name);
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

Q
qiaolongfei 已提交
429
  bool HasOutput(const std::string& name) const override {
Q
Qiao Longfei 已提交
430 431 432 433 434
    auto ipt = op_.Output(name);
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

Q
qiaolongfei 已提交
435
  bool HasInputs(const std::string& name) const override {
436
    auto inputs = op_.Inputs(name);
Q
qiaolongfei 已提交
437
    if (inputs.empty()) {
438 439 440 441 442 443 444 445 446 447
      return false;
    }
    for (auto& input : inputs) {
      if (scope_.FindVar(input) == nullptr) {
        return false;
      }
    }
    return true;
  }

Q
qiaolongfei 已提交
448
  bool HasOutputs(const std::string& name) const override {
449
    auto outputs = op_.Outputs(name);
Q
qiaolongfei 已提交
450
    if (outputs.empty()) {
451 452 453 454 455 456 457 458 459 460
      return false;
    }
    for (auto& output : outputs) {
      if (scope_.FindVar(output) == nullptr) {
        return false;
      }
    }
    return true;
  }

Q
qiaolongfei 已提交
461
  DDim GetInputDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
462 463 464
    return GetDim(op_.Input(name));
  }

Q
qiaolongfei 已提交
465
  void SetInputDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
466 467 468
    SetDim(op_.Input(name), dim);
  }

Q
qiaolongfei 已提交
469
  DDim GetOutputDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
470 471 472
    return GetDim(op_.Output(name));
  }

Q
qiaolongfei 已提交
473
  void SetOutputDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
474 475 476
    SetDim(op_.Output(name), dim);
  }

Q
qiaolongfei 已提交
477
  AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
Q
Qiao Longfei 已提交
478

Q
qiaolongfei 已提交
479 480
  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
Q
Qiao Longfei 已提交
481 482 483
    return op_.Inputs(name);
  }

Q
qiaolongfei 已提交
484 485
  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
Q
Qiao Longfei 已提交
486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
    return op_.Outputs(name);
  }

 private:
  template <bool Allocate>
  Tensor* GetTensor(const std::string& name) const {
    Tensor* t = nullptr;
    auto* var = scope_.FindVar(name);
    if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) {
      if (Allocate) {
        t = var->GetMutable<LoDTensor>();
      } else {
        PADDLE_THROW("Variable(%s) should be tensor", name);
      }
    } else {
      t = GetTensorFromVar(scope_.FindVar(name));
    }
    return t;
  }

Q
qiaolongfei 已提交
506
  DDim GetDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
507 508 509
    return GetTensor<false>(name)->dims();
  }

Q
qiaolongfei 已提交
510
  void SetDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
511 512 513 514 515 516 517
    GetTensor<true>(name)->Resize(dim);
  }

  const OperatorBase& op_;
  const Scope& scope_;
};

Y
Yu Yang 已提交
518
class OpKernelBase {
Q
qijun 已提交
519
 public:
Q
qijun 已提交
520
  /**
521
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
522 523
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
524
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
525 526
   */

527
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
528

Y
Yu Yang 已提交
529 530 531 532 533 534 535
  virtual ~OpKernelBase() = default;
};

template <typename T>
class OpKernel : public OpKernelBase {
 public:
  using ELEMENT_TYPE = T;
Y
Yu Yang 已提交
536 537
};

Q
Qiao Longfei 已提交
538 539
class OperatorWithKernel : public OperatorBase {
 public:
Y
Yu Yang 已提交
540 541
  struct OpKernelKey {
    platform::Place place_;
Y
Yu Yang 已提交
542
    DataType data_type_;
Q
Qiao Longfei 已提交
543

Y
Yu Yang 已提交
544 545 546 547 548
    OpKernelKey(DataType data_type, platform::Place place)
        : place_(place), data_type_(data_type) {}

    OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
        : place_(dev_ctx.GetPlace()), data_type_(data_type) {}
Y
Yu Yang 已提交
549

Q
qijun 已提交
550
    bool operator==(const OpKernelKey& o) const {
Y
Yu Yang 已提交
551 552
      return platform::places_are_same_class(place_, o.place_) &&
             data_type_ == o.data_type_;
Q
qijun 已提交
553
    }
Y
Yu Yang 已提交
554 555 556
  };

  struct OpKernelHash {
Y
Yu Yang 已提交
557
    std::hash<int> hash_;
Y
Yu Yang 已提交
558
    size_t operator()(const OpKernelKey& key) const {
Y
Yu Yang 已提交
559 560
      int place = key.place_.which();
      int data_type = static_cast<int>(key.data_type_);
Y
Yu Yang 已提交
561 562
      int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
                     (place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
Y
Yu Yang 已提交
563
      return hash_(pre_hash);
Y
Yu Yang 已提交
564 565 566 567
    }
  };

  using OpKernelMap =
Y
Yu Yang 已提交
568 569
      std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
                         OpKernelHash>;
Q
Qiao Longfei 已提交
570

Y
Yu Yang 已提交
571 572
  OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
                     const VariableNameMap& outputs, const AttributeMap& attrs)
Y
Yu Yang 已提交
573 574
      : OperatorBase(type, inputs, outputs, attrs) {}

Y
Yu Yang 已提交
575
  void Run(const Scope& scope,
Y
Yu Yang 已提交
576
           const platform::DeviceContext& dev_ctx) const final {
Y
Yu Yang 已提交
577
    VLOG(3) << "Running operator " << this->Type();
Y
Yu Yang 已提交
578 579 580
    RuntimeInferShapeContext infer_shape_ctx(*this, scope);
    this->InferShape(&infer_shape_ctx);

Y
Yu Yang 已提交
581
    ExecutionContext ctx(*this, scope, dev_ctx);
582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600

    // check if op[type] has kernel registered.
    auto& all_op_kernels = AllOpKernels();
    auto kernels_iter = all_op_kernels.find(type_);
    if (kernels_iter == all_op_kernels.end()) {
      PADDLE_THROW("op[%s] has no kernel", type_);
    }

    // check if op[type] have kernel for kernel_key
    OpKernelMap& kernels = kernels_iter->second;
    auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
    auto kernel_iter = kernels.find(kernel_key);

    if (kernel_iter == kernels.end()) {
      PADDLE_THROW("op[%s] has no kernel with kernel_key[%s]", type_,
                   kernel_key);
    }

    kernel_iter->second->Compute(ctx);
Q
Qiao Longfei 已提交
601 602
  }

Y
Yu Yang 已提交
603 604 605 606
  static std::unordered_map<std::string /* op_type */, OpKernelMap>&
  AllOpKernels() {
    static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
    return g_all_op_kernels;
Y
Yu Yang 已提交
607
  }
Y
Yan Chunwei 已提交
608

609
  bool SupportGPU() const override {
Y
Yu Yang 已提交
610 611 612 613 614
    auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
    return std::any_of(op_kernels.begin(), op_kernels.end(),
                       [](OpKernelMap::const_reference kern_pair) {
                         return platform::is_gpu_place(kern_pair.first.place_);
                       });
615 616
  }

617
  virtual void InferShape(InferShapeContext* ctx) const = 0;
Y
Yu Yang 已提交
618

Q
qiaolongfei 已提交
619
 protected:
Y
Yu Yang 已提交
620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646
  // indicate kernel DataType by input data. Defaultly all input data must be
  // same.
  virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
    auto& scope = ctx.scope();
    int data_type = -1;
    for (auto& input : this->inputs_) {
      for (auto& ipt_name : input.second) {
        auto* var = scope.FindVar(ipt_name);
        if (var != nullptr) {
          const Tensor* t = nullptr;
          if (var->IsType<Tensor>()) {
            t = &var->Get<Tensor>();
          } else if (var->IsType<LoDTensor>()) {
            t = &var->Get<LoDTensor>();
          }
          if (t != nullptr) {
            int tmp = static_cast<int>(ToDataType(t->type()));
            PADDLE_ENFORCE(tmp == data_type || data_type == -1,
                           "DataType of Paddle Op must be same.");
            data_type = tmp;
          }
        }
      }
    }
    PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
    return static_cast<DataType>(data_type);
  }
Q
Qiao Longfei 已提交
647 648
};

649 650 651
std::ostream& operator<<(std::ostream& os,
                         const OperatorWithKernel::OpKernelKey& kernel_key);

Q
Qiao Longfei 已提交
652 653
}  // namespace framework
}  // namespace paddle