operator.h 21.8 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
18
#include <atomic>
Q
Qiao Longfei 已提交
19 20 21 22
#include <string>
#include <unordered_map>
#include <vector>

Y
Yu Yang 已提交
23
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
24
#include "paddle/framework/attribute.h"
Q
qiaolongfei 已提交
25
#include "paddle/framework/block_desc.h"
Y
Yu Yang 已提交
26
#include "paddle/framework/data_type.h"
Y
Yu Yang 已提交
27
#include "paddle/framework/framework.pb.h"
28
#include "paddle/framework/lod_tensor.h"
Y
Yu Yang 已提交
29
#include "paddle/framework/op_info.h"
Q
qijun 已提交
30
#include "paddle/framework/scope.h"
Q
Qiao Longfei 已提交
31
#include "paddle/framework/shape_inference.h"
Q
qijun 已提交
32 33 34
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
Y
Yu Yang 已提交
35
#include "paddle/platform/variant.h"
Q
qijun 已提交
36
#include "paddle/utils/Error.h"
Q
Qiao Longfei 已提交
37 38 39 40

namespace paddle {
namespace framework {

41
/// If a variable is a empty variable, that name will be used.
42
constexpr char kEmptyVarName[] = "@EMPTY@";
43 44 45

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
46
constexpr char kTempVarName[] = "@TEMP@";
47 48 49 50

/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
51
constexpr char kGradVarSuffix[] = "@GRAD";
52 53

/// Variables with this suffix are supposed to be filled up with zeros.
54
constexpr char kZeroVarSuffix[] = "@ZERO";
55 56 57 58 59

inline std::string GradVarName(const std::string& var_name) {
  return var_name + kGradVarSuffix;
}

Q
Qiao Longfei 已提交
60
class OperatorBase;
61
class ExecutionContext;
62

Q
Qiao Longfei 已提交
63 64 65
extern const Tensor* GetTensorFromVar(const Variable* var);
extern Tensor* GetTensorFromVar(Variable* var);

Q
Qiao Longfei 已提交
66 67 68 69 70 71 72 73
/**
 * OperatorBase has the basic element that Net will call to do computation.
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
Y
Yu Yang 已提交
74 75
  OperatorBase(const std::string& type, const VariableNameMap& inputs,
               const VariableNameMap& outputs, const AttributeMap& attrs);
76

Q
Qiao Longfei 已提交
77 78 79
  virtual ~OperatorBase() {}

  template <typename T>
Y
Yu Yang 已提交
80
  inline const T& Attr(const std::string& name) const {
Q
Qiao Longfei 已提交
81 82 83 84 85
    PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
                   name);
    return boost::get<T>(attrs_.at(name));
  }

86
  virtual std::string DebugString() const;
Q
Qiao Longfei 已提交
87 88

  /// Net will call this function to Run an op.
Y
Yu Yang 已提交
89
  virtual void Run(const Scope& scope,
Y
Yu Yang 已提交
90 91
                   const platform::DeviceContext& dev_ctx) const = 0;

Y
Yu Yang 已提交
92 93
  virtual bool IsNetOp() const { return false; }

94 95
  virtual bool SupportGPU() const { return false; }

D
dongzhihong 已提交
96 97 98
  /// rename inputs outputs name
  void Rename(const std::string& old_name, const std::string& new_name);

Y
Yu Yang 已提交
99 100
  const VariableNameMap& Inputs() const { return inputs_; }
  const VariableNameMap& Outputs() const { return outputs_; }
101

Y
Yu Yang 已提交
102
  //! Get a input with argument's name described in `op_proto`
103
  std::string Input(const std::string& name) const;
Y
Yu Yang 已提交
104
  //! Get a input which has multiple variables.
Y
Yu Yang 已提交
105
  const std::vector<std::string>& Inputs(const std::string& name) const;
Y
Yi Wang 已提交
106

Q
qijun 已提交
107 108
  std::vector<std::string> InputVars() const;

Y
Yu Yang 已提交
109
  //! Get a output with argument's name described in `op_proto`
110
  std::string Output(const std::string& name) const;
Y
Yu Yang 已提交
111 112
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yu Yang 已提交
113
  const std::vector<std::string>& Outputs(const std::string& name) const;
Y
Yan Chunwei 已提交
114

Y
Yu Yang 已提交
115
  virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
116

Q
qiaolongfei 已提交
117
  const std::string& Type() const { return type_; }
Q
qiaolongfei 已提交
118
  void SetType(const std::string& type) { type_ = type; }
Y
Yi Wang 已提交
119 120
  const AttributeMap& Attrs() const { return attrs_; }

Y
Yu Yang 已提交
121
  // Return a new operator instance, which is as same as this.
Y
Yu Yang 已提交
122 123
  // Use unique_ptr to prevent caller forget to delete this pointer.
  virtual std::unique_ptr<OperatorBase> Clone() const = 0;
Y
Yu Yang 已提交
124

Q
qiaolongfei 已提交
125
 protected:
Q
Qiao Longfei 已提交
126
  std::string type_;
D
dongzhihong 已提交
127
  // NOTE: in case of OpGrad, inputs_ contains:
128
  // I (Inputs)
D
dongzhihong 已提交
129 130
  // O (Outputs)
  // OG (Output Gradients)
Y
Yu Yang 已提交
131
  VariableNameMap inputs_;
Y
Yu Yang 已提交
132

D
dongzhihong 已提交
133 134
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Y
Yu Yang 已提交
135
  VariableNameMap outputs_;
Q
Qiao Longfei 已提交
136
  AttributeMap attrs_;
137 138 139 140

 private:
  void GenerateTemporaryNames();
  void CheckAllInputOutputSet() const;
Y
Yan Chunwei 已提交
141 142
};

Y
Yu Yang 已提交
143 144
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
145
// register it. i.e. `Clone` method is not needed to define by yourself.
146 147 148
#define DEFINE_OP_CLONE_METHOD(cls)                                            \
  std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final {     \
    return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \
Y
Yu Yang 已提交
149
  }
Y
Yu Yang 已提交
150

Y
Yu Yang 已提交
151 152 153 154
// Macro for define a default constructor for Operator.
// You can also use
//   using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
Y
Yu Yang 已提交
155 156
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls)             \
  cls(const std::string& type,                             \
Y
Yu Yang 已提交
157 158 159
      const ::paddle::framework::VariableNameMap& inputs,  \
      const ::paddle::framework::VariableNameMap& outputs, \
      const paddle::framework::AttributeMap& attrs)        \
Y
Yu Yang 已提交
160
      : parent_cls(type, inputs, outputs, attrs) {}
Y
Yu Yang 已提交
161

162 163
class NOP : public OperatorBase {
 public:
164
  using OperatorBase::OperatorBase;
165 166
  void Run(const Scope& scope,
           const platform::DeviceContext& dev_ctx) const override {}
167 168 169
  std::unique_ptr<OperatorBase> Clone() const override {
    return std::unique_ptr<OperatorBase>(new NOP(*this));
  }
170 171
};

172
class ExecutionContext {
Y
Yan Chunwei 已提交
173
 public:
174 175 176
  ExecutionContext(const OperatorBase& op, const Scope& scope,
                   const platform::DeviceContext& device_context)
      : op_(op), scope_(scope), device_context_(device_context) {}
177

Q
qiaolongfei 已提交
178 179 180 181
  const OperatorBase& op() const { return op_; }

  const Scope& scope() const { return scope_; }

Q
qiaolongfei 已提交
182
  template <typename T>
Y
Yu Yang 已提交
183 184
  inline const T& Attr(const std::string& name) const {
    return op_.Attr<T>(name);
Q
qiaolongfei 已提交
185 186
  }

Y
Yu Yang 已提交
187
  size_t InputSize(const std::string& name) const {
Y
Yu Yang 已提交
188
    return op_.Inputs(name).size();
Y
Yan Chunwei 已提交
189 190
  }

Y
Yu Yang 已提交
191
  size_t OutputSize(const std::string& name) const {
Y
Yu Yang 已提交
192
    return op_.Outputs(name).size();
Y
Yan Chunwei 已提交
193 194
  }

195
  const Variable* InputVar(const std::string& name) const {
196
    auto ipt = op_.Input(name);
Y
Yu Yang 已提交
197
    return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
Y
Yan Chunwei 已提交
198 199
  }

200
  Variable* OutputVar(const std::string& name) const {
201
    auto opt = op_.Output(name);
Y
Yu Yang 已提交
202
    return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
Y
Yan Chunwei 已提交
203 204
  }

205 206
  const std::vector<const Variable*> MultiInputVar(
      const std::string& name) const {
Y
Yan Chunwei 已提交
207 208
    auto names = op_.Inputs(name);
    std::vector<const Variable*> res;
209
    res.reserve(names.size());
210 211
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
212 213
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
214
                   });
Y
Yan Chunwei 已提交
215 216 217
    return res;
  }

218
  std::vector<Variable*> MultiOutputVar(const std::string& name) const {
Y
Yan Chunwei 已提交
219
    auto names = op_.Outputs(name);
220
    std::vector<Variable*> res;
221
    res.reserve(names.size());
222 223
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
224 225
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
226
                   });
Y
Yan Chunwei 已提交
227 228 229
    return res;
  }

230 231
  template <typename T>
  const T* Input(const std::string& name) const {
Y
Yu Yang 已提交
232
    auto* var = InputVar(name);
233
    return var == nullptr ? nullptr : &var->Get<T>();
234 235 236 237
  }

  template <typename T>
  T* Output(const std::string& name) const {
238
    auto var = OutputVar(name);
239
    return var == nullptr ? nullptr : var->GetMutable<T>();
240 241 242 243 244 245 246 247
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
    auto names = op_.Inputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
248
                   [&](const std::string& sub_name) {
249
                     auto var = scope_.FindVar(sub_name);
250
                     return var == nullptr ? nullptr : &var->Get<T>();
251 252 253 254 255
                   });
    return res;
  }

  template <typename T>
256
  std::vector<T*> MultiOutput(const std::string& name) const {
257
    auto names = op_.Outputs(name);
258
    std::vector<T*> res;
259 260
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
261
                   [&](const std::string& sub_name) {
262
                     auto var = scope_.FindVar(sub_name);
263
                     return var == nullptr ? nullptr : var->GetMutable<T>();
264 265 266 267
                   });
    return res;
  }

268 269 270 271 272 273
  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) const {
    PADDLE_ENFORCE_LT(i, InputSize(in));
    PADDLE_ENFORCE_LT(j, OutputSize(out));
    auto* in_var = MultiInputVar(in)[i];
    auto* out_var = MultiOutputVar(out)[j];
274
    if (!in_var->IsType<LoDTensor>()) return;
275
    PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
276
                   "The %d-th output of Output(%s) must be LoDTensor.", j, out);
277 278 279
    auto in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_lod(in_tensor.lod());
280 281
  }

Q
qijun 已提交
282
  template <typename PlaceType,
283 284
            typename DeviceType = typename platform::EigenDeviceConverter<
                PlaceType>::EigenDeviceType>
285
  DeviceType& GetEigenDevice() const;
Q
qijun 已提交
286

287
  platform::Place GetPlace() const { return device_context_.GetPlace(); }
Q
qijun 已提交
288

289
  const platform::DeviceContext& device_context() const {
Q
qijun 已提交
290
    return device_context_;
Q
qijun 已提交
291
  }
Q
qijun 已提交
292

D
Dong Zhihong 已提交
293
  //! Get variables vector with same input name.
D
Dong Zhihong 已提交
294 295 296
  const std::vector<std::string>& Inputs(const std::string& name) const {
    return op_.Inputs(name);
  }
D
Dong Zhihong 已提交
297 298

  //! Get variables vector with same output name.
D
Dong Zhihong 已提交
299 300 301 302
  const std::vector<std::string>& Outputs(const std::string& name) const {
    return op_.Outputs(name);
  }

武毅 已提交
303 304 305 306 307 308 309 310 311
#ifdef PADDLE_WITH_CUDA
  const platform::CUDADeviceContext& cuda_device_context() const {
    PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
    auto cuda_ctx =
        reinterpret_cast<const platform::CUDADeviceContext*>(&device_context_);
    return *cuda_ctx;
  }
#endif

312
 private:
313 314
  const OperatorBase& op_;
  const Scope& scope_;
315
  const platform::DeviceContext& device_context_;
Q
Qiao Longfei 已提交
316 317
};

318 319 320 321 322 323 324 325 326 327 328 329 330 331
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;

template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
    const std::string& name) const;

template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;

template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
    const std::string& name) const;

332
class CompileTimeInferShapeContext : public InferShapeContext {
Q
tmp  
qiaolongfei 已提交
333
 public:
Q
qiaolongfei 已提交
334 335
  CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
      : op_(op), block_(block) {}
Q
tmp  
qiaolongfei 已提交
336

Q
qiaolongfei 已提交
337
  bool HasInput(const std::string& name) const override {
Q
qiaolongfei 已提交
338
    const std::vector<std::string>& input_names = op_.Input(name);
339
    auto length = input_names.size();
340 341 342
    if (length == 0) {
      return false;
    }
343 344 345 346
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Input(%s) should have only one value, "
                      "but it have %d now",
                      name, length);
347
    return block_.HasVarRecursive(input_names[0]);
Q
tmp  
qiaolongfei 已提交
348 349
  }

Q
qiaolongfei 已提交
350
  bool HasOutput(const std::string& name) const override {
Q
qiaolongfei 已提交
351
    const std::vector<std::string>& output_names = op_.Output(name);
352
    auto length = output_names.size();
353 354 355
    if (length == 0) {
      return false;
    }
356 357 358 359
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Output(%s) should have only one value, "
                      "but it have %d now",
                      name, length);
360
    return block_.HasVarRecursive(output_names[0]);
Q
tmp  
qiaolongfei 已提交
361 362
  }

Q
qiaolongfei 已提交
363
  bool HasInputs(const std::string& name) const override {
Q
qiaolongfei 已提交
364
    const std::vector<std::string>& input_names = op_.Input(name);
365 366 367
    if (input_names.empty()) {
      return false;
    }
Q
qiaolongfei 已提交
368
    for (auto& input : input_names) {
369
      if (!block_.HasVarRecursive(input)) return false;
Q
tmp  
qiaolongfei 已提交
370 371 372 373
    }
    return true;
  }

Q
qiaolongfei 已提交
374
  bool HasOutputs(const std::string& name) const override {
Q
qiaolongfei 已提交
375
    const std::vector<std::string>& output_names = op_.Output(name);
376 377 378
    if (output_names.empty()) {
      return false;
    }
Q
qiaolongfei 已提交
379
    for (auto& output : output_names) {
380
      if (!block_.HasVarRecursive(output)) return false;
Q
tmp  
qiaolongfei 已提交
381 382 383 384
    }
    return true;
  }

Q
qiaolongfei 已提交
385
  DDim GetInputDim(const std::string& name) const override {
Q
qiaolongfei 已提交
386
    std::vector<DDim> ddims = GetInputsDim(name);
387 388 389 390 391
    auto length = ddims.size();
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Input(%s) should have 1 value, "
                      "but it has %d now",
                      name, length);
Q
qiaolongfei 已提交
392
    return ddims[0];
Q
tmp  
qiaolongfei 已提交
393 394
  }

Q
qiaolongfei 已提交
395
  void SetInputDim(const std::string& name, const DDim& dim) override {
Q
qiaolongfei 已提交
396
    SetInputsDim(name, {dim});
Q
tmp  
qiaolongfei 已提交
397 398
  }

Q
qiaolongfei 已提交
399
  DDim GetOutputDim(const std::string& name) const override {
Q
qiaolongfei 已提交
400
    std::vector<DDim> ddims = GetOutputsDim(name);
401 402 403 404 405
    auto length = ddims.size();
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Output(%s) should have 1 value, "
                      "but it has %d now",
                      name, length);
Q
qiaolongfei 已提交
406
    return ddims[0];
Q
tmp  
qiaolongfei 已提交
407 408
  }

Q
qiaolongfei 已提交
409
  void SetOutputDim(const std::string& name, const DDim& dim) override {
Q
qiaolongfei 已提交
410
    SetOutputsDim(name, {dim});
Q
tmp  
qiaolongfei 已提交
411 412
  }

Q
qiaolongfei 已提交
413
  AttrReader Attrs() const override { return AttrReader(op_.GetAttrMap()); }
Q
tmp  
qiaolongfei 已提交
414

Q
qiaolongfei 已提交
415 416
  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
Q
qiaolongfei 已提交
417
    return op_.Input(name);
Q
tmp  
qiaolongfei 已提交
418 419
  }

Q
qiaolongfei 已提交
420 421
  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
Q
qiaolongfei 已提交
422
    return op_.Output(name);
Q
tmp  
qiaolongfei 已提交
423 424 425
  }

 private:
Q
qiaolongfei 已提交
426
  DDim GetDim(const std::string& name) const override {
427
    return framework::make_ddim(block_.FindVarRecursive(name)->Shape());
Q
tmp  
qiaolongfei 已提交
428 429
  }

Q
qiaolongfei 已提交
430
  void SetDim(const std::string& name, const DDim& dim) override {
431
    block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
Q
tmp  
qiaolongfei 已提交
432 433
  }

Q
qiaolongfei 已提交
434 435
  const OpDescBind& op_;
  const BlockDescBind& block_;
Q
tmp  
qiaolongfei 已提交
436 437
};

438
class RuntimeInferShapeContext : public InferShapeContext {
Q
Qiao Longfei 已提交
439 440 441 442
 public:
  RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
      : op_(op), scope_(scope) {}

Q
qiaolongfei 已提交
443
  bool HasInput(const std::string& name) const override {
444 445 446 447 448 449 450 451
    auto& ins = Inputs(name);
    size_t length = ins.size();
    if (length == 0) {
      return false;
    }
    PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs",
                      name);
    auto ipt = ins[0];
Q
Qiao Longfei 已提交
452 453 454 455
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

Q
qiaolongfei 已提交
456
  bool HasOutput(const std::string& name) const override {
457 458 459 460 461 462 463 464
    auto& outs = Outputs(name);
    size_t length = outs.size();
    if (length == 0) {
      return false;
    }
    PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs",
                      name);
    auto ipt = outs[0];
Q
Qiao Longfei 已提交
465 466 467 468
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

Q
qiaolongfei 已提交
469
  bool HasInputs(const std::string& name) const override {
470
    auto inputs = op_.Inputs(name);
Q
qiaolongfei 已提交
471
    if (inputs.empty()) {
472 473 474 475 476 477 478 479 480 481
      return false;
    }
    for (auto& input : inputs) {
      if (scope_.FindVar(input) == nullptr) {
        return false;
      }
    }
    return true;
  }

Q
qiaolongfei 已提交
482
  bool HasOutputs(const std::string& name) const override {
483
    auto outputs = op_.Outputs(name);
Q
qiaolongfei 已提交
484
    if (outputs.empty()) {
485 486 487 488 489 490 491 492 493 494
      return false;
    }
    for (auto& output : outputs) {
      if (scope_.FindVar(output) == nullptr) {
        return false;
      }
    }
    return true;
  }

Q
qiaolongfei 已提交
495
  DDim GetInputDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
496 497 498
    return GetDim(op_.Input(name));
  }

Q
qiaolongfei 已提交
499
  void SetInputDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
500 501 502
    SetDim(op_.Input(name), dim);
  }

Q
qiaolongfei 已提交
503
  DDim GetOutputDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
504 505 506
    return GetDim(op_.Output(name));
  }

Q
qiaolongfei 已提交
507
  void SetOutputDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
508 509 510
    SetDim(op_.Output(name), dim);
  }

Q
qiaolongfei 已提交
511
  AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
Q
Qiao Longfei 已提交
512

Q
qiaolongfei 已提交
513 514
  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
Q
Qiao Longfei 已提交
515 516 517
    return op_.Inputs(name);
  }

Q
qiaolongfei 已提交
518 519
  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
Q
Qiao Longfei 已提交
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539
    return op_.Outputs(name);
  }

 private:
  template <bool Allocate>
  Tensor* GetTensor(const std::string& name) const {
    Tensor* t = nullptr;
    auto* var = scope_.FindVar(name);
    if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) {
      if (Allocate) {
        t = var->GetMutable<LoDTensor>();
      } else {
        PADDLE_THROW("Variable(%s) should be tensor", name);
      }
    } else {
      t = GetTensorFromVar(scope_.FindVar(name));
    }
    return t;
  }

Q
qiaolongfei 已提交
540
  DDim GetDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
541 542 543
    return GetTensor<false>(name)->dims();
  }

Q
qiaolongfei 已提交
544
  void SetDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
545 546 547 548 549 550 551
    GetTensor<true>(name)->Resize(dim);
  }

  const OperatorBase& op_;
  const Scope& scope_;
};

Y
Yu Yang 已提交
552
class OpKernelBase {
Q
qijun 已提交
553
 public:
Q
qijun 已提交
554
  /**
555
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
556 557
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
558
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
559 560
   */

561
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
562

Y
Yu Yang 已提交
563 564 565 566 567 568 569
  virtual ~OpKernelBase() = default;
};

template <typename T>
class OpKernel : public OpKernelBase {
 public:
  using ELEMENT_TYPE = T;
Y
Yu Yang 已提交
570 571
};

Q
Qiao Longfei 已提交
572 573
class OperatorWithKernel : public OperatorBase {
 public:
Y
Yu Yang 已提交
574 575
  struct OpKernelKey {
    platform::Place place_;
Y
Yu Yang 已提交
576
    DataType data_type_;
Q
Qiao Longfei 已提交
577

Y
Yu Yang 已提交
578 579 580 581 582
    OpKernelKey(DataType data_type, platform::Place place)
        : place_(place), data_type_(data_type) {}

    OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
        : place_(dev_ctx.GetPlace()), data_type_(data_type) {}
Y
Yu Yang 已提交
583

Q
qijun 已提交
584
    bool operator==(const OpKernelKey& o) const {
Y
Yu Yang 已提交
585 586
      return platform::places_are_same_class(place_, o.place_) &&
             data_type_ == o.data_type_;
Q
qijun 已提交
587
    }
Y
Yu Yang 已提交
588 589 590
  };

  struct OpKernelHash {
Y
Yu Yang 已提交
591
    std::hash<int> hash_;
Y
Yu Yang 已提交
592
    size_t operator()(const OpKernelKey& key) const {
Y
Yu Yang 已提交
593 594
      int place = key.place_.which();
      int data_type = static_cast<int>(key.data_type_);
Y
Yu Yang 已提交
595 596
      int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
                     (place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
Y
Yu Yang 已提交
597
      return hash_(pre_hash);
Y
Yu Yang 已提交
598 599 600 601
    }
  };

  using OpKernelMap =
Y
Yu Yang 已提交
602 603
      std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
                         OpKernelHash>;
Q
Qiao Longfei 已提交
604

Y
Yu Yang 已提交
605 606
  OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
                     const VariableNameMap& outputs, const AttributeMap& attrs)
Y
Yu Yang 已提交
607 608
      : OperatorBase(type, inputs, outputs, attrs) {}

Y
Yu Yang 已提交
609
  void Run(const Scope& scope,
Y
Yu Yang 已提交
610
           const platform::DeviceContext& dev_ctx) const final {
Y
Yu Yang 已提交
611
    VLOG(3) << "Running operator " << this->Type();
Y
Yu Yang 已提交
612 613 614
    RuntimeInferShapeContext infer_shape_ctx(*this, scope);
    this->InferShape(&infer_shape_ctx);

Y
Yu Yang 已提交
615
    ExecutionContext ctx(*this, scope, dev_ctx);
616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634

    // check if op[type] has kernel registered.
    auto& all_op_kernels = AllOpKernels();
    auto kernels_iter = all_op_kernels.find(type_);
    if (kernels_iter == all_op_kernels.end()) {
      PADDLE_THROW("op[%s] has no kernel", type_);
    }

    // check if op[type] have kernel for kernel_key
    OpKernelMap& kernels = kernels_iter->second;
    auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
    auto kernel_iter = kernels.find(kernel_key);

    if (kernel_iter == kernels.end()) {
      PADDLE_THROW("op[%s] has no kernel with kernel_key[%s]", type_,
                   kernel_key);
    }

    kernel_iter->second->Compute(ctx);
Q
Qiao Longfei 已提交
635 636
  }

Y
Yu Yang 已提交
637 638 639 640
  static std::unordered_map<std::string /* op_type */, OpKernelMap>&
  AllOpKernels() {
    static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
    return g_all_op_kernels;
Y
Yu Yang 已提交
641
  }
Y
Yan Chunwei 已提交
642

643
  bool SupportGPU() const override {
Y
Yu Yang 已提交
644 645 646 647 648
    auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
    return std::any_of(op_kernels.begin(), op_kernels.end(),
                       [](OpKernelMap::const_reference kern_pair) {
                         return platform::is_gpu_place(kern_pair.first.place_);
                       });
649 650
  }

651
  virtual void InferShape(InferShapeContext* ctx) const = 0;
Y
Yu Yang 已提交
652

Q
qiaolongfei 已提交
653
 protected:
Y
Yu Yang 已提交
654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680
  // indicate kernel DataType by input data. Defaultly all input data must be
  // same.
  virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
    auto& scope = ctx.scope();
    int data_type = -1;
    for (auto& input : this->inputs_) {
      for (auto& ipt_name : input.second) {
        auto* var = scope.FindVar(ipt_name);
        if (var != nullptr) {
          const Tensor* t = nullptr;
          if (var->IsType<Tensor>()) {
            t = &var->Get<Tensor>();
          } else if (var->IsType<LoDTensor>()) {
            t = &var->Get<LoDTensor>();
          }
          if (t != nullptr) {
            int tmp = static_cast<int>(ToDataType(t->type()));
            PADDLE_ENFORCE(tmp == data_type || data_type == -1,
                           "DataType of Paddle Op must be same.");
            data_type = tmp;
          }
        }
      }
    }
    PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
    return static_cast<DataType>(data_type);
  }
Q
Qiao Longfei 已提交
681 682
};

683 684 685
std::ostream& operator<<(std::ostream& os,
                         const OperatorWithKernel::OpKernelKey& kernel_key);

Y
Yu Yang 已提交
686 687
extern bool OpSupportGPU(const std::string& op_type);

Q
Qiao Longfei 已提交
688 689
}  // namespace framework
}  // namespace paddle