operator.h 30.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
18
#include <atomic>
L
luotao1 已提交
19
#include <memory>
20
#include <mutex>  // NOLINT
Q
Qiao Longfei 已提交
21
#include <string>
D
dzhwinter 已提交
22
#include <tuple>
Q
Qiao Longfei 已提交
23
#include <unordered_map>
L
luotao1 已提交
24
#include <utility>
Q
Qiao Longfei 已提交
25 26
#include <vector>

Y
Yu Yang 已提交
27
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
28 29
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/block_desc.h"
30
#include "paddle/fluid/framework/convert_utils.h"
Y
Yi Wang 已提交
31 32 33
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h"
34
#include "paddle/fluid/framework/phi_utils.h"
Y
Yi Wang 已提交
35
#include "paddle/fluid/framework/scope.h"
36
#include "paddle/fluid/framework/selected_rows_utils.h"
37
#include "paddle/fluid/framework/shape_inference.h"
Y
Yi Wang 已提交
38
#include "paddle/fluid/framework/tensor.h"
39
#include "paddle/fluid/framework/unused_var_check.h"
40
#include "paddle/fluid/memory/malloc.h"
Y
Yi Wang 已提交
41
#include "paddle/fluid/platform/device_context.h"
42

43 44
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
45
#include "paddle/phi/core/kernel_context.h"
46
#include "paddle/phi/core/kernel_factory.h"
47
#include "paddle/utils/flat_hash_map.h"
48

W
wanghuancoder 已提交
49 50 51 52 53 54 55 56
namespace paddle {
namespace framework {
class OpInfo;
class Scope;
class Variable;
}  // namespace framework
}  // namespace paddle

57 58 59 60
namespace phi {
class KernelContext;
}

Q
Qiao Longfei 已提交
61 62
DECLARE_int32(inner_op_parallelism);

Q
Qiao Longfei 已提交
63 64 65
namespace paddle {
namespace framework {

66
/// If a variable is a empty variable, that name will be used.
67
constexpr char kEmptyVarName[] = "@EMPTY@";
68 69 70

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
71
constexpr char kTempVarName[] = "@TEMP@";
72 73

/// If a variable's name has a certain suffix, it means that the
T
tianshuo78520a 已提交
74 75
/// variable is the gradient of another variable.
/// e.g. Variable "x@GRAD" is the gradient of variable "x".
76
constexpr char kGradVarSuffix[] = "@GRAD";
77

M
minqiyang 已提交
78 79
constexpr size_t kGradVarSuffixSize = 5U;

80
/// Variables with this suffix are supposed to be filled up with zeros.
81
constexpr char kZeroVarSuffix[] = "@ZERO";
82

C
chengduo 已提交
83 84 85
/// Variables with this suffix are the new Gradient.
constexpr char kNewGradSuffix[] = "@NEWGRAD@";

L
luotao1 已提交
86 87 88 89 90 91 92 93
/// RuntimeContext is used to relate input/output names of Operator with
/// the corresponding variables in name scope.
/// If an Op has attribute kEnableCacheRuntimeContext, it means that in a same
/// name scope, since the input/output names of this Op do not change in the
/// execution, RuntimeContext could be created only at the first iteration of
/// this Op's execution to save the elapsed time.
constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@";

L
luotao1 已提交
94 95 96 97 98 99 100 101 102
/// If an Op has this attribute, all its kernels should calculate output
/// variable's shape in the corresponding Compute() function. And
/// OperatorWithKernel::RunImpl() would skip call this Op's InferShape()
/// function in its runtime for speedup.
/// TODO(luotao): Note that this temporal attribute would be deleted after all
/// ops contain it.
constexpr char kAllKernelsMustComputeRuntimeShape[] =
    "@ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE@";

D
dzhwinter 已提交
103
// define some kernel priority
104
/* Define multiple kernel type fallback order*/
D
dzhwinter 已提交
105 106
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;

107
inline std::string GradVarName(const std::string& var_name) {
M
minqiyang 已提交
108 109 110 111 112
  std::string result;
  result.reserve(var_name.size() + kGradVarSuffixSize);
  result += var_name;
  result += kGradVarSuffix;
  return result;
113 114
}

M
minqiyang 已提交
115
inline std::string GradOriginalVarName(const std::string& grad_var_name) {
M
minqiyang 已提交
116
  std::size_t pos = grad_var_name.rfind(kGradVarSuffix);
117 118 119 120 121
  if (pos == std::string::npos) {
    return grad_var_name;
  } else {
    return grad_var_name.substr(0, pos);
  }
122 123
}

124
inline bool VarIsTensor(const Variable& var) {
125
  return var.IsType<phi::DenseTensor>() || var.IsType<phi::SelectedRows>();
126 127
}

128 129 130
const phi::DenseTensor* GetLoDTensorOrSelectedRowsValueFromVar(
    const Variable& var);
phi::DenseTensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
Q
qiaolongfei 已提交
131

132
class ExecutionContext;
W
wanghuancoder 已提交
133
class OperatorBase;
134

X
Xin Pan 已提交
135 136
class RuntimeContext {
 public:
X
Xin Pan 已提交
137
  RuntimeContext(const VariableNameMap& innames,
138 139
                 const VariableNameMap& outnames,
                 const Scope& scope);
X
Xin Pan 已提交
140

X
Xin Pan 已提交
141 142 143 144
  RuntimeContext(const VariableValueMap& invars,
                 const VariableValueMap& outvars)
      : inputs(invars), outputs(outvars) {}

X
Xin Pan 已提交
145 146 147 148
  VariableValueMap inputs;
  VariableValueMap outputs;
};

149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
class RuntimeInferShapeContext : public InferShapeContext {
 public:
  RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx);

  bool HasInput(const std::string& name) const override;

  bool HasOutput(const std::string& name) const override;

  bool HasAttr(const std::string& name) const override;

  bool HasInputs(const std::string& name) const override;

  bool HasOutputs(const std::string& name,
                  bool allow_null = false) const override;

  AttrReader Attrs() const override;

  std::vector<std::string> Inputs(const std::string& name) const override;

  std::vector<std::string> Outputs(const std::string& name) const override;

  std::string GetInputNameByIdx(size_t idx) const override;

  std::string GetOutputNameByIdx(size_t idx) const override;

  void ShareDim(const std::string& in,
                const std::string& out,
                size_t i = 0,
                size_t j = 0) override;

  void ShareAllLoD(const std::string& in,
                   const std::string& out) const override;

  void ShareLoD(const std::string& in,
                const std::string& out,
                size_t i = 0,
                size_t j = 0) const override;

  int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override;

  void SetLoDLevel(const std::string& out,
                   int32_t lod_level,
                   size_t j = 0) const override;

  bool IsRuntime() const override;

  bool IsRunMKLDNNKernel() const override;

  // TODO(paddle-dev): Can this be template?
  paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
  GetInputVarPtrs(const std::string& name) const override;

  paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
  GetOutputVarPtrs(const std::string& name) const override;

  DDim GetInputDim(const std::string& name) const override;

  std::vector<DDim> GetInputsDim(const std::string& name) const override;

  proto::VarType::Type GetInputVarType(const std::string& name) const override;

  std::vector<proto::VarType::Type> GetInputsVarType(
      const std::string& name) const override;

  std::vector<proto::VarType::Type> GetOutputsVarType(
      const std::string& name) const override;

  void SetOutputDim(const std::string& name, const DDim& dim) override;

  void SetOutputsDim(const std::string& name,
                     const std::vector<DDim>& dims) override;

  const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override;

  const phi::KernelSignature* GetPhiDefaultKernelSignature() const override;

  void SetSkipLoD(bool skip);

 protected:
  DDim GetDim(Variable* var) const;

  std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const;

  std::vector<DDim> GetRepeatedDims(const std::string& name) const override;

  void SetDim(Variable* var, const DDim& dim);

  void SetDims(const std::vector<Variable*>& vars,
               const std::vector<DDim>& dims);

  void SetRepeatedDims(const std::string& name,
                       const std::vector<DDim>& dims) override;

  std::vector<proto::VarType::Type> GetVarTypes(
      const std::vector<Variable*>& vars) const;

  proto::VarType::Type GetVarType(Variable* var) const;

 private:
  const std::vector<Variable*>& InputVars(const std::string& name) const;

  const std::vector<Variable*>& OutputVars(const std::string& name) const;

  const OperatorBase& op_;
  const RuntimeContext& ctx_;
  bool can_skip_lod_{false};
};

Q
Qiao Longfei 已提交
257
/**
X
Xin Pan 已提交
258
 * OperatorBase has the basic elements that Net will call to do computation.
Q
Qiao Longfei 已提交
259 260 261 262 263 264
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
265 266 267 268
  OperatorBase(const std::string& type,
               const VariableNameMap& inputs,
               const VariableNameMap& outputs,
               const AttributeMap& attrs);
269

Q
Qiao Longfei 已提交
270 271
  virtual ~OperatorBase() {}

272
  /// Executor will call this interface function to Run an op.
273 274
  //  The implementation should be written at RunImpl
  void Run(const Scope& scope, const platform::Place& place);
Y
Yu Yang 已提交
275

T
typhoonzero 已提交
276 277 278
  // FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
  virtual void Stop() {}

279
  /// if scope is not null, also show dimensions of arguments
280
  virtual std::string DebugStringEx(const Scope* scope) const;
281
  std::string DebugString() const { return DebugStringEx(nullptr); }
Y
Yu Yang 已提交
282

283
  virtual bool SupportGPU() const { return false; }
B
Baibaifan 已提交
284
  virtual bool SupportNPU() const { return false; }
F
fwenguang 已提交
285
  virtual bool SupportMLU() const { return false; }
286
  virtual bool SupportXPU() const { return false; }
287

288 289
  const std::string& Type() const { return type_; }

290 291 292
  bool HasAttr(const std::string& name) const {
    return attrs_.count(name) || runtime_attrs_.count(name);
  }
293 294
  template <typename T>
  inline const T& Attr(const std::string& name) const {
295 296 297 298 299 300 301 302 303 304 305
    auto it = attrs_.find(name);
    if (it == attrs_.end()) {
      it = runtime_attrs_.find(name);
      PADDLE_ENFORCE_NE(
          it,
          runtime_attrs_.end(),
          platform::errors::NotFound(
              "(%s) is not found in AttributeMap and RuntimeAttributeMap.",
              name));
    }
    return PADDLE_GET_CONST(T, it->second);
306
  }
307 308
  void SetAttr(const std::string& name, const Attribute& v) {
    PADDLE_ENFORCE_EQ(
309 310
        HasAttr(name),
        true,
311 312 313 314 315
        platform::errors::NotFound(
            "The attribute %s is not found in operator %s", name, Type()));

    attrs_[name] = v;
  }
316
  const AttributeMap& Attrs() const { return attrs_; }
317 318 319 320
  const AttributeMap& RuntimeAttrs() const { return runtime_attrs_; }
  void SetRuntimeAttributeMap(const AttributeMap& runtime_attrs) {
    runtime_attrs_ = runtime_attrs;
  }
D
dongzhihong 已提交
321

Y
Yu Yang 已提交
322 323
  const VariableNameMap& Inputs() const { return inputs_; }
  const VariableNameMap& Outputs() const { return outputs_; }
324 325
  VariableNameMap& Inputs() { return inputs_; }
  VariableNameMap& Outputs() { return outputs_; }
326

S
sneaxiy 已提交
327
  const OpInfo& Info() const {
328
    PADDLE_ENFORCE_NOT_NULL(
329 330 331
        info_,
        platform::errors::NotFound("OpInfo of operator (%s) is not found.",
                                   type_));
S
sneaxiy 已提交
332 333 334
    return *info_;
  }

335
  bool HasInputs(const std::string& name) const;
Y
Yu Yang 已提交
336
  //! Get a input with argument's name described in `op_proto`
337
  std::string Input(const std::string& name) const;
Y
Yu Yang 已提交
338
  //! Get a input which has multiple variables.
Y
Yu Yang 已提交
339
  const std::vector<std::string>& Inputs(const std::string& name) const;
340
  //! Get all inputs variable names
Q
qijun 已提交
341 342
  std::vector<std::string> InputVars() const;

343
  bool HasOutputs(const std::string& name) const;
Y
Yu Yang 已提交
344
  //! Get a output with argument's name described in `op_proto`
345
  std::string Output(const std::string& name) const;
Y
Yu Yang 已提交
346 347
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yu Yang 已提交
348
  const std::vector<std::string>& Outputs(const std::string& name) const;
349
  //! Get all outputs variable names
Y
Yu Yang 已提交
350
  virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
351

352
  void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
353

B
baojun-nervana 已提交
354
  virtual void RuntimeInferShape(const Scope& scope,
X
Xin Pan 已提交
355 356
                                 const platform::Place& place,
                                 const RuntimeContext& ctx) const {}
357

Z
Zhang Ting 已提交
358 359 360 361 362
  virtual platform::Place GetExecutionPlace(
      const platform::Place& place) const {
    return place;
  }

363 364 365 366
  uint64_t Id() const { return id_; }

  void SetId(uint64_t id) { id_ = id; }

Q
qiaolongfei 已提交
367
 protected:
Q
Qiao Longfei 已提交
368
  std::string type_;
D
dongzhihong 已提交
369
  // NOTE: in case of OpGrad, inputs_ contains:
370
  // I (Inputs)
D
dongzhihong 已提交
371 372
  // O (Outputs)
  // OG (Output Gradients)
Y
Yu Yang 已提交
373
  VariableNameMap inputs_;
Y
Yu Yang 已提交
374

D
dongzhihong 已提交
375 376
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Y
Yu Yang 已提交
377
  VariableNameMap outputs_;
Q
Qiao Longfei 已提交
378
  AttributeMap attrs_;
379 380 381 382 383 384
  // NOTE: runtime_attrs_ contains the attributes which used for dispatching
  // kernel (use_mkldnn, use_cudnn, ...) or passing additional configuration
  // for special heterogeneous kernel (workspace_size_MB, ...).
  // The attributes in runtime_attrs_ are setted by framework (such as PASS),
  // and not in the python api.
  AttributeMap runtime_attrs_;
S
sneaxiy 已提交
385 386 387 388

  // OpInfo
  const OpInfo* info_;

389 390 391
  // OpDesc Id
  uint64_t id_ = UINT64_MAX;

392 393
  // Whether this operator executes in an Executor.
  bool run_by_executor_{true};
394 395 396 397

 private:
  void GenerateTemporaryNames();
  void CheckAllInputOutputSet() const;
398 399
  virtual void RunImpl(const Scope& scope,
                       const platform::Place& place) const = 0;
Y
Yan Chunwei 已提交
400 401
};

402
class ExecutionContext : public phi::KernelContext {
Y
Yan Chunwei 已提交
403
 public:
404 405
  ExecutionContext(const OperatorBase& op,
                   const Scope& scope,
X
Xin Pan 已提交
406
                   const platform::DeviceContext& device_context,
407 408
                   const RuntimeContext& ctx)
      : op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {}
H
hong 已提交
409
  virtual ~ExecutionContext() {}
410

H
hong 已提交
411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
  virtual std::string InputName(const std::string& name) const {
    return op_.Input(name);
  }
  virtual std::vector<std::string> InputNames(const std::string& name) const {
    return op_.Inputs(name);
  }
  virtual std::string OutputName(const std::string& name) const {
    return op_.Output(name);
  }

  virtual std::vector<std::string> OutputNames(const std::string& name) const {
    return op_.Outputs(name);
  }

  virtual bool HasAttr(const std::string& name) const {
    return op_.HasAttr(name);
  }
  virtual const AttributeMap& Attrs() const { return op_.Attrs(); }

  const std::string& Type() const { return op_.Type(); }
Q
qiaolongfei 已提交
431 432 433

  const Scope& scope() const { return scope_; }

Q
qiaolongfei 已提交
434
  template <typename T>
Y
Yu Yang 已提交
435
  inline const T& Attr(const std::string& name) const {
R
Ruibiao Chen 已提交
436
    return PADDLE_GET_CONST(T, GetAttr(name));
Q
qiaolongfei 已提交
437 438
  }

H
hong 已提交
439
  virtual const Attribute& GetAttr(const std::string& name) const {
440 441
    auto iter = op_.Attrs().find(name);
    if (iter == op_.Attrs().end()) {
442 443 444 445 446 447 448 449
      iter = op_.RuntimeAttrs().find(name);
      PADDLE_ENFORCE_NE(
          iter,
          op_.RuntimeAttrs().end(),
          platform::errors::NotFound("(%s) is not found in AttributeMap and "
                                     "RuntimeAttributeMap of (%s) operator.",
                                     name,
                                     op_.Type()));
450
    }
451
    return iter->second;
H
hong 已提交
452
  }
453

H
hong 已提交
454
  virtual bool HasInput(const std::string& name) const;
455

456 457
  virtual bool HasInputs(const std::string& name) const;

H
hong 已提交
458
  virtual bool HasOutput(const std::string& name) const;
459

H
hong 已提交
460
  virtual size_t InputSize(const std::string& name) const {
Y
Yu Yang 已提交
461
    return op_.Inputs(name).size();
Y
Yan Chunwei 已提交
462 463
  }

H
hong 已提交
464
  virtual size_t OutputSize(const std::string& name) const {
Y
Yu Yang 已提交
465
    return op_.Outputs(name).size();
Y
Yan Chunwei 已提交
466 467
  }

H
hong 已提交
468
  virtual const Variable* InputVar(const std::string& name) const;
Y
Yan Chunwei 已提交
469

H
hong 已提交
470
  virtual Variable* OutputVar(const std::string& name) const;
Y
Yan Chunwei 已提交
471

H
hong 已提交
472
  virtual const std::vector<Variable*> MultiInputVar(
473
      const std::string& name) const {
474 475
    LogVarUsageIfUnusedVarCheckEnabled(name);

X
Xin Pan 已提交
476 477 478 479
    auto it = ctx_.inputs.find(name);
    if (it == ctx_.inputs.end()) {
      return {};
    }
G
Gabor Buella 已提交
480
    return {it->second.begin(), it->second.end()};
X
Xin Pan 已提交
481 482
  }

H
hong 已提交
483
  virtual std::vector<Variable*> MultiOutputVar(const std::string& name) const {
X
Xin Pan 已提交
484 485 486 487 488 489 490
    auto it = ctx_.outputs.find(name);
    if (it == ctx_.outputs.end()) {
      return {};
    }
    return it->second;
  }

C
Chen Weihang 已提交
491 492
  virtual paddle::small_vector<const std::string*> InNameList() const {
    paddle::small_vector<const std::string*> vec_temp;
H
hong 已提交
493 494 495
    vec_temp.reserve(ctx_.inputs.size());

    for (auto& input : ctx_.inputs) {
496
      vec_temp.push_back(&input.first);
H
hong 已提交
497 498 499 500 501
    }

    return vec_temp;
  }

502 503
  template <typename T>
  const T* Input(const std::string& name) const {
Y
Yu Yang 已提交
504
    auto* var = InputVar(name);
505
    return var == nullptr ? nullptr : &var->Get<T>();
506 507 508 509
  }

  template <typename T>
  T* Output(const std::string& name) const {
510
    auto var = OutputVar(name);
511
    return var == nullptr ? nullptr : var->GetMutable<T>();
512 513 514 515
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
516 517
    LogVarUsageIfUnusedVarCheckEnabled(name);

H
hong 已提交
518 519
    auto vars = MultiInputVar(name);
    if (vars.size() == 0) {
X
Xin Pan 已提交
520 521 522 523
      return {};
    }
    std::vector<const T*> res;
    res.reserve(vars.size());
524 525 526
    std::transform(vars.begin(),
                   vars.end(),
                   std::back_inserter(res),
H
hong 已提交
527
                   [&](const Variable* var) -> const T* {
X
Xin Pan 已提交
528 529 530 531 532 533 534
                     return var == nullptr ? nullptr : &var->Get<T>();
                   });
    return res;
  }

  template <typename T>
  std::vector<T*> MultiOutput(const std::string& name) const {
H
hong 已提交
535 536 537
    auto vars = MultiOutputVar(name);

    if (vars.size() == 0) {
X
Xin Pan 已提交
538 539
      return {};
    }
H
hong 已提交
540

X
Xin Pan 已提交
541 542
    std::vector<T*> res;
    res.reserve(vars.size());
543 544 545
    std::transform(vars.begin(),
                   vars.end(),
                   std::back_inserter(res),
X
Xin Pan 已提交
546 547 548
                   [&](Variable* var) -> T* {
                     return var == nullptr ? nullptr : var->GetMutable<T>();
                   });
H
hong 已提交
549

X
Xin Pan 已提交
550 551 552
    return res;
  }

553
  platform::Place GetPlace() const { return device_context_.GetPlace(); }
Q
qijun 已提交
554

Q
QI JUN 已提交
555 556 557 558 559
  template <typename DeviceContextType>
  const DeviceContextType& device_context() const {
    return *reinterpret_cast<const DeviceContextType*>(&device_context_);
  }

560
  const platform::DeviceContext& device_context() const {
Q
qijun 已提交
561
    return device_context_;
Q
qijun 已提交
562
  }
Q
qijun 已提交
563

564
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
L
Leo Chen 已提交
565
  const inline phi::GPUContext& cuda_device_context() const {
566 567
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(device_context_.GetPlace()),
                      true,
568 569
                      platform::errors::PreconditionNotMet(
                          "Current device context place is not GPUPlace."));
L
Leo Chen 已提交
570
    return *reinterpret_cast<const phi::GPUContext*>(&device_context_);
Q
QI JUN 已提交
571 572 573
  }
#endif

X
Xin Pan 已提交
574
  template <typename T, typename DevContext>
575 576
  phi::DenseTensor AllocateTmpTensor(const framework::DDim& dim,
                                     const DevContext& dev_ctx) const {
577 578 579 580
    phi::DenseTensor tmp;
    tmp.Resize(dim);
    dev_ctx.template Alloc<T>(&tmp);
    return tmp;
X
Xin Pan 已提交
581 582
  }

H
hong 已提交
583 584 585
  const RuntimeContext Context() const { return ctx_; }

  std::string DebugString() const { return op_.DebugString(); }
586
  const OperatorBase& GetOp() const { return op_; }
H
hong 已提交
587

588
 private:
589 590
  const OperatorBase& op_;
  const Scope& scope_;
591
  const platform::DeviceContext& device_context_;
X
Xin Pan 已提交
592
  const RuntimeContext& ctx_;
Q
Qiao Longfei 已提交
593 594
};

595
// TODO(chenweihang): split impl based OpProto or Dygraph if needed
596
class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
597 598 599 600 601
 public:
  explicit ExecutionArgumentMappingContext(const ExecutionContext& ctx)
      : ctx_(ctx) {}

  bool HasInput(const std::string& name) const override {
602
    return ctx_.HasInputs(name);
603 604 605 606 607 608
  }

  bool HasOutput(const std::string& name) const override {
    return ctx_.HasOutput(name);
  }

609 610 611 612
  bool HasAttr(const std::string& name) const override {
    return ctx_.HasAttr(name);
  }

613 614 615
  paddle::any Attr(const std::string& name) const override {
    auto& attr = ctx_.GetAttr(name);
    return GetAttrValue(attr);
616 617 618
  }

  size_t InputSize(const std::string& name) const override {
619
    return ctx_.MultiInputVar(name).size();
620 621 622
  }

  size_t OutputSize(const std::string& name) const override {
623
    return ctx_.MultiOutputVar(name).size();
624 625 626
  }

  bool IsDenseTensorInput(const std::string& name) const override {
627 628 629 630 631
    const auto* var = ctx_.InputVar(name);
    return var->IsType<phi::DenseTensor>();
  }

  bool IsDenseTensorInputs(const std::string& name) const override {
632 633 634 635
    auto vars = ctx_.MultiInputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::DenseTensor>();
    });
636 637
  }

Y
YuanRisheng 已提交
638 639 640 641 642 643 644
  bool IsSelectedRowsInputs(const std::string& name) const override {
    auto vars = ctx_.MultiInputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::SelectedRows>();
    });
  }

645
  bool IsSelectedRowsInput(const std::string& name) const override {
646 647
    const auto* var = ctx_.InputVar(name);
    return var->IsType<phi::SelectedRows>();
648 649
  }

650
  bool IsDenseTensorVectorInput(const std::string& name) const override {
651 652 653 654
    auto vars = ctx_.MultiInputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<framework::LoDTensorArray>();
    });
655 656
  }

657 658 659 660 661
  bool IsSparseCooTensorInput(const std::string& name) const override {
    const auto* var = ctx_.InputVar(name);
    return var->IsType<phi::SparseCooTensor>();
  }

662 663 664 665 666 667 668
  bool IsSparseCooTensorOutput(const std::string& name) const override {
    auto vars = ctx_.MultiOutputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::SparseCooTensor>();
    });
  }

669 670 671 672 673
  bool IsSparseCsrTensorInput(const std::string& name) const override {
    const auto* var = ctx_.InputVar(name);
    return var->IsType<phi::SparseCsrTensor>();
  }

674
  bool IsDenseTensorOutput(const std::string& name) const override {
675 676 677 678
    auto vars = ctx_.MultiOutputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::DenseTensor>();
    });
679 680 681
  }

  bool IsSelectedRowsOutput(const std::string& name) const override {
682 683 684 685
    auto vars = ctx_.MultiOutputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::SelectedRows>();
    });
686 687
  }

688 689
  bool IsForInferShape() const override { return false; }

690 691 692 693
 private:
  const ExecutionContext& ctx_;
};

694
template <>
695 696
const std::vector<const phi::DenseTensor*>
ExecutionContext::MultiInput<phi::DenseTensor>(const std::string& name) const;
697 698

template <>
699
std::vector<phi::DenseTensor*> ExecutionContext::MultiOutput<phi::DenseTensor>(
700 701
    const std::string& name) const;

Y
Yu Yang 已提交
702
class OpKernelBase {
Q
qijun 已提交
703
 public:
Q
qijun 已提交
704
  /**
705
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
706 707
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
708
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
709 710
   */

711
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
712

Y
Yu Yang 已提交
713 714 715 716 717 718 719
  virtual ~OpKernelBase() = default;
};

template <typename T>
class OpKernel : public OpKernelBase {
 public:
  using ELEMENT_TYPE = T;
Y
Yu Yang 已提交
720 721
};

Y
Yu Yang 已提交
722 723
class OperatorWithKernel : public OperatorBase {
 public:
Y
yuyang18 已提交
724
  using OpKernelFunc = std::function<void(const ExecutionContext&)>;
Y
Yu Yang 已提交
725
  using OpKernelMap =
Y
yuyang18 已提交
726
      std::unordered_map<OpKernelType, OpKernelFunc, OpKernelType::Hash>;
Q
Qiao Longfei 已提交
727

728 729 730
  OperatorWithKernel(const std::string& type,
                     const VariableNameMap& inputs,
                     const VariableNameMap& outputs,
731 732 733
                     const AttributeMap& attrs);

  virtual ~OperatorWithKernel();
Y
Yu Yang 已提交
734

C
chentianyu03 已提交
735
  static paddle::flat_hash_map<std::string /* op_type */, OpKernelMap>&
Y
Yu Yang 已提交
736
  AllOpKernels() {
C
chentianyu03 已提交
737
    static paddle::flat_hash_map<std::string, OpKernelMap> g_all_op_kernels;
Y
Yu Yang 已提交
738
    return g_all_op_kernels;
Y
Yu Yang 已提交
739
  }
Y
Yan Chunwei 已提交
740

741 742 743
  bool SupportGPU() const override;

  bool SupportNPU() const override;
744

F
fwenguang 已提交
745
  bool SupportMLU() const override {
746
    // TODO(zhiqiu): support phi if needed?
F
fwenguang 已提交
747
    auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
748 749
    return std::any_of(op_kernels.begin(),
                       op_kernels.end(),
F
fwenguang 已提交
750 751 752 753
                       [](OpKernelMap::const_reference kern_pair) {
                         return platform::is_mlu_place(kern_pair.first.place_);
                       });
  }
754 755 756

  bool SupportXPU() const override;

757
  bool SupportsMKLDNN(phi::DataType data_type) const;
758

759
  bool SupportsCUDNN(phi::DataType data_type) const;
760

761 762
  bool SupportsKernelType(const OpKernelType& kernel_type,
                          const ExecutionContext& exe_ctx) const;
763

764 765 766
  bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
                       phi::DataType data_type) const;

767 768
  bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
                       proto::VarType::Type data_type) const;
769

770 771 772
  bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
                      phi::DataType data_type) const;

773 774 775
  bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
                      proto::VarType::Type data_type) const;

776
  virtual void InferShape(InferShapeContext* ctx) const;
Y
Yu Yang 已提交
777

778 779
  void RuntimeInferShape(const Scope& scope,
                         const platform::Place& place,
X
Xin Pan 已提交
780
                         const RuntimeContext& ctx) const override;
B
baojun-nervana 已提交
781

782 783 784
  proto::VarType::Type IndicateVarDataType(const ExecutionContext& ctx,
                                           const std::string& name) const;

785
  proto::VarType::Type IndicateOrPromoteVarDataTypes(
786 787
      const ExecutionContext& ctx,
      const std::string& name1,
788 789
      const std::string& name2) const;

790 791
  virtual phi::KernelKey GetExpectedKernelType(
      const ExecutionContext& ctx) const;
X
Xin Pan 已提交
792

793 794
  // change this to public so that in dygraph mode we can call it to check if we
  // need transform data
795
  virtual phi::KernelKey GetKernelTypeForVar(
796
      const std::string& var_name,
797
      const phi::DenseTensor& tensor,
798
      const phi::KernelKey& expected_kernel_type) const;
Y
Yu Yang 已提交
799

800 801
  platform::Place GetExecutionPlace(
      const platform::Place& platform) const override {
Z
Zhang Ting 已提交
802 803 804
    return kernel_type_->place_;
  }

805
  /* member functions for adapting to phi lib */
806 807 808 809 810 811 812
  /** In the phi::DenseTensor calculation library, the new Kernel adopts a
   * clearer and more streamlined design. The arguments of the Kernel and the
   * input and output arguments registered in the original OpMaker do not match
   * in some cases, so we use map to record the arguments required by the
   * kernel. When selecting Kernel during Op execution, select the arguments of
   * the original Op according to the GetExpectedPhiKernelArgs returned
   * arguments.
813
   */
814
  phi::KernelSignature GetExpectedPhiKernelArgs(
815 816
      const ExecutionContext& ctx) const;

817 818
  /* member functions for adapting to phi lib */
  phi::KernelKey ChoosePhiKernel(const ExecutionContext& ctx) const;
819

820
  void ChooseKernel(const ExecutionContext& ctx) const;
821

822 823
  void BuildPhiKernelContext(const RuntimeContext& ctx,
                             platform::DeviceContext* dev_ctx,
824
                             phi::KernelContext* phi_kernel_context) const;
825

826
  phi::KernelSignature* PhiKernelSignature() const {
827
    return kernel_signature_.get();
828 829
  }

830
  phi::Kernel* PhiKernel() const { return phi_kernel_.get(); }
831

832
  void ResetPhiKernel(phi::Kernel* kernel) const {
833
    return phi_kernel_.reset(kernel);
834 835
  }

836
  const OpKernelType* kernel_type() const { return kernel_type_.get(); }
837
  const OpKernelFunc* kernel_func() const { return kernel_func_.get(); }
838

839 840 841 842
  void ResetKernelType(OpKernelType* kernel_type) {
    kernel_type_.reset(kernel_type);
  }

843 844 845 846
  bool DnnFallback() const { return dnn_fallback_; }

  void SetDnnFallback(bool dnn_fallback) const { dnn_fallback_ = dnn_fallback; }

Y
Yu Yang 已提交
847
 private:
848
  void RunImpl(const Scope& scope, const platform::Place& place) const final;
849 850
  void RunImpl(const Scope& scope,
               const platform::Place& place,
L
luotao1 已提交
851
               RuntimeContext* runtime_ctx) const;
Y
yuyang18 已提交
852 853

  /**
T
tianshuo78520a 已提交
854
   * Transfer data from scope to a transferred scope. If there is no data need
855
   * to be transferred, it returns nullptr.
Y
yuyang18 已提交
856
   *
857
   * transfered_inplace_vars is a output vector.
Y
yuyang18 已提交
858
   */
X
Xin Pan 已提交
859
  Scope* PrepareData(const Scope& scope,
860
                     const phi::KernelKey& expected_kernel_key,
X
Xin Pan 已提交
861
                     std::vector<std::string>* transfered_inplace_vars,
862 863
                     RuntimeContext* ctx,
                     const phi::Place& place) const;
Y
yuyang18 已提交
864

865 866 867 868
  void CheckWhetherPreparePhiData(const VariableNameMap& innames,
                                  const VariableNameMap& outnames,
                                  const Scope& scope) const;

Y
yuyang18 已提交
869 870 871
  void TransferInplaceVarsBack(const Scope& scope,
                               const std::vector<std::string>& inplace_vars,
                               const Scope& exec_scope) const;
872

873 874
  OpKernelType InnerGetExpectedKernelType(const ExecutionContext& ctx) const;

875 876 877
  void HandleComplexGradToRealGrad(const Scope& scope,
                                   RuntimeContext* ctx) const;

878 879 880 881 882
  /* Inner assist methods */
  // indicate kernel DataType by input data.
  // By default all input data must be same.
  proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
  // used for IndicateDataType
883 884
  void ParseInputDataType(const Variable* vars,
                          const std::string& name,
885
                          proto::VarType::Type* data_type) const;
886 887 888
  void ParseMultiInputDataType(const std::vector<Variable*>& vars,
                               const std::string& name,
                               proto::VarType::Type* data_type) const;
889
  // used for IndicateOrPromoteVarDataTypes
890 891
  phi::DenseTensor* GetTensorFormInputSafely(const ExecutionContext& ctx,
                                             const std::string& name) const;
892

893
 protected:
L
Liu Yiqun 已提交
894 895
  mutable std::unique_ptr<OpKernelType> kernel_type_;
  mutable std::unique_ptr<OpKernelFunc> kernel_func_;
L
luotao1 已提交
896
  mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
C
csy0225 已提交
897
  mutable const Scope* pre_scope_ = nullptr;
898
  mutable bool need_prepare_data_ = true;
899
  mutable bool need_prepare_phi_data_ = false;
900 901
  mutable bool enable_cache_runtime_context_ = false;
  mutable bool all_kernels_must_compute_runtime_shape_ = false;
902
  mutable std::mutex cache_update_mutex_;
903
  mutable bool enable_cache_transfer_scope_ = false;
904 905 906 907
  // NOTE(jiahongyu): Whether fallback to plain kernel after calling
  // GetExpectedKernelType, use this bool flag to solve mkldnn and cudnn hard
  // code
  mutable bool dnn_fallback_ = false;
908
  // NOTE(chenweihang): Similar op members are used to adapt to
909
  // new phi kernel, if there is a better design in the future,
910
  // we may polish the implementation here
911
  mutable bool run_phi_kernel_ = false;
L
Liu-xiandong 已提交
912
  mutable bool run_kp_kernel = false;
913
  mutable std::unique_ptr<phi::KernelSignature> kernel_signature_;
914
  mutable std::unique_ptr<phi::Kernel> phi_kernel_;
915
  mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;
916

917
 private:
918
  struct CacheImpl;
919
  mutable std::unique_ptr<CacheImpl> impl_;
Q
Qiao Longfei 已提交
920 921
};

Y
Yu Yang 已提交
922 923
extern bool OpSupportGPU(const std::string& op_type);

Q
Qiao Longfei 已提交
924 925
}  // namespace framework
}  // namespace paddle