operator.h 21.9 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
18
#include <atomic>
Q
Qiao Longfei 已提交
19 20 21 22
#include <string>
#include <unordered_map>
#include <vector>

Y
Yu Yang 已提交
23
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
24
#include "paddle/framework/attribute.h"
Q
qiaolongfei 已提交
25
#include "paddle/framework/block_desc.h"
Y
Yu Yang 已提交
26
#include "paddle/framework/data_type.h"
Y
Yu Yang 已提交
27
#include "paddle/framework/framework.pb.h"
28
#include "paddle/framework/lod_tensor.h"
Y
Yu Yang 已提交
29
#include "paddle/framework/op_info.h"
Q
qijun 已提交
30
#include "paddle/framework/scope.h"
Q
QI JUN 已提交
31
#include "paddle/framework/selected_rows.h"
Q
Qiao Longfei 已提交
32
#include "paddle/framework/shape_inference.h"
Q
qijun 已提交
33 34 35
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
Y
Yu Yang 已提交
36
#include "paddle/platform/variant.h"
Q
qijun 已提交
37
#include "paddle/utils/Error.h"
Q
Qiao Longfei 已提交
38 39 40 41

namespace paddle {
namespace framework {

42
/// If a variable is a empty variable, that name will be used.
43
constexpr char kEmptyVarName[] = "@EMPTY@";
44 45 46

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
47
constexpr char kTempVarName[] = "@TEMP@";
48 49 50 51

/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
52
constexpr char kGradVarSuffix[] = "@GRAD";
53 54

/// Variables with this suffix are supposed to be filled up with zeros.
55
constexpr char kZeroVarSuffix[] = "@ZERO";
56 57 58 59 60

inline std::string GradVarName(const std::string& var_name) {
  return var_name + kGradVarSuffix;
}

Q
Qiao Longfei 已提交
61
class OperatorBase;
62
class ExecutionContext;
63

Q
Qiao Longfei 已提交
64 65 66 67 68 69 70 71
/**
 * OperatorBase has the basic element that Net will call to do computation.
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
Y
Yu Yang 已提交
72 73
  OperatorBase(const std::string& type, const VariableNameMap& inputs,
               const VariableNameMap& outputs, const AttributeMap& attrs);
74

Q
Qiao Longfei 已提交
75 76 77
  virtual ~OperatorBase() {}

  template <typename T>
Y
Yu Yang 已提交
78
  inline const T& Attr(const std::string& name) const {
Q
Qiao Longfei 已提交
79 80 81 82 83
    PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
                   name);
    return boost::get<T>(attrs_.at(name));
  }

84
  virtual std::string DebugString() const;
Q
Qiao Longfei 已提交
85 86

  /// Net will call this function to Run an op.
Y
Yu Yang 已提交
87
  virtual void Run(const Scope& scope,
Y
Yu Yang 已提交
88 89
                   const platform::DeviceContext& dev_ctx) const = 0;

Y
Yu Yang 已提交
90 91
  virtual bool IsNetOp() const { return false; }

92 93
  virtual bool SupportGPU() const { return false; }

D
dongzhihong 已提交
94 95 96
  /// rename inputs outputs name
  void Rename(const std::string& old_name, const std::string& new_name);

Y
Yu Yang 已提交
97 98
  const VariableNameMap& Inputs() const { return inputs_; }
  const VariableNameMap& Outputs() const { return outputs_; }
99

Y
Yu Yang 已提交
100
  //! Get a input with argument's name described in `op_proto`
101
  std::string Input(const std::string& name) const;
Y
Yu Yang 已提交
102
  //! Get a input which has multiple variables.
Y
Yu Yang 已提交
103
  const std::vector<std::string>& Inputs(const std::string& name) const;
Y
Yi Wang 已提交
104

Q
qijun 已提交
105 106
  std::vector<std::string> InputVars() const;

Y
Yu Yang 已提交
107
  //! Get a output with argument's name described in `op_proto`
108
  std::string Output(const std::string& name) const;
Y
Yu Yang 已提交
109 110
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yu Yang 已提交
111
  const std::vector<std::string>& Outputs(const std::string& name) const;
Y
Yan Chunwei 已提交
112

Y
Yu Yang 已提交
113
  virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
114

Q
qiaolongfei 已提交
115
  const std::string& Type() const { return type_; }
Q
qiaolongfei 已提交
116
  void SetType(const std::string& type) { type_ = type; }
Y
Yi Wang 已提交
117 118
  const AttributeMap& Attrs() const { return attrs_; }

Y
Yu Yang 已提交
119
  // Return a new operator instance, which is as same as this.
Y
Yu Yang 已提交
120 121
  // Use unique_ptr to prevent caller forget to delete this pointer.
  virtual std::unique_ptr<OperatorBase> Clone() const = 0;
Y
Yu Yang 已提交
122

Q
qiaolongfei 已提交
123
 protected:
Q
Qiao Longfei 已提交
124
  std::string type_;
D
dongzhihong 已提交
125
  // NOTE: in case of OpGrad, inputs_ contains:
Y
Yu Yang 已提交
126
  // I (Inputs)opear
D
dongzhihong 已提交
127 128
  // O (Outputs)
  // OG (Output Gradients)
Y
Yu Yang 已提交
129
  VariableNameMap inputs_;
Y
Yu Yang 已提交
130

D
dongzhihong 已提交
131 132
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Y
Yu Yang 已提交
133
  VariableNameMap outputs_;
Q
Qiao Longfei 已提交
134
  AttributeMap attrs_;
135 136 137 138

 private:
  void GenerateTemporaryNames();
  void CheckAllInputOutputSet() const;
Y
Yan Chunwei 已提交
139 140
};

Y
Yu Yang 已提交
141 142
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
143
// register it. i.e. `Clone` method is not needed to define by yourself.
144 145 146
#define DEFINE_OP_CLONE_METHOD(cls)                                            \
  std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final {     \
    return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \
Y
Yu Yang 已提交
147
  }
Y
Yu Yang 已提交
148

Y
Yu Yang 已提交
149 150 151 152
// Macro for define a default constructor for Operator.
// You can also use
//   using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
Y
Yu Yang 已提交
153 154
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls)             \
  cls(const std::string& type,                             \
Y
Yu Yang 已提交
155 156 157
      const ::paddle::framework::VariableNameMap& inputs,  \
      const ::paddle::framework::VariableNameMap& outputs, \
      const paddle::framework::AttributeMap& attrs)        \
Y
Yu Yang 已提交
158
      : parent_cls(type, inputs, outputs, attrs) {}
Y
Yu Yang 已提交
159

160 161
class NOP : public OperatorBase {
 public:
162
  using OperatorBase::OperatorBase;
163 164
  void Run(const Scope& scope,
           const platform::DeviceContext& dev_ctx) const override {}
165 166 167
  std::unique_ptr<OperatorBase> Clone() const override {
    return std::unique_ptr<OperatorBase>(new NOP(*this));
  }
168 169
};

170
class ExecutionContext {
Y
Yan Chunwei 已提交
171
 public:
172 173 174
  ExecutionContext(const OperatorBase& op, const Scope& scope,
                   const platform::DeviceContext& device_context)
      : op_(op), scope_(scope), device_context_(device_context) {}
175

Q
qiaolongfei 已提交
176 177 178 179
  const OperatorBase& op() const { return op_; }

  const Scope& scope() const { return scope_; }

Q
qiaolongfei 已提交
180
  template <typename T>
Y
Yu Yang 已提交
181 182
  inline const T& Attr(const std::string& name) const {
    return op_.Attr<T>(name);
Q
qiaolongfei 已提交
183 184
  }

Y
Yu Yang 已提交
185
  size_t InputSize(const std::string& name) const {
Y
Yu Yang 已提交
186
    return op_.Inputs(name).size();
Y
Yan Chunwei 已提交
187 188
  }

Y
Yu Yang 已提交
189
  size_t OutputSize(const std::string& name) const {
Y
Yu Yang 已提交
190
    return op_.Outputs(name).size();
Y
Yan Chunwei 已提交
191 192
  }

193
  const Variable* InputVar(const std::string& name) const {
194
    auto ipt = op_.Input(name);
Y
Yu Yang 已提交
195
    return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
Y
Yan Chunwei 已提交
196 197
  }

198
  Variable* OutputVar(const std::string& name) const {
199
    auto opt = op_.Output(name);
Y
Yu Yang 已提交
200
    return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
Y
Yan Chunwei 已提交
201 202
  }

203 204
  const std::vector<const Variable*> MultiInputVar(
      const std::string& name) const {
Y
Yan Chunwei 已提交
205 206
    auto names = op_.Inputs(name);
    std::vector<const Variable*> res;
207
    res.reserve(names.size());
208 209
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
210 211
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
212
                   });
Y
Yan Chunwei 已提交
213 214 215
    return res;
  }

216
  std::vector<Variable*> MultiOutputVar(const std::string& name) const {
Y
Yan Chunwei 已提交
217
    auto names = op_.Outputs(name);
218
    std::vector<Variable*> res;
219
    res.reserve(names.size());
220 221
    std::transform(names.begin(), names.end(), std::back_inserter(res),
                   [this](const std::string& name) {
Y
Yu Yang 已提交
222 223
                     return name == kEmptyVarName ? nullptr
                                                  : scope_.FindVar(name);
224
                   });
Y
Yan Chunwei 已提交
225 226 227
    return res;
  }

228 229
  template <typename T>
  const T* Input(const std::string& name) const {
Y
Yu Yang 已提交
230
    auto* var = InputVar(name);
231
    return var == nullptr ? nullptr : &var->Get<T>();
232 233 234 235
  }

  template <typename T>
  T* Output(const std::string& name) const {
236
    auto var = OutputVar(name);
237
    return var == nullptr ? nullptr : var->GetMutable<T>();
238 239 240 241 242 243 244 245
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
    auto names = op_.Inputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
246
                   [&](const std::string& sub_name) {
247
                     auto var = scope_.FindVar(sub_name);
248
                     return var == nullptr ? nullptr : &var->Get<T>();
249 250 251 252 253
                   });
    return res;
  }

  template <typename T>
254
  std::vector<T*> MultiOutput(const std::string& name) const {
255
    auto names = op_.Outputs(name);
256
    std::vector<T*> res;
257 258
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
259
                   [&](const std::string& sub_name) {
260
                     auto var = scope_.FindVar(sub_name);
261
                     return var == nullptr ? nullptr : var->GetMutable<T>();
262 263 264 265
                   });
    return res;
  }

266 267 268 269 270 271
  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) const {
    PADDLE_ENFORCE_LT(i, InputSize(in));
    PADDLE_ENFORCE_LT(j, OutputSize(out));
    auto* in_var = MultiInputVar(in)[i];
    auto* out_var = MultiOutputVar(out)[j];
272
    if (!in_var->IsType<LoDTensor>()) return;
273
    PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
274
                   "The %d-th output of Output(%s) must be LoDTensor.", j, out);
275 276 277
    auto in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_lod(in_tensor.lod());
278 279
  }

Q
qijun 已提交
280
  template <typename PlaceType,
281 282
            typename DeviceType = typename platform::EigenDeviceConverter<
                PlaceType>::EigenDeviceType>
283
  DeviceType& GetEigenDevice() const;
Q
qijun 已提交
284

285
  platform::Place GetPlace() const { return device_context_.GetPlace(); }
Q
qijun 已提交
286

287
  const platform::DeviceContext& device_context() const {
Q
qijun 已提交
288
    return device_context_;
Q
qijun 已提交
289
  }
Q
qijun 已提交
290

武毅 已提交
291 292 293 294 295 296 297 298 299
#ifdef PADDLE_WITH_CUDA
  const platform::CUDADeviceContext& cuda_device_context() const {
    PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
    auto cuda_ctx =
        reinterpret_cast<const platform::CUDADeviceContext*>(&device_context_);
    return *cuda_ctx;
  }
#endif

300
 private:
301 302
  const OperatorBase& op_;
  const Scope& scope_;
303
  const platform::DeviceContext& device_context_;
Q
Qiao Longfei 已提交
304 305
};

306 307 308 309 310 311 312 313 314 315 316 317 318 319
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;

template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
    const std::string& name) const;

template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;

template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
    const std::string& name) const;

320
class CompileTimeInferShapeContext : public InferShapeContext {
Q
tmp  
qiaolongfei 已提交
321
 public:
Q
qiaolongfei 已提交
322 323
  CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
      : op_(op), block_(block) {}
Q
tmp  
qiaolongfei 已提交
324

Q
qiaolongfei 已提交
325
  bool HasInput(const std::string& name) const override {
Q
qiaolongfei 已提交
326
    const std::vector<std::string>& input_names = op_.Input(name);
327
    auto length = input_names.size();
328 329 330
    if (length == 0) {
      return false;
    }
331 332 333 334
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Input(%s) should have only one value, "
                      "but it have %d now",
                      name, length);
335
    return block_.HasVarRecursive(input_names[0]);
Q
tmp  
qiaolongfei 已提交
336 337
  }

Q
qiaolongfei 已提交
338
  bool HasOutput(const std::string& name) const override {
Q
qiaolongfei 已提交
339
    const std::vector<std::string>& output_names = op_.Output(name);
340
    auto length = output_names.size();
341 342 343
    if (length == 0) {
      return false;
    }
344 345 346 347
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Output(%s) should have only one value, "
                      "but it have %d now",
                      name, length);
348
    return block_.HasVarRecursive(output_names[0]);
Q
tmp  
qiaolongfei 已提交
349 350
  }

Q
qiaolongfei 已提交
351
  bool HasInputs(const std::string& name) const override {
Q
qiaolongfei 已提交
352
    const std::vector<std::string>& input_names = op_.Input(name);
353 354 355
    if (input_names.empty()) {
      return false;
    }
Q
qiaolongfei 已提交
356
    for (auto& input : input_names) {
357
      if (!block_.HasVarRecursive(input)) return false;
Q
tmp  
qiaolongfei 已提交
358 359 360 361
    }
    return true;
  }

Q
qiaolongfei 已提交
362
  bool HasOutputs(const std::string& name) const override {
Q
qiaolongfei 已提交
363
    const std::vector<std::string>& output_names = op_.Output(name);
364 365 366
    if (output_names.empty()) {
      return false;
    }
Q
qiaolongfei 已提交
367
    for (auto& output : output_names) {
368
      if (!block_.HasVarRecursive(output)) return false;
Q
tmp  
qiaolongfei 已提交
369 370 371 372
    }
    return true;
  }

Q
qiaolongfei 已提交
373
  DDim GetInputDim(const std::string& name) const override {
Q
qiaolongfei 已提交
374
    std::vector<DDim> ddims = GetInputsDim(name);
375 376 377 378 379
    auto length = ddims.size();
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Input(%s) should have 1 value, "
                      "but it has %d now",
                      name, length);
Q
qiaolongfei 已提交
380
    return ddims[0];
Q
tmp  
qiaolongfei 已提交
381 382
  }

Q
qiaolongfei 已提交
383
  void SetInputDim(const std::string& name, const DDim& dim) override {
Q
qiaolongfei 已提交
384
    SetInputsDim(name, {dim});
Q
tmp  
qiaolongfei 已提交
385 386
  }

Q
qiaolongfei 已提交
387
  DDim GetOutputDim(const std::string& name) const override {
Q
qiaolongfei 已提交
388
    std::vector<DDim> ddims = GetOutputsDim(name);
389 390 391 392 393
    auto length = ddims.size();
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Output(%s) should have 1 value, "
                      "but it has %d now",
                      name, length);
Q
qiaolongfei 已提交
394
    return ddims[0];
Q
tmp  
qiaolongfei 已提交
395 396
  }

Q
qiaolongfei 已提交
397
  void SetOutputDim(const std::string& name, const DDim& dim) override {
Q
qiaolongfei 已提交
398
    SetOutputsDim(name, {dim});
Q
tmp  
qiaolongfei 已提交
399 400
  }

Q
qiaolongfei 已提交
401
  AttrReader Attrs() const override { return AttrReader(op_.GetAttrMap()); }
Q
tmp  
qiaolongfei 已提交
402

Q
qiaolongfei 已提交
403 404
  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
Q
qiaolongfei 已提交
405
    return op_.Input(name);
Q
tmp  
qiaolongfei 已提交
406 407
  }

Q
qiaolongfei 已提交
408 409
  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
Q
qiaolongfei 已提交
410
    return op_.Output(name);
Q
tmp  
qiaolongfei 已提交
411 412 413
  }

 private:
Q
qiaolongfei 已提交
414
  DDim GetDim(const std::string& name) const override {
Y
Yu Yang 已提交
415 416 417
    auto var = block_.FindVarRecursive(name);
    PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
    return framework::make_ddim(var->Shape());
Q
tmp  
qiaolongfei 已提交
418 419
  }

Q
qiaolongfei 已提交
420
  void SetDim(const std::string& name, const DDim& dim) override {
421
    block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
Q
tmp  
qiaolongfei 已提交
422 423
  }

Q
qiaolongfei 已提交
424 425
  const OpDescBind& op_;
  const BlockDescBind& block_;
Q
tmp  
qiaolongfei 已提交
426 427
};

428
class RuntimeInferShapeContext : public InferShapeContext {
Q
Qiao Longfei 已提交
429 430 431 432
 public:
  RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
      : op_(op), scope_(scope) {}

Q
qiaolongfei 已提交
433
  bool HasInput(const std::string& name) const override {
434 435 436 437 438 439 440 441
    auto& ins = Inputs(name);
    size_t length = ins.size();
    if (length == 0) {
      return false;
    }
    PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs",
                      name);
    auto ipt = ins[0];
Q
Qiao Longfei 已提交
442 443 444 445
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

Q
qiaolongfei 已提交
446
  bool HasOutput(const std::string& name) const override {
447 448 449 450 451 452 453 454
    auto& outs = Outputs(name);
    size_t length = outs.size();
    if (length == 0) {
      return false;
    }
    PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs",
                      name);
    auto ipt = outs[0];
Q
Qiao Longfei 已提交
455 456 457 458
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

Q
qiaolongfei 已提交
459
  bool HasInputs(const std::string& name) const override {
460
    auto inputs = op_.Inputs(name);
Q
qiaolongfei 已提交
461
    if (inputs.empty()) {
462 463 464 465 466 467 468 469 470 471
      return false;
    }
    for (auto& input : inputs) {
      if (scope_.FindVar(input) == nullptr) {
        return false;
      }
    }
    return true;
  }

Q
qiaolongfei 已提交
472
  bool HasOutputs(const std::string& name) const override {
473
    auto outputs = op_.Outputs(name);
Q
qiaolongfei 已提交
474
    if (outputs.empty()) {
475 476 477 478 479 480 481 482 483 484
      return false;
    }
    for (auto& output : outputs) {
      if (scope_.FindVar(output) == nullptr) {
        return false;
      }
    }
    return true;
  }

Q
qiaolongfei 已提交
485
  DDim GetInputDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
486 487 488
    return GetDim(op_.Input(name));
  }

Q
qiaolongfei 已提交
489
  void SetInputDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
490 491 492
    SetDim(op_.Input(name), dim);
  }

Q
qiaolongfei 已提交
493
  DDim GetOutputDim(const std::string& name) const override {
Q
Qiao Longfei 已提交
494 495 496
    return GetDim(op_.Output(name));
  }

Q
qiaolongfei 已提交
497
  void SetOutputDim(const std::string& name, const DDim& dim) override {
Q
Qiao Longfei 已提交
498 499 500
    SetDim(op_.Output(name), dim);
  }

Q
qiaolongfei 已提交
501
  AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
Q
Qiao Longfei 已提交
502

Q
qiaolongfei 已提交
503 504
  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
Q
Qiao Longfei 已提交
505 506 507
    return op_.Inputs(name);
  }

Q
qiaolongfei 已提交
508 509
  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
Q
Qiao Longfei 已提交
510 511 512 513
    return op_.Outputs(name);
  }

 private:
Q
QI JUN 已提交
514 515 516 517 518 519
  DDim GetDim(const std::string& name) const override {
    Variable* var = scope_.FindVar(name);
    if (var->IsType<LoDTensor>()) {
      return var->Get<LoDTensor>().dims();
    } else if (var->IsType<SelectedRows>()) {
      return var->Get<SelectedRows>().GetCompleteDims();
Q
Qiao Longfei 已提交
520
    } else {
Q
QI JUN 已提交
521
      PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
Q
Qiao Longfei 已提交
522 523 524
    }
  }

Q
qiaolongfei 已提交
525
  void SetDim(const std::string& name, const DDim& dim) override {
Q
QI JUN 已提交
526 527 528 529 530 531 532 533
    Variable* var = scope_.FindVar(name);
    if (var->IsType<LoDTensor>()) {
      var->GetMutable<LoDTensor>()->Resize(dim);
    } else if (var->IsType<SelectedRows>()) {
      var->GetMutable<SelectedRows>()->set_height(dim[0]);
    } else {
      PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
    }
Q
Qiao Longfei 已提交
534 535 536 537 538 539
  }

  const OperatorBase& op_;
  const Scope& scope_;
};

Y
Yu Yang 已提交
540
class OpKernelBase {
Q
qijun 已提交
541
 public:
Q
qijun 已提交
542
  /**
543
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
544 545
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
546
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
547 548
   */

549
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
550

Y
Yu Yang 已提交
551 552 553 554 555 556 557
  virtual ~OpKernelBase() = default;
};

template <typename T>
class OpKernel : public OpKernelBase {
 public:
  using ELEMENT_TYPE = T;
Y
Yu Yang 已提交
558 559
};

Q
Qiao Longfei 已提交
560 561
class OperatorWithKernel : public OperatorBase {
 public:
Y
Yu Yang 已提交
562 563
  struct OpKernelKey {
    platform::Place place_;
Y
Yu Yang 已提交
564
    DataType data_type_;
Q
Qiao Longfei 已提交
565

Y
Yu Yang 已提交
566 567 568 569 570
    OpKernelKey(DataType data_type, platform::Place place)
        : place_(place), data_type_(data_type) {}

    OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
        : place_(dev_ctx.GetPlace()), data_type_(data_type) {}
Y
Yu Yang 已提交
571

Q
qijun 已提交
572
    bool operator==(const OpKernelKey& o) const {
Y
Yu Yang 已提交
573 574
      return platform::places_are_same_class(place_, o.place_) &&
             data_type_ == o.data_type_;
Q
qijun 已提交
575
    }
Y
Yu Yang 已提交
576 577 578
  };

  struct OpKernelHash {
Y
Yu Yang 已提交
579
    std::hash<int> hash_;
Y
Yu Yang 已提交
580
    size_t operator()(const OpKernelKey& key) const {
Y
Yu Yang 已提交
581 582
      int place = key.place_.which();
      int data_type = static_cast<int>(key.data_type_);
Y
Yu Yang 已提交
583 584
      int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
                     (place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
Y
Yu Yang 已提交
585
      return hash_(pre_hash);
Y
Yu Yang 已提交
586 587 588 589
    }
  };

  using OpKernelMap =
Y
Yu Yang 已提交
590 591
      std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
                         OpKernelHash>;
Q
Qiao Longfei 已提交
592

Y
Yu Yang 已提交
593 594
  OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
                     const VariableNameMap& outputs, const AttributeMap& attrs)
Y
Yu Yang 已提交
595 596
      : OperatorBase(type, inputs, outputs, attrs) {}

Y
Yu Yang 已提交
597
  void Run(const Scope& scope,
Y
Yu Yang 已提交
598
           const platform::DeviceContext& dev_ctx) const final {
Y
Yu Yang 已提交
599
    VLOG(3) << "Running operator " << this->Type();
Y
Yu Yang 已提交
600 601 602
    RuntimeInferShapeContext infer_shape_ctx(*this, scope);
    this->InferShape(&infer_shape_ctx);

Y
Yu Yang 已提交
603
    ExecutionContext ctx(*this, scope, dev_ctx);
604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622

    // check if op[type] has kernel registered.
    auto& all_op_kernels = AllOpKernels();
    auto kernels_iter = all_op_kernels.find(type_);
    if (kernels_iter == all_op_kernels.end()) {
      PADDLE_THROW("op[%s] has no kernel", type_);
    }

    // check if op[type] have kernel for kernel_key
    OpKernelMap& kernels = kernels_iter->second;
    auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
    auto kernel_iter = kernels.find(kernel_key);

    if (kernel_iter == kernels.end()) {
      PADDLE_THROW("op[%s] has no kernel with kernel_key[%s]", type_,
                   kernel_key);
    }

    kernel_iter->second->Compute(ctx);
Q
Qiao Longfei 已提交
623 624
  }

Y
Yu Yang 已提交
625 626 627 628
  static std::unordered_map<std::string /* op_type */, OpKernelMap>&
  AllOpKernels() {
    static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
    return g_all_op_kernels;
Y
Yu Yang 已提交
629
  }
Y
Yan Chunwei 已提交
630

631
  bool SupportGPU() const override {
Y
Yu Yang 已提交
632 633 634 635 636
    auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
    return std::any_of(op_kernels.begin(), op_kernels.end(),
                       [](OpKernelMap::const_reference kern_pair) {
                         return platform::is_gpu_place(kern_pair.first.place_);
                       });
637 638
  }

639 640 641
  virtual void InferShape(InferShapeContext* ctx) const {
    OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
  }
Y
Yu Yang 已提交
642

Q
qiaolongfei 已提交
643
 protected:
Y
Yu Yang 已提交
644 645 646 647 648 649 650 651 652 653 654 655 656 657
  // indicate kernel DataType by input data. Defaultly all input data must be
  // same.
  virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
    auto& scope = ctx.scope();
    int data_type = -1;
    for (auto& input : this->inputs_) {
      for (auto& ipt_name : input.second) {
        auto* var = scope.FindVar(ipt_name);
        if (var != nullptr) {
          const Tensor* t = nullptr;
          if (var->IsType<Tensor>()) {
            t = &var->Get<Tensor>();
          } else if (var->IsType<LoDTensor>()) {
            t = &var->Get<LoDTensor>();
Q
QI JUN 已提交
658 659
          } else if (var->IsType<SelectedRows>()) {
            t = &(var->Get<SelectedRows>().value());
Y
Yu Yang 已提交
660 661 662
          }
          if (t != nullptr) {
            int tmp = static_cast<int>(ToDataType(t->type()));
Y
Yu Yang 已提交
663
            VLOG(3) << "Input " << ipt_name << " with data_type " << tmp;
Y
Yu Yang 已提交
664
            PADDLE_ENFORCE(tmp == data_type || data_type == -1,
Y
Yu Yang 已提交
665
                           "DataType of Paddle Op %s must be same.", Type());
Y
Yu Yang 已提交
666 667 668 669 670 671 672 673
            data_type = tmp;
          }
        }
      }
    }
    PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
    return static_cast<DataType>(data_type);
  }
Q
Qiao Longfei 已提交
674 675
};

676 677 678
std::ostream& operator<<(std::ostream& os,
                         const OperatorWithKernel::OpKernelKey& kernel_key);

Y
Yu Yang 已提交
679 680
extern bool OpSupportGPU(const std::string& op_type);

Q
Qiao Longfei 已提交
681 682
}  // namespace framework
}  // namespace paddle