operator.h 21.7 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
18
#include <atomic>
Q
Qiao Longfei 已提交
19 20 21 22
#include <string>
#include <unordered_map>
#include <vector>

Y
Yu Yang 已提交
23
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
24
#include "paddle/framework/attribute.h"
Q
qiaolongfei 已提交
25
#include "paddle/framework/block_desc.h"
Y
Yu Yang 已提交
26
#include "paddle/framework/data_type.h"
Y
Yu Yang 已提交
27
#include "paddle/framework/framework.pb.h"
28
#include "paddle/framework/lod_tensor.h"
Y
Yu Yang 已提交
29
#include "paddle/framework/op_info.h"
Q
qijun 已提交
30
#include "paddle/framework/scope.h"
Q
Qiao Longfei 已提交
31
#include "paddle/framework/shape_inference.h"
Q
qijun 已提交
32 33 34
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
Y
Yu Yang 已提交
35
#include "paddle/platform/variant.h"
Q
qijun 已提交
36
#include "paddle/utils/Error.h"
Q
Qiao Longfei 已提交
37 38 39 40

namespace paddle {
namespace framework {

41
/// If a variable is a empty variable, that name will be used.
42
constexpr char kEmptyVarName[] = "@EMPTY@";
43 44 45

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
46
constexpr char kTempVarName[] = "@TEMP@";
47 48 49 50

/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
51
constexpr char kGradVarSuffix[] = "@GRAD";
52 53

/// Variables with this suffix are supposed to be filled up with zeros.
54
constexpr char kZeroVarSuffix[] = "@ZERO";
55 56 57 58 59

inline std::string GradVarName(const std::string& var_name) {
  return var_name + kGradVarSuffix;
}

Q
Qiao Longfei 已提交
60
class OperatorBase;
61
class ExecutionContext;
62

Q
Qiao Longfei 已提交
63 64 65
extern const Tensor* GetTensorFromVar(const Variable* var);
extern Tensor* GetTensorFromVar(Variable* var);

Q
Qiao Longfei 已提交
66 67 68 69 70 71 72 73
/**
 * OperatorBase has the basic element that Net will call to do computation.
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
Y
Yu Yang 已提交
74 75
  OperatorBase(const std::string& type, const VariableNameMap& inputs,
               const VariableNameMap& outputs, const AttributeMap& attrs);
76

Q
Qiao Longfei 已提交
77 78 79
  virtual ~OperatorBase() {}

  template <typename T>
Y
Yu Yang 已提交
80
  inline const T& Attr(const std::string& name) const {
Q
Qiao Longfei 已提交
81 82 83 84 85
    PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
                   name);
    return boost::get<T>(attrs_.at(name));
  }

86
  virtual std::string DebugString() const;
Q
Qiao Longfei 已提交
87 88

  /// Net will call this function to Run an op.
Y
Yu Yang 已提交
89
  virtual void Run(const Scope& scope,
Y
Yu Yang 已提交
90 91
                   const platform::DeviceContext& dev_ctx) const = 0;

Y
Yu Yang 已提交
92 93
  virtual bool IsNetOp() const { return false; }

94 95
  virtual bool SupportGPU() const { return false; }

D
dongzhihong 已提交
96 97 98
  /// rename inputs outputs name
  void Rename(const std::string& old_name, const std::string& new_name);

Y
Yu Yang 已提交
99 100
  const VariableNameMap& Inputs() const { return inputs_; }
  const VariableNameMap& Outputs() const { return outputs_; }
101

Y
Yu Yang 已提交
102
  //! Get a input with argument's name described in `op_proto`
103
  std::string Input(const std::string& name) const;
Y
Yu Yang 已提交
104
  //! Get a input which has multiple variables.
Y
Yu Yang 已提交
105
  const std::vector<std::string>& Inputs(const std::string& name) const;
Y
Yi Wang 已提交
106

Q
qijun 已提交
107 108
  std::vector<std::string> InputVars() const;

Y
Yu Yang 已提交
109
  //! Get a output with argument's name described in `op_proto`
110
  std::string Output(const std::string& name) const;
Y
Yu Yang 已提交
111 112
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yu Yang 已提交
113
  const std::vector<std::string>& Outputs(const std::string& name) const;
Y
Yan Chunwei 已提交
114

Y
Yu Yang 已提交
115
  virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
116

Q
qiaolongfei 已提交
117
  const std::string& Type() const { return type_; }
Q
qiaolongfei 已提交
118
  void SetType(const std::string& type) { type_ = type; }
Y
Yi Wang 已提交
119 120
  const AttributeMap& Attrs() const { return attrs_; }

Y
Yu Yang 已提交
121
  // Return a new operator instance, which is as same as this.
Y
Yu Yang 已提交
122 123
  // Use unique_ptr to prevent caller forget to delete this pointer.
  virtual std::unique_ptr<OperatorBase> Clone() const = 0;
Y
Yu Yang 已提交
124

Q
qiaolongfei 已提交
125
 protected:
Q
Qiao Longfei 已提交
126
  std::string type_;
D
dongzhihong 已提交
127
  // NOTE: in case of OpGrad, inputs_ contains:
Y
Yu Yang 已提交
128
  // I (Inputs)opear
D
dongzhihong 已提交
129 130
  // O (Outputs)
  // OG (Output Gradients)
Y
Yu Yang 已提交
131
  VariableNameMap inputs_;
Y
Yu Yang 已提交
132

D
dongzhihong 已提交
133 134
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Y
Yu Yang 已提交
135
  VariableNameMap outputs_;
Q
Qiao Longfei 已提交
136
  AttributeMap attrs_;
137 138 139 140

 private:
  void GenerateTemporaryNames();
  void CheckAllInputOutputSet() const;
Y
Yan Chunwei 已提交
141 142
};

Y
Yu Yang 已提交
143 144
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
145
// register it. i.e. `Clone` method is not needed to define by yourself.
146 147 148
#define DEFINE_OP_CLONE_METHOD(cls)                                            \
  std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final {     \
    return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \
Y
Yu Yang 已提交
149
  }
Y
Yu Yang 已提交
150

Y
Yu Yang 已提交
151 152 153 154
// Macro for define a default constructor for Operator.
// You can also use
//   using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
Y
Yu Yang 已提交
155 156
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls)             \
  cls(const std::string& type,                             \
Y
Yu Yang 已提交
157 158 159
      const ::paddle::framework::VariableNameMap& inputs,  \
      const ::paddle::framework::VariableNameMap& outputs, \
      const paddle::framework::AttributeMap& attrs)        \
Y
Yu Yang 已提交
160
      : parent_cls(type, inputs, outputs, attrs) {}
Y
Yu Yang 已提交
161

162 163
class NOP : public OperatorBase {
 public:
164
  using OperatorBase::OperatorBase;
165 166
  void Run(const Scope& scope,
           const platform::DeviceContext& dev_ctx) const override {}
167 168 169
  std::unique_ptr<OperatorBase> Clone() const override {
    return std::unique_ptr<OperatorBase>(new NOP(*this));
  }
170 171
};

172
class ExecutionContext {
Y
Yan Chunwei 已提交
173
 public:
174 175 176
  ExecutionContext(const OperatorBase& op, const Scope& scope,
                   const platform::DeviceContext& device_context)
      : op_(op), scope_(scope), device_context_(device_context) {}
177

Q
qiaolongfei 已提交
178 179 180 181
  const OperatorBase& op() const { return op_; }

  const Scope& scope() const { return scope_; }

Q
qiaolongfei 已提交
182
  template <typename T>
Y
Yu Yang 已提交
183 184
  inline const T& Attr(const std::string& name) const {
    return op_.Attr<T>(name);
Q
qiaolongfei 已提交
185 186
  }

Y
Yu Yang 已提交
187
  size_t InputSize(const std::string& name) const {
Y
Yu Yang 已提交
188
    return op_.Inputs(name).size();
Y
Yan Chunwei 已提交
189 190
  }

Y
Yu Yang 已提交
191
  size_t OutputSize(const std::string& name) const {
Y
Yu Yang 已提交
192
    return op_.Outputs(name).size();
Y
Yan Chunwei 已提交
193 194
  }

195
  const Variable* InputVar(const std::string& name) const {
196
    auto ipt = op_.Input(name);
Y
Yu Yang 已提交
197
    return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
Y
Yan Chunwei 已提交
198 199
  }

200
  Variable* OutputVar(const std::string& name) const {
201
    auto opt = op_.Output(name);
Y
Yu Yang 已提交
202
    return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
Y
Yan Chunwei 已提交
203 204
  }

205 206
  const std::vector<const Variable*> MultiInputVar(
      const std::string& name) const {
Y
Yan Chunwei 已提交
207 208
    auto names = op_.Inputs(name);
    std::vector<const Variable*> res;
209
    res.reserve(names.size());
210 211
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
212 213
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
214
                   });
Y
Yan Chunwei 已提交
215 216 217
    return res;
  }

218
  std::vector<Variable*> MultiOutputVar(const std::string& name) const {
Y
Yan Chunwei 已提交
219
    auto names = op_.Outputs(name);
220
    std::vector<Variable*> res;
221
    res.reserve(names.size());
222 223
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
224 225
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
226
                   });
Y
Yan Chunwei 已提交
227 228 229
    return res;
  }

230 231
  template <typename T>
  const T* Input(const std::string& name) const {
Y
Yu Yang 已提交
232
    auto* var = InputVar(name);
233
    return var == nullptr ? nullptr : &var->Get<T>();
234 235 236 237
  }

  template <typename T>
  T* Output(const std::string& name) const {
238
    auto var = OutputVar(name);
239
    return var == nullptr ? nullptr : var->GetMutable<T>();
240 241 242 243 244 245 246 247
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
    auto names = op_.Inputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
248
                   [&](const std::string& sub_name) {
249
                     auto var = scope_.FindVar(sub_name);
250
                     return var == nullptr ? nullptr : &var->Get<T>();
251 252 253 254 255
                   });
    return res;
  }

  template <typename T>
256
  std::vector<T*> MultiOutput(const std::string& name) const {
257
    auto names = op_.Outputs(name);
258
    std::vector<T*> res;
259 260
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
261
                   [&](const std::string& sub_name) {
262
                     auto var = scope_.FindVar(sub_name);
263
                     return var == nullptr ? nullptr : var->GetMutable<T>();
264 265 266 267
                   });
    return res;
  }

268 269 270 271 272 273
  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) const {
    PADDLE_ENFORCE_LT(i, InputSize(in));
    PADDLE_ENFORCE_LT(j, OutputSize(out));
    auto* in_var = MultiInputVar(in)[i];
    auto* out_var = MultiOutputVar(out)[j];
274
    if (!in_var->IsType<LoDTensor>()) return;
275
    PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
276
                   "The %d-th output of Output(%s) must be LoDTensor.", j, out);
277 278 279
    auto in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_lod(in_tensor.lod());
280 281
  }

Q
qijun 已提交
282
  template <typename PlaceType,
283 284
            typename DeviceType = typename platform::EigenDeviceConverter<
                PlaceType>::EigenDeviceType>
285
  DeviceType& GetEigenDevice() const;
Q
qijun 已提交
286

287
  platform::Place GetPlace() const { return device_context_.GetPlace(); }
Q
qijun 已提交
288

289
  const platform::DeviceContext& device_context() const {
Q
qijun 已提交
290
    return device_context_;
Q
qijun 已提交
291
  }
Q
qijun 已提交
292

武毅 已提交
293 294 295 296 297 298 299 300 301
#ifdef PADDLE_WITH_CUDA
  const platform::CUDADeviceContext& cuda_device_context() const {
    PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
    auto cuda_ctx =
        reinterpret_cast<const platform::CUDADeviceContext*>(&device_context_);
    return *cuda_ctx;
  }
#endif

302
 private:
303 304
  const OperatorBase& op_;
  const Scope& scope_;
305
  const platform::DeviceContext& device_context_;
Q
Qiao Longfei 已提交
306 307
};

308 309 310 311 312 313 314 315 316 317 318 319 320 321
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;

template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
    const std::string& name) const;

template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;

template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
    const std::string& name) const;

322
class CompileTimeInferShapeContext : public InferShapeContext {
Q
tmp  
qiaolongfei 已提交
323
 public:
Q
qiaolongfei 已提交
324 325
  CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
      : op_(op), block_(block) {}
Q
tmp  
qiaolongfei 已提交
326

Q
qiaolongfei 已提交
327
  bool HasInput(const std::string& name) const override {
Q
qiaolongfei 已提交
328
    const std::vector<std::string>& input_names = op_.Input(name);
329
    auto length = input_names.size();
330 331 332
    if (length == 0) {
      return false;
    }
333 334 335 336
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Input(%s) should have only one value, "
                      "but it have %d now",
                      name, length);
337
    return block_.HasVarRecursive(input_names[0]);
Q
tmp  
qiaolongfei 已提交
338 339
  }

Q
qiaolongfei 已提交
340
  bool HasOutput(const std::string& name) const override {
Q
qiaolongfei 已提交
341
    const std::vector<std::string>& output_names = op_.Output(name);
342
    auto length = output_names.size();
343 344 345
    if (length == 0) {
      return false;
    }
346 347 348 349
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Output(%s) should have only one value, "
                      "but it have %d now",
                      name, length);
350
    return block_.HasVarRecursive(output_names[0]);
Q
tmp  
qiaolongfei 已提交
351 352
  }

Q
qiaolongfei 已提交
353
  bool HasInputs(const std::string& name) const override {
Q
qiaolongfei 已提交
354
    const std::vector<std::string>& input_names = op_.Input(name);
355 356 357
    if (input_names.empty()) {
      return false;
    }
Q
qiaolongfei 已提交
358
    for (auto& input : input_names) {
359
      if (!block_.HasVarRecursive(input)) return false;
Q
tmp  
qiaolongfei 已提交
360 361 362 363
    }
    return true;
  }

Q
qiaolongfei 已提交
364
  bool HasOutputs(const std::string& name) const override {
Q
qiaolongfei 已提交
365
    const std::vector<std::string>& output_names = op_.Output(name);
366 367 368
    if (output_names.empty()) {
      return false;
    }
Q
qiaolongfei 已提交
369
    for (auto& output : output_names) {
370
      if (!block_.HasVarRecursive(output)) return false;
Q
tmp  
qiaolongfei 已提交
371 372 373 374
    }
    return true;
  }

Q
qiaolongfei 已提交
375
  DDim GetInputDim(const std::string& name) const override {
Q
qiaolongfei 已提交
376
    std::vector<DDim> ddims = GetInputsDim(name);
377 378 379 380 381
    auto length = ddims.size();
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Input(%s) should have 1 value, "
                      "but it has %d now",
                      name, length);
Q
qiaolongfei 已提交
382
    return ddims[0];
Q
tmp  
qiaolongfei 已提交
383 384
  }

Q
qiaolongfei 已提交
385
  void SetInputDim(const std::string& name, const DDim& dim) override {
Q
qiaolongfei 已提交
386
    SetInputsDim(name, {dim});
Q
tmp  
qiaolongfei 已提交
387 388
  }

Q
qiaolongfei 已提交
389
  DDim GetOutputDim(const std::string& name) const override {
Q
qiaolongfei 已提交
390
    std::vector<DDim> ddims = GetOutputsDim(name);
391 392 393 394 395
    auto length = ddims.size();
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Output(%s) should have 1 value, "
                      "but it has %d now",
                      name, length);
Q
qiaolongfei 已提交
396
    return ddims[0];
Q
tmp  
qiaolongfei 已提交
397 398
  }

Q
qiaolongfei 已提交
399
  void SetOutputDim(const std::string& name, const DDim& dim) override {
Q
qiaolongfei 已提交
400
    SetOutputsDim(name, {dim});
Q
tmp  
qiaolongfei 已提交
401 402
  }

Q
qiaolongfei 已提交
403
  AttrReader Attrs() const override { return AttrReader(op_.GetAttrMap()); }
Q
tmp  
qiaolongfei 已提交
404

Q
qiaolongfei 已提交
405 406
  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
Q
qiaolongfei 已提交
407
    return op_.Input(name);
Q
tmp  
qiaolongfei 已提交
408 409
  }

Q
qiaolongfei 已提交
410 411
  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
Q
qiaolongfei 已提交
412
    return op_.Output(name);
Q
tmp  
qiaolongfei 已提交
413 414 415
  }

 private:
Q
qiaolongfei 已提交
416
  DDim GetDim(const std::string& name) const override {
Y
Yu Yang 已提交
417 418 419
    auto var = block_.FindVarRecursive(name);
    PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
    return framework::make_ddim(var->Shape());
Q
tmp  
qiaolongfei 已提交
420 421
  }

Q
qiaolongfei 已提交
422
  void SetDim(const std::string& name, const DDim& dim) override {
423
    block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
Q
tmp  
qiaolongfei 已提交
424 425
  }

Q
qiaolongfei 已提交
426 427
  const OpDescBind& op_;
  const BlockDescBind& block_;
Q
tmp  
qiaolongfei 已提交
428 429
};

430
class RuntimeInferShapeContext : public InferShapeContext {
Q
Qiao Longfei 已提交
431 432 433 434
 public:
  RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
      : op_(op), scope_(scope) {}

Q
qiaolongfei 已提交
435
  bool HasInput(const std::string& name) const override {
436 437 438 439 440 441 442 443
    auto& ins = Inputs(name);
    size_t length = ins.size();
    if (length == 0) {
      return false;
    }
    PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs",
                      name);
    auto ipt = ins[0];
Q
Qiao Longfei 已提交
444 445 446 447
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

Q
qiaolongfei 已提交
448
  bool HasOutput(const std::string& name) const override {
449 450 451 452 453 454 455 456
    auto& outs = Outputs(name);
    size_t length = outs.size();
    if (length == 0) {
      return false;
    }
    PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs",
                      name);
    auto ipt = outs[0];
Q
Qiao Longfei 已提交
457 458 459 460
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

Q
qiaolongfei 已提交
461
  bool HasInputs(const std::string& name) const override {
462
    auto inputs = op_.Inputs(name);
Q
qiaolongfei 已提交
463
    if (inputs.empty()) {
464 465 466 467 468 469 470 471 472 473
      return false;
    }
    for (auto& input : inputs) {
      if (scope_.FindVar(input) == nullptr) {
        return false;
      }
    }
    return true;
  }

Q
qiaolongfei 已提交
474
  bool HasOutputs(const std::string& name) const override {
475
    auto outputs = op_.Outputs(name);
Q
qiaolongfei 已提交
476
    if (outputs.empty()) {
477 478 479 480 481 482 483 484 485 486
      return false;
    }
    for (auto& output : outputs) {
      if (scope_.FindVar(output) == nullptr) {
        return false;
      }
    }
    return true;
  }

Q
qiaolongfei 已提交
487
  DDim GetInputDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
488 489 490
    return GetDim(op_.Input(name));
  }

Q
qiaolongfei 已提交
491
  void SetInputDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
492 493 494
    SetDim(op_.Input(name), dim);
  }

Q
qiaolongfei 已提交
495
  DDim GetOutputDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
496 497 498
    return GetDim(op_.Output(name));
  }

Q
qiaolongfei 已提交
499
  void SetOutputDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
500 501 502
    SetDim(op_.Output(name), dim);
  }

Q
qiaolongfei 已提交
503
  AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
Q
Qiao Longfei 已提交
504

Q
qiaolongfei 已提交
505 506
  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
Q
Qiao Longfei 已提交
507 508 509
    return op_.Inputs(name);
  }

Q
qiaolongfei 已提交
510 511
  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
Q
Qiao Longfei 已提交
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531
    return op_.Outputs(name);
  }

 private:
  template <bool Allocate>
  Tensor* GetTensor(const std::string& name) const {
    Tensor* t = nullptr;
    auto* var = scope_.FindVar(name);
    if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) {
      if (Allocate) {
        t = var->GetMutable<LoDTensor>();
      } else {
        PADDLE_THROW("Variable(%s) should be tensor", name);
      }
    } else {
      t = GetTensorFromVar(scope_.FindVar(name));
    }
    return t;
  }

Q
qiaolongfei 已提交
532
  DDim GetDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
533 534 535
    return GetTensor<false>(name)->dims();
  }

Q
qiaolongfei 已提交
536
  void SetDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
537 538 539 540 541 542 543
    GetTensor<true>(name)->Resize(dim);
  }

  const OperatorBase& op_;
  const Scope& scope_;
};

Y
Yu Yang 已提交
544
class OpKernelBase {
Q
qijun 已提交
545
 public:
Q
qijun 已提交
546
  /**
547
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
548 549
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
550
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
551 552
   */

553
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
554

Y
Yu Yang 已提交
555 556 557 558 559 560 561
  virtual ~OpKernelBase() = default;
};

template <typename T>
class OpKernel : public OpKernelBase {
 public:
  using ELEMENT_TYPE = T;
Y
Yu Yang 已提交
562 563
};

Q
Qiao Longfei 已提交
564 565
class OperatorWithKernel : public OperatorBase {
 public:
Y
Yu Yang 已提交
566 567
  struct OpKernelKey {
    platform::Place place_;
Y
Yu Yang 已提交
568
    DataType data_type_;
Q
Qiao Longfei 已提交
569

Y
Yu Yang 已提交
570 571 572 573 574
    OpKernelKey(DataType data_type, platform::Place place)
        : place_(place), data_type_(data_type) {}

    OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
        : place_(dev_ctx.GetPlace()), data_type_(data_type) {}
Y
Yu Yang 已提交
575

Q
qijun 已提交
576
    bool operator==(const OpKernelKey& o) const {
Y
Yu Yang 已提交
577 578
      return platform::places_are_same_class(place_, o.place_) &&
             data_type_ == o.data_type_;
Q
qijun 已提交
579
    }
Y
Yu Yang 已提交
580 581 582
  };

  struct OpKernelHash {
Y
Yu Yang 已提交
583
    std::hash<int> hash_;
Y
Yu Yang 已提交
584
    size_t operator()(const OpKernelKey& key) const {
Y
Yu Yang 已提交
585 586
      int place = key.place_.which();
      int data_type = static_cast<int>(key.data_type_);
Y
Yu Yang 已提交
587 588
      int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
                     (place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
Y
Yu Yang 已提交
589
      return hash_(pre_hash);
Y
Yu Yang 已提交
590 591 592 593
    }
  };

  using OpKernelMap =
Y
Yu Yang 已提交
594 595
      std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
                         OpKernelHash>;
Q
Qiao Longfei 已提交
596

Y
Yu Yang 已提交
597 598
  OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
                     const VariableNameMap& outputs, const AttributeMap& attrs)
Y
Yu Yang 已提交
599 600
      : OperatorBase(type, inputs, outputs, attrs) {}

Y
Yu Yang 已提交
601
  void Run(const Scope& scope,
Y
Yu Yang 已提交
602
           const platform::DeviceContext& dev_ctx) const final {
Y
Yu Yang 已提交
603
    VLOG(3) << "Running operator " << this->Type();
Y
Yu Yang 已提交
604 605 606
    RuntimeInferShapeContext infer_shape_ctx(*this, scope);
    this->InferShape(&infer_shape_ctx);

Y
Yu Yang 已提交
607
    ExecutionContext ctx(*this, scope, dev_ctx);
608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626

    // check if op[type] has kernel registered.
    auto& all_op_kernels = AllOpKernels();
    auto kernels_iter = all_op_kernels.find(type_);
    if (kernels_iter == all_op_kernels.end()) {
      PADDLE_THROW("op[%s] has no kernel", type_);
    }

    // check if op[type] have kernel for kernel_key
    OpKernelMap& kernels = kernels_iter->second;
    auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
    auto kernel_iter = kernels.find(kernel_key);

    if (kernel_iter == kernels.end()) {
      PADDLE_THROW("op[%s] has no kernel with kernel_key[%s]", type_,
                   kernel_key);
    }

    kernel_iter->second->Compute(ctx);
Q
Qiao Longfei 已提交
627 628
  }

Y
Yu Yang 已提交
629 630 631 632
  static std::unordered_map<std::string /* op_type */, OpKernelMap>&
  AllOpKernels() {
    static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
    return g_all_op_kernels;
Y
Yu Yang 已提交
633
  }
Y
Yan Chunwei 已提交
634

635
  bool SupportGPU() const override {
Y
Yu Yang 已提交
636 637 638 639 640
    auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
    return std::any_of(op_kernels.begin(), op_kernels.end(),
                       [](OpKernelMap::const_reference kern_pair) {
                         return platform::is_gpu_place(kern_pair.first.place_);
                       });
641 642
  }

643
  virtual void InferShape(InferShapeContext* ctx) const = 0;
Y
Yu Yang 已提交
644

Q
qiaolongfei 已提交
645
 protected:
Y
Yu Yang 已提交
646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662
  // indicate kernel DataType by input data. Defaultly all input data must be
  // same.
  virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
    auto& scope = ctx.scope();
    int data_type = -1;
    for (auto& input : this->inputs_) {
      for (auto& ipt_name : input.second) {
        auto* var = scope.FindVar(ipt_name);
        if (var != nullptr) {
          const Tensor* t = nullptr;
          if (var->IsType<Tensor>()) {
            t = &var->Get<Tensor>();
          } else if (var->IsType<LoDTensor>()) {
            t = &var->Get<LoDTensor>();
          }
          if (t != nullptr) {
            int tmp = static_cast<int>(ToDataType(t->type()));
Y
Yu Yang 已提交
663
            VLOG(3) << "Input " << ipt_name << " with data_type " << tmp;
Y
Yu Yang 已提交
664
            PADDLE_ENFORCE(tmp == data_type || data_type == -1,
Y
Yu Yang 已提交
665
                           "DataType of Paddle Op %s must be same.", Type());
Y
Yu Yang 已提交
666 667 668 669 670 671 672 673
            data_type = tmp;
          }
        }
      }
    }
    PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
    return static_cast<DataType>(data_type);
  }
Q
Qiao Longfei 已提交
674 675
};

676 677 678
std::ostream& operator<<(std::ostream& os,
                         const OperatorWithKernel::OpKernelKey& kernel_key);

Y
Yu Yang 已提交
679 680
extern bool OpSupportGPU(const std::string& op_type);

Q
Qiao Longfei 已提交
681 682
}  // namespace framework
}  // namespace paddle