operator.h 30.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
18
#include <atomic>
L
luotao1 已提交
19
#include <memory>
20
#include <mutex>  // NOLINT
Q
Qiao Longfei 已提交
21
#include <string>
D
dzhwinter 已提交
22
#include <tuple>
Q
Qiao Longfei 已提交
23
#include <unordered_map>
L
luotao1 已提交
24
#include <utility>
Q
Qiao Longfei 已提交
25 26
#include <vector>

Y
Yu Yang 已提交
27
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
28 29
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/block_desc.h"
30
#include "paddle/fluid/framework/convert_utils.h"
Y
Yi Wang 已提交
31 32 33
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h"
34
#include "paddle/fluid/framework/phi_utils.h"
Y
Yi Wang 已提交
35
#include "paddle/fluid/framework/scope.h"
36
#include "paddle/fluid/framework/selected_rows_utils.h"
37
#include "paddle/fluid/framework/shape_inference.h"
Y
Yi Wang 已提交
38
#include "paddle/fluid/framework/tensor.h"
39
#include "paddle/fluid/framework/unused_var_check.h"
40
#include "paddle/fluid/memory/malloc.h"
Y
Yi Wang 已提交
41
#include "paddle/fluid/platform/device_context.h"
42

43 44
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
45
#include "paddle/phi/core/flags.h"
46
#include "paddle/phi/core/kernel_context.h"
47
#include "paddle/phi/core/kernel_factory.h"
48
#include "paddle/phi/core/macros.h"
49
#include "paddle/utils/flat_hash_map.h"
50

W
wanghuancoder 已提交
51 52 53 54 55 56 57 58
namespace paddle {
namespace framework {
class OpInfo;
class Scope;
class Variable;
}  // namespace framework
}  // namespace paddle

59 60 61 62
namespace phi {
class KernelContext;
}

63
PHI_DECLARE_int32(inner_op_parallelism);
Q
Qiao Longfei 已提交
64

Q
Qiao Longfei 已提交
65 66 67
namespace paddle {
namespace framework {

68
/// If a variable is a empty variable, that name will be used.
69
constexpr char kEmptyVarName[] = "@EMPTY@";
70 71 72

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
73
constexpr char kTempVarName[] = "@TEMP@";
74 75

/// If a variable's name has a certain suffix, it means that the
T
tianshuo78520a 已提交
76 77
/// variable is the gradient of another variable.
/// e.g. Variable "x@GRAD" is the gradient of variable "x".
78
constexpr char kGradVarSuffix[] = "@GRAD";
79

M
minqiyang 已提交
80 81
constexpr size_t kGradVarSuffixSize = 5U;

82
/// Variables with this suffix are supposed to be filled up with zeros.
83
constexpr char kZeroVarSuffix[] = "@ZERO";
84

C
chengduo 已提交
85 86 87
/// Variables with this suffix are the new Gradient.
constexpr char kNewGradSuffix[] = "@NEWGRAD@";

L
luotao1 已提交
88 89 90 91 92 93 94 95
/// RuntimeContext is used to relate input/output names of Operator with
/// the corresponding variables in name scope.
/// If an Op has attribute kEnableCacheRuntimeContext, it means that in a same
/// name scope, since the input/output names of this Op do not change in the
/// execution, RuntimeContext could be created only at the first iteration of
/// this Op's execution to save the elapsed time.
constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@";

L
luotao1 已提交
96 97 98 99 100 101 102 103 104
/// If an Op has this attribute, all its kernels should calculate output
/// variable's shape in the corresponding Compute() function. And
/// OperatorWithKernel::RunImpl() would skip call this Op's InferShape()
/// function in its runtime for speedup.
/// TODO(luotao): Note that this temporal attribute would be deleted after all
/// ops contain it.
constexpr char kAllKernelsMustComputeRuntimeShape[] =
    "@ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE@";

D
dzhwinter 已提交
105
// define some kernel priority
106
/* Define multiple kernel type fallback order*/
D
dzhwinter 已提交
107 108
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;

109
inline std::string GradVarName(const std::string& var_name) {
M
minqiyang 已提交
110 111 112 113 114
  std::string result;
  result.reserve(var_name.size() + kGradVarSuffixSize);
  result += var_name;
  result += kGradVarSuffix;
  return result;
115 116
}

M
minqiyang 已提交
117
inline std::string GradOriginalVarName(const std::string& grad_var_name) {
M
minqiyang 已提交
118
  std::size_t pos = grad_var_name.rfind(kGradVarSuffix);
119 120 121 122 123
  if (pos == std::string::npos) {
    return grad_var_name;
  } else {
    return grad_var_name.substr(0, pos);
  }
124 125
}

126
inline bool VarIsTensor(const Variable& var) {
127
  return var.IsType<phi::DenseTensor>() || var.IsType<phi::SelectedRows>();
128 129
}

130 131 132
const phi::DenseTensor* GetLoDTensorOrSelectedRowsValueFromVar(
    const Variable& var);
phi::DenseTensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
Q
qiaolongfei 已提交
133

134
class ExecutionContext;
W
wanghuancoder 已提交
135
class OperatorBase;
136

X
Xin Pan 已提交
137 138
class RuntimeContext {
 public:
X
Xin Pan 已提交
139
  RuntimeContext(const VariableNameMap& innames,
140 141
                 const VariableNameMap& outnames,
                 const Scope& scope);
X
Xin Pan 已提交
142

X
Xin Pan 已提交
143 144 145 146
  RuntimeContext(const VariableValueMap& invars,
                 const VariableValueMap& outvars)
      : inputs(invars), outputs(outvars) {}

X
Xin Pan 已提交
147 148 149 150
  VariableValueMap inputs;
  VariableValueMap outputs;
};

151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
class RuntimeInferShapeContext : public InferShapeContext {
 public:
  RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx);

  bool HasInput(const std::string& name) const override;

  bool HasOutput(const std::string& name) const override;

  bool HasAttr(const std::string& name) const override;

  bool HasInputs(const std::string& name) const override;

  bool HasOutputs(const std::string& name,
                  bool allow_null = false) const override;

  AttrReader Attrs() const override;

  std::vector<std::string> Inputs(const std::string& name) const override;

  std::vector<std::string> Outputs(const std::string& name) const override;

  std::string GetInputNameByIdx(size_t idx) const override;

  std::string GetOutputNameByIdx(size_t idx) const override;

  void ShareDim(const std::string& in,
                const std::string& out,
                size_t i = 0,
                size_t j = 0) override;

  void ShareAllLoD(const std::string& in,
                   const std::string& out) const override;

  void ShareLoD(const std::string& in,
                const std::string& out,
                size_t i = 0,
                size_t j = 0) const override;

  int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override;

  void SetLoDLevel(const std::string& out,
                   int32_t lod_level,
                   size_t j = 0) const override;

  bool IsRuntime() const override;

  bool IsRunMKLDNNKernel() const override;

  // TODO(paddle-dev): Can this be template?
  paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
  GetInputVarPtrs(const std::string& name) const override;

  paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
  GetOutputVarPtrs(const std::string& name) const override;

  DDim GetInputDim(const std::string& name) const override;

  std::vector<DDim> GetInputsDim(const std::string& name) const override;

  proto::VarType::Type GetInputVarType(const std::string& name) const override;

  std::vector<proto::VarType::Type> GetInputsVarType(
      const std::string& name) const override;

  std::vector<proto::VarType::Type> GetOutputsVarType(
      const std::string& name) const override;

  void SetOutputDim(const std::string& name, const DDim& dim) override;

  void SetOutputsDim(const std::string& name,
                     const std::vector<DDim>& dims) override;

  const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override;

  const phi::KernelSignature* GetPhiDefaultKernelSignature() const override;

  void SetSkipLoD(bool skip);

P
pangengzheng 已提交
229 230 231 232
  std::vector<LoD> GetOutputsLod(const std::string& out) const;

  std::vector<DDim> GetOutputsDim(const std::string& name) const;

233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
 protected:
  DDim GetDim(Variable* var) const;

  std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const;

  std::vector<DDim> GetRepeatedDims(const std::string& name) const override;

  void SetDim(Variable* var, const DDim& dim);

  void SetDims(const std::vector<Variable*>& vars,
               const std::vector<DDim>& dims);

  void SetRepeatedDims(const std::string& name,
                       const std::vector<DDim>& dims) override;

  std::vector<proto::VarType::Type> GetVarTypes(
      const std::vector<Variable*>& vars) const;

  proto::VarType::Type GetVarType(Variable* var) const;

 private:
  const std::vector<Variable*>& InputVars(const std::string& name) const;

  const std::vector<Variable*>& OutputVars(const std::string& name) const;

  const OperatorBase& op_;
  const RuntimeContext& ctx_;
  bool can_skip_lod_{false};
};

Q
Qiao Longfei 已提交
263
/**
X
Xin Pan 已提交
264
 * OperatorBase has the basic elements that Net will call to do computation.
Q
Qiao Longfei 已提交
265 266 267 268 269 270
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
271 272 273 274
  OperatorBase(const std::string& type,
               const VariableNameMap& inputs,
               const VariableNameMap& outputs,
               const AttributeMap& attrs);
275

Q
Qiao Longfei 已提交
276 277
  virtual ~OperatorBase() {}

278
  /// Executor will call this interface function to Run an op.
279 280
  //  The implementation should be written at RunImpl
  void Run(const Scope& scope, const platform::Place& place);
Y
Yu Yang 已提交
281

T
typhoonzero 已提交
282 283 284
  // FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
  virtual void Stop() {}

285
  /// if scope is not null, also show dimensions of arguments
286
  virtual std::string DebugStringEx(const Scope* scope) const;
287
  std::string DebugString() const { return DebugStringEx(nullptr); }
Y
Yu Yang 已提交
288

289
  virtual bool SupportGPU() const { return false; }
290
  virtual bool SupportXPU() const { return false; }
291
  virtual bool SupportCustomDevice() const { return false; }
292

293 294
  const std::string& Type() const { return type_; }

295 296 297
  bool HasAttr(const std::string& name) const {
    return attrs_.count(name) || runtime_attrs_.count(name);
  }
298 299
  template <typename T>
  inline const T& Attr(const std::string& name) const {
300 301 302 303 304 305 306 307 308 309 310
    auto it = attrs_.find(name);
    if (it == attrs_.end()) {
      it = runtime_attrs_.find(name);
      PADDLE_ENFORCE_NE(
          it,
          runtime_attrs_.end(),
          platform::errors::NotFound(
              "(%s) is not found in AttributeMap and RuntimeAttributeMap.",
              name));
    }
    return PADDLE_GET_CONST(T, it->second);
311
  }
312 313
  void SetAttr(const std::string& name, const Attribute& v) {
    PADDLE_ENFORCE_EQ(
314 315
        HasAttr(name),
        true,
316 317 318 319 320
        platform::errors::NotFound(
            "The attribute %s is not found in operator %s", name, Type()));

    attrs_[name] = v;
  }
321
  const AttributeMap& Attrs() const { return attrs_; }
322 323 324 325
  const AttributeMap& RuntimeAttrs() const { return runtime_attrs_; }
  void SetRuntimeAttributeMap(const AttributeMap& runtime_attrs) {
    runtime_attrs_ = runtime_attrs;
  }
D
dongzhihong 已提交
326

Y
Yu Yang 已提交
327 328
  const VariableNameMap& Inputs() const { return inputs_; }
  const VariableNameMap& Outputs() const { return outputs_; }
329 330
  VariableNameMap& Inputs() { return inputs_; }
  VariableNameMap& Outputs() { return outputs_; }
331

S
sneaxiy 已提交
332
  const OpInfo& Info() const {
333
    PADDLE_ENFORCE_NOT_NULL(
334 335 336
        info_,
        platform::errors::NotFound("OpInfo of operator (%s) is not found.",
                                   type_));
S
sneaxiy 已提交
337 338 339
    return *info_;
  }

340
  bool HasInputs(const std::string& name) const;
Y
Yu Yang 已提交
341
  //! Get a input with argument's name described in `op_proto`
342
  std::string Input(const std::string& name) const;
Y
Yu Yang 已提交
343
  //! Get a input which has multiple variables.
Y
Yu Yang 已提交
344
  const std::vector<std::string>& Inputs(const std::string& name) const;
345
  //! Get all inputs variable names
Q
qijun 已提交
346 347
  std::vector<std::string> InputVars() const;

348
  bool HasOutputs(const std::string& name) const;
Y
Yu Yang 已提交
349
  //! Get a output with argument's name described in `op_proto`
350
  std::string Output(const std::string& name) const;
Y
Yu Yang 已提交
351 352
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yu Yang 已提交
353
  const std::vector<std::string>& Outputs(const std::string& name) const;
354
  //! Get all outputs variable names
Y
Yu Yang 已提交
355
  virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
356

357
  void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
358

359
  virtual void SetIsRuntimeInferShape(bool x UNUSED) {}
P
pangengzheng 已提交
360

361 362 363
  virtual void RuntimeInferShape(const Scope& scope UNUSED,
                                 const platform::Place& place UNUSED,
                                 const RuntimeContext& ctx UNUSED) const {}
364

Z
Zhang Ting 已提交
365 366 367 368 369
  virtual platform::Place GetExecutionPlace(
      const platform::Place& place) const {
    return place;
  }

370 371 372 373
  uint64_t Id() const { return id_; }

  void SetId(uint64_t id) { id_ = id; }

374 375 376 377 378
  using HookFunc = std::function<void(OperatorBase*, Scope*)>;
  void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) {
    hookfuncs_ = hookfuncs;
  }

Q
qiaolongfei 已提交
379
 protected:
Q
Qiao Longfei 已提交
380
  std::string type_;
D
dongzhihong 已提交
381
  // NOTE: in case of OpGrad, inputs_ contains:
382
  // I (Inputs)
D
dongzhihong 已提交
383 384
  // O (Outputs)
  // OG (Output Gradients)
Y
Yu Yang 已提交
385
  VariableNameMap inputs_;
Y
Yu Yang 已提交
386

D
dongzhihong 已提交
387 388
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Y
Yu Yang 已提交
389
  VariableNameMap outputs_;
Q
Qiao Longfei 已提交
390
  AttributeMap attrs_;
391 392 393 394 395 396
  // NOTE: runtime_attrs_ contains the attributes which used for dispatching
  // kernel (use_mkldnn, use_cudnn, ...) or passing additional configuration
  // for special heterogeneous kernel (workspace_size_MB, ...).
  // The attributes in runtime_attrs_ are setted by framework (such as PASS),
  // and not in the python api.
  AttributeMap runtime_attrs_;
S
sneaxiy 已提交
397 398 399 400

  // OpInfo
  const OpInfo* info_;

401 402 403
  // OpDesc Id
  uint64_t id_ = UINT64_MAX;

404 405
  // Whether this operator executes in an Executor.
  bool run_by_executor_{true};
406

407 408
  std::vector<HookFunc> hookfuncs_;

409 410 411
 private:
  void GenerateTemporaryNames();
  void CheckAllInputOutputSet() const;
412 413
  virtual void RunImpl(const Scope& scope,
                       const platform::Place& place) const = 0;
Y
Yan Chunwei 已提交
414 415
};

416
class ExecutionContext : public phi::KernelContext {
Y
Yan Chunwei 已提交
417
 public:
418 419
  ExecutionContext(const OperatorBase& op,
                   const Scope& scope,
X
Xin Pan 已提交
420
                   const platform::DeviceContext& device_context,
421 422
                   const RuntimeContext& ctx)
      : op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {}
H
hong 已提交
423
  virtual ~ExecutionContext() {}
424

H
hong 已提交
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
  virtual std::string InputName(const std::string& name) const {
    return op_.Input(name);
  }
  virtual std::vector<std::string> InputNames(const std::string& name) const {
    return op_.Inputs(name);
  }
  virtual std::string OutputName(const std::string& name) const {
    return op_.Output(name);
  }

  virtual std::vector<std::string> OutputNames(const std::string& name) const {
    return op_.Outputs(name);
  }

  virtual bool HasAttr(const std::string& name) const {
    return op_.HasAttr(name);
  }
  virtual const AttributeMap& Attrs() const { return op_.Attrs(); }

  const std::string& Type() const { return op_.Type(); }
Q
qiaolongfei 已提交
445 446 447

  const Scope& scope() const { return scope_; }

Q
qiaolongfei 已提交
448
  template <typename T>
Y
Yu Yang 已提交
449
  inline const T& Attr(const std::string& name) const {
R
Ruibiao Chen 已提交
450
    return PADDLE_GET_CONST(T, GetAttr(name));
Q
qiaolongfei 已提交
451 452
  }

H
hong 已提交
453
  virtual const Attribute& GetAttr(const std::string& name) const {
454 455
    auto iter = op_.Attrs().find(name);
    if (iter == op_.Attrs().end()) {
456 457 458 459 460 461 462 463
      iter = op_.RuntimeAttrs().find(name);
      PADDLE_ENFORCE_NE(
          iter,
          op_.RuntimeAttrs().end(),
          platform::errors::NotFound("(%s) is not found in AttributeMap and "
                                     "RuntimeAttributeMap of (%s) operator.",
                                     name,
                                     op_.Type()));
464
    }
465
    return iter->second;
H
hong 已提交
466
  }
467

H
hong 已提交
468
  virtual bool HasInput(const std::string& name) const;
469

470 471
  virtual bool HasInputs(const std::string& name) const;

H
hong 已提交
472
  virtual bool HasOutput(const std::string& name) const;
473

H
hong 已提交
474
  virtual size_t InputSize(const std::string& name) const {
Y
Yu Yang 已提交
475
    return op_.Inputs(name).size();
Y
Yan Chunwei 已提交
476 477
  }

H
hong 已提交
478
  virtual size_t OutputSize(const std::string& name) const {
Y
Yu Yang 已提交
479
    return op_.Outputs(name).size();
Y
Yan Chunwei 已提交
480 481
  }

H
hong 已提交
482
  virtual const Variable* InputVar(const std::string& name) const;
Y
Yan Chunwei 已提交
483

H
hong 已提交
484
  virtual Variable* OutputVar(const std::string& name) const;
Y
Yan Chunwei 已提交
485

H
hong 已提交
486
  virtual const std::vector<Variable*> MultiInputVar(
487
      const std::string& name) const {
488 489
    LogVarUsageIfUnusedVarCheckEnabled(name);

X
Xin Pan 已提交
490 491 492 493
    auto it = ctx_.inputs.find(name);
    if (it == ctx_.inputs.end()) {
      return {};
    }
G
Gabor Buella 已提交
494
    return {it->second.begin(), it->second.end()};
X
Xin Pan 已提交
495 496
  }

H
hong 已提交
497
  virtual std::vector<Variable*> MultiOutputVar(const std::string& name) const {
X
Xin Pan 已提交
498 499 500 501 502 503 504
    auto it = ctx_.outputs.find(name);
    if (it == ctx_.outputs.end()) {
      return {};
    }
    return it->second;
  }

C
Chen Weihang 已提交
505 506
  virtual paddle::small_vector<const std::string*> InNameList() const {
    paddle::small_vector<const std::string*> vec_temp;
H
hong 已提交
507 508 509
    vec_temp.reserve(ctx_.inputs.size());

    for (auto& input : ctx_.inputs) {
510
      vec_temp.push_back(&input.first);
H
hong 已提交
511 512 513 514 515
    }

    return vec_temp;
  }

516 517
  template <typename T>
  const T* Input(const std::string& name) const {
Y
Yu Yang 已提交
518
    auto* var = InputVar(name);
519
    return var == nullptr ? nullptr : &var->Get<T>();
520 521 522 523
  }

  template <typename T>
  T* Output(const std::string& name) const {
524
    auto var = OutputVar(name);
525
    return var == nullptr ? nullptr : var->GetMutable<T>();
526 527 528 529
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
530 531
    LogVarUsageIfUnusedVarCheckEnabled(name);

H
hong 已提交
532 533
    auto vars = MultiInputVar(name);
    if (vars.size() == 0) {
X
Xin Pan 已提交
534 535 536 537
      return {};
    }
    std::vector<const T*> res;
    res.reserve(vars.size());
538 539 540
    std::transform(vars.begin(),
                   vars.end(),
                   std::back_inserter(res),
H
hong 已提交
541
                   [&](const Variable* var) -> const T* {
X
Xin Pan 已提交
542 543 544 545 546 547 548
                     return var == nullptr ? nullptr : &var->Get<T>();
                   });
    return res;
  }

  template <typename T>
  std::vector<T*> MultiOutput(const std::string& name) const {
H
hong 已提交
549 550 551
    auto vars = MultiOutputVar(name);

    if (vars.size() == 0) {
X
Xin Pan 已提交
552 553
      return {};
    }
H
hong 已提交
554

X
Xin Pan 已提交
555 556
    std::vector<T*> res;
    res.reserve(vars.size());
557 558 559
    std::transform(vars.begin(),
                   vars.end(),
                   std::back_inserter(res),
X
Xin Pan 已提交
560 561 562
                   [&](Variable* var) -> T* {
                     return var == nullptr ? nullptr : var->GetMutable<T>();
                   });
H
hong 已提交
563

X
Xin Pan 已提交
564 565 566
    return res;
  }

567
  platform::Place GetPlace() const { return device_context_.GetPlace(); }
Q
qijun 已提交
568

Q
QI JUN 已提交
569 570 571 572 573
  template <typename DeviceContextType>
  const DeviceContextType& device_context() const {
    return *reinterpret_cast<const DeviceContextType*>(&device_context_);
  }

574
  const platform::DeviceContext& device_context() const {
Q
qijun 已提交
575
    return device_context_;
Q
qijun 已提交
576
  }
Q
qijun 已提交
577

578
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
L
Leo Chen 已提交
579
  const inline phi::GPUContext& cuda_device_context() const {
580 581
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(device_context_.GetPlace()),
                      true,
582 583
                      platform::errors::PreconditionNotMet(
                          "Current device context place is not GPUPlace."));
L
Leo Chen 已提交
584
    return *reinterpret_cast<const phi::GPUContext*>(&device_context_);
Q
QI JUN 已提交
585 586 587
  }
#endif

X
Xin Pan 已提交
588
  template <typename T, typename DevContext>
589 590
  phi::DenseTensor AllocateTmpTensor(const framework::DDim& dim,
                                     const DevContext& dev_ctx) const {
591 592 593 594
    phi::DenseTensor tmp;
    tmp.Resize(dim);
    dev_ctx.template Alloc<T>(&tmp);
    return tmp;
X
Xin Pan 已提交
595 596
  }

H
hong 已提交
597 598 599
  const RuntimeContext Context() const { return ctx_; }

  std::string DebugString() const { return op_.DebugString(); }
600
  const OperatorBase& GetOp() const { return op_; }
H
hong 已提交
601

602
 private:
603 604
  const OperatorBase& op_;
  const Scope& scope_;
605
  const platform::DeviceContext& device_context_;
X
Xin Pan 已提交
606
  const RuntimeContext& ctx_;
Q
Qiao Longfei 已提交
607 608
};

609
// TODO(chenweihang): split impl based OpProto or Dygraph if needed
610
class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
611 612 613 614 615
 public:
  explicit ExecutionArgumentMappingContext(const ExecutionContext& ctx)
      : ctx_(ctx) {}

  bool HasInput(const std::string& name) const override {
616
    return ctx_.HasInputs(name);
617 618 619 620 621 622
  }

  bool HasOutput(const std::string& name) const override {
    return ctx_.HasOutput(name);
  }

623 624 625 626
  bool HasAttr(const std::string& name) const override {
    return ctx_.HasAttr(name);
  }

627 628 629
  paddle::any Attr(const std::string& name) const override {
    auto& attr = ctx_.GetAttr(name);
    return GetAttrValue(attr);
630 631 632
  }

  size_t InputSize(const std::string& name) const override {
633
    return ctx_.MultiInputVar(name).size();
634 635 636
  }

  size_t OutputSize(const std::string& name) const override {
637
    return ctx_.MultiOutputVar(name).size();
638 639 640
  }

  bool IsDenseTensorInput(const std::string& name) const override {
641 642 643 644 645
    const auto* var = ctx_.InputVar(name);
    return var->IsType<phi::DenseTensor>();
  }

  bool IsDenseTensorInputs(const std::string& name) const override {
646 647 648 649
    auto vars = ctx_.MultiInputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::DenseTensor>();
    });
650 651
  }

Y
YuanRisheng 已提交
652 653 654 655 656 657 658
  bool IsSelectedRowsInputs(const std::string& name) const override {
    auto vars = ctx_.MultiInputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::SelectedRows>();
    });
  }

659
  bool IsSelectedRowsInput(const std::string& name) const override {
660 661
    const auto* var = ctx_.InputVar(name);
    return var->IsType<phi::SelectedRows>();
662 663
  }

664
  bool IsDenseTensorVectorInput(const std::string& name) const override {
665 666 667 668
    auto vars = ctx_.MultiInputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<framework::LoDTensorArray>();
    });
669 670
  }

671 672 673 674 675
  bool IsSparseCooTensorInput(const std::string& name) const override {
    const auto* var = ctx_.InputVar(name);
    return var->IsType<phi::SparseCooTensor>();
  }

676 677 678 679 680 681 682
  bool IsSparseCooTensorOutput(const std::string& name) const override {
    auto vars = ctx_.MultiOutputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::SparseCooTensor>();
    });
  }

683 684 685 686 687
  bool IsSparseCsrTensorInput(const std::string& name) const override {
    const auto* var = ctx_.InputVar(name);
    return var->IsType<phi::SparseCsrTensor>();
  }

688
  bool IsDenseTensorOutput(const std::string& name) const override {
689 690 691 692
    auto vars = ctx_.MultiOutputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::DenseTensor>();
    });
693 694 695
  }

  bool IsSelectedRowsOutput(const std::string& name) const override {
696 697 698 699
    auto vars = ctx_.MultiOutputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::SelectedRows>();
    });
700 701
  }

702 703
  bool IsForInferShape() const override { return false; }

704 705 706 707
 private:
  const ExecutionContext& ctx_;
};

708
template <>
709 710
const std::vector<const phi::DenseTensor*>
ExecutionContext::MultiInput<phi::DenseTensor>(const std::string& name) const;
711 712

template <>
713
std::vector<phi::DenseTensor*> ExecutionContext::MultiOutput<phi::DenseTensor>(
714 715
    const std::string& name) const;

Y
Yu Yang 已提交
716
class OpKernelBase {
Q
qijun 已提交
717
 public:
Q
qijun 已提交
718
  /**
719
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
720 721
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
722
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
723 724
   */

725
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
726

Y
Yu Yang 已提交
727 728 729 730 731 732 733
  virtual ~OpKernelBase() = default;
};

template <typename T>
class OpKernel : public OpKernelBase {
 public:
  using ELEMENT_TYPE = T;
Y
Yu Yang 已提交
734 735
};

Y
Yu Yang 已提交
736 737
class OperatorWithKernel : public OperatorBase {
 public:
Y
yuyang18 已提交
738
  using OpKernelFunc = std::function<void(const ExecutionContext&)>;
Y
Yu Yang 已提交
739
  using OpKernelMap =
Y
yuyang18 已提交
740
      std::unordered_map<OpKernelType, OpKernelFunc, OpKernelType::Hash>;
Q
Qiao Longfei 已提交
741

742 743 744
  OperatorWithKernel(const std::string& type,
                     const VariableNameMap& inputs,
                     const VariableNameMap& outputs,
745 746 747
                     const AttributeMap& attrs);

  virtual ~OperatorWithKernel();
Y
Yu Yang 已提交
748

C
chentianyu03 已提交
749
  static paddle::flat_hash_map<std::string /* op_type */, OpKernelMap>&
Y
Yu Yang 已提交
750
  AllOpKernels() {
C
chentianyu03 已提交
751
    static paddle::flat_hash_map<std::string, OpKernelMap> g_all_op_kernels;
Y
Yu Yang 已提交
752
    return g_all_op_kernels;
Y
Yu Yang 已提交
753
  }
Y
Yan Chunwei 已提交
754

755 756
  bool SupportGPU() const override;

757 758
  bool SupportXPU() const override;

759 760
  bool SupportCustomDevice() const override;

761
  bool SupportsMKLDNN(phi::DataType data_type) const;
762

763
  bool SupportsCUDNN(phi::DataType data_type) const;
764

765 766
  bool SupportsKernelType(const OpKernelType& kernel_type,
                          const ExecutionContext& exe_ctx) const;
767

768 769 770
  bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
                       phi::DataType data_type) const;

771 772
  bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
                       proto::VarType::Type data_type) const;
773

774 775 776
  bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
                      phi::DataType data_type) const;

777 778 779
  bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
                      proto::VarType::Type data_type) const;

780
  virtual void InferShape(InferShapeContext* ctx) const;
Y
Yu Yang 已提交
781

P
pangengzheng 已提交
782 783 784 785
  void SetIsRuntimeInferShape(bool x) override {
    all_kernels_must_compute_runtime_shape_ = x;
  }

786 787
  void RuntimeInferShape(const Scope& scope,
                         const platform::Place& place,
X
Xin Pan 已提交
788
                         const RuntimeContext& ctx) const override;
B
baojun-nervana 已提交
789

790 791 792
  proto::VarType::Type IndicateVarDataType(const ExecutionContext& ctx,
                                           const std::string& name) const;

793
  proto::VarType::Type IndicateOrPromoteVarDataTypes(
794 795
      const ExecutionContext& ctx,
      const std::string& name1,
796 797
      const std::string& name2) const;

798 799
  virtual phi::KernelKey GetExpectedKernelType(
      const ExecutionContext& ctx) const;
X
Xin Pan 已提交
800

801 802
  // change this to public so that in dygraph mode we can call it to check if we
  // need transform data
803
  virtual phi::KernelKey GetKernelTypeForVar(
804
      const std::string& var_name,
805
      const phi::DenseTensor& tensor,
806
      const phi::KernelKey& expected_kernel_type) const;
Y
Yu Yang 已提交
807

808
  platform::Place GetExecutionPlace(
809
      const platform::Place& platform UNUSED) const override {
Z
Zhang Ting 已提交
810 811 812
    return kernel_type_->place_;
  }

813
  /* member functions for adapting to phi lib */
814 815 816 817 818 819 820
  /** In the phi::DenseTensor calculation library, the new Kernel adopts a
   * clearer and more streamlined design. The arguments of the Kernel and the
   * input and output arguments registered in the original OpMaker do not match
   * in some cases, so we use map to record the arguments required by the
   * kernel. When selecting Kernel during Op execution, select the arguments of
   * the original Op according to the GetExpectedPhiKernelArgs returned
   * arguments.
821
   */
822
  phi::KernelSignature GetExpectedPhiKernelArgs(
823 824
      const ExecutionContext& ctx) const;

825 826
  /* member functions for adapting to phi lib */
  phi::KernelKey ChoosePhiKernel(const ExecutionContext& ctx) const;
827

828
  void ChooseKernel(const ExecutionContext& ctx) const;
829

830 831
  void BuildPhiKernelContext(const RuntimeContext& ctx,
                             platform::DeviceContext* dev_ctx,
832
                             phi::KernelContext* phi_kernel_context) const;
833

834
  phi::KernelSignature* PhiKernelSignature() const {
835
    return kernel_signature_.get();
836 837
  }

838
  phi::Kernel* PhiKernel() const { return phi_kernel_.get(); }
839

840
  void ResetPhiKernel(phi::Kernel* kernel) const {
841
    return phi_kernel_.reset(kernel);
842 843
  }

844
  const OpKernelType* kernel_type() const { return kernel_type_.get(); }
845
  const OpKernelFunc* kernel_func() const { return kernel_func_.get(); }
846

847 848 849 850
  void ResetKernelType(OpKernelType* kernel_type) {
    kernel_type_.reset(kernel_type);
  }

851 852 853 854
  bool DnnFallback() const { return dnn_fallback_; }

  void SetDnnFallback(bool dnn_fallback) const { dnn_fallback_ = dnn_fallback; }

Y
Yu Yang 已提交
855
 private:
856
  void RunImpl(const Scope& scope, const platform::Place& place) const final;
857 858
  void RunImpl(const Scope& scope,
               const platform::Place& place,
L
luotao1 已提交
859
               RuntimeContext* runtime_ctx) const;
Y
yuyang18 已提交
860 861

  /**
T
tianshuo78520a 已提交
862
   * Transfer data from scope to a transferred scope. If there is no data need
863
   * to be transferred, it returns nullptr.
Y
yuyang18 已提交
864
   *
865
   * transfered_inplace_vars is a output vector.
Y
yuyang18 已提交
866
   */
X
Xin Pan 已提交
867
  Scope* PrepareData(const Scope& scope,
868
                     const phi::KernelKey& expected_kernel_key,
X
Xin Pan 已提交
869
                     std::vector<std::string>* transfered_inplace_vars,
870 871
                     RuntimeContext* ctx,
                     const phi::Place& place) const;
Y
yuyang18 已提交
872

873 874 875 876
  void CheckWhetherPreparePhiData(const VariableNameMap& innames,
                                  const VariableNameMap& outnames,
                                  const Scope& scope) const;

Y
yuyang18 已提交
877 878 879
  void TransferInplaceVarsBack(const Scope& scope,
                               const std::vector<std::string>& inplace_vars,
                               const Scope& exec_scope) const;
880

881 882
  OpKernelType InnerGetExpectedKernelType(const ExecutionContext& ctx) const;

883 884 885
  void HandleComplexGradToRealGrad(const Scope& scope,
                                   RuntimeContext* ctx) const;

886 887 888 889 890
  /* Inner assist methods */
  // indicate kernel DataType by input data.
  // By default all input data must be same.
  proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
  // used for IndicateDataType
891 892
  void ParseInputDataType(const Variable* vars,
                          const std::string& name,
893
                          proto::VarType::Type* data_type) const;
894 895 896
  void ParseMultiInputDataType(const std::vector<Variable*>& vars,
                               const std::string& name,
                               proto::VarType::Type* data_type) const;
897
  // used for IndicateOrPromoteVarDataTypes
898 899
  phi::DenseTensor* GetTensorFormInputSafely(const ExecutionContext& ctx,
                                             const std::string& name) const;
900

901
 protected:
L
Liu Yiqun 已提交
902 903
  mutable std::unique_ptr<OpKernelType> kernel_type_;
  mutable std::unique_ptr<OpKernelFunc> kernel_func_;
L
luotao1 已提交
904
  mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
C
csy0225 已提交
905
  mutable const Scope* pre_scope_ = nullptr;
906
  mutable bool need_prepare_data_ = true;
907
  mutable bool need_prepare_phi_data_ = false;
908 909
  mutable bool enable_cache_runtime_context_ = false;
  mutable bool all_kernels_must_compute_runtime_shape_ = false;
910
  mutable std::mutex cache_update_mutex_;
911
  mutable bool enable_cache_transfer_scope_ = false;
912 913 914 915
  // NOTE(jiahongyu): Whether fallback to plain kernel after calling
  // GetExpectedKernelType, use this bool flag to solve mkldnn and cudnn hard
  // code
  mutable bool dnn_fallback_ = false;
916
  // NOTE(chenweihang): Similar op members are used to adapt to
917
  // new phi kernel, if there is a better design in the future,
918
  // we may polish the implementation here
919
  mutable bool run_phi_kernel_ = false;
L
Liu-xiandong 已提交
920
  mutable bool run_kp_kernel = false;
921
  mutable std::unique_ptr<phi::KernelSignature> kernel_signature_;
922
  mutable std::unique_ptr<phi::Kernel> phi_kernel_;
923
  mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;
924

925
 private:
926
  struct CacheImpl;
927
  mutable std::unique_ptr<CacheImpl> impl_;
Q
Qiao Longfei 已提交
928 929
};

Y
Yu Yang 已提交
930 931
extern bool OpSupportGPU(const std::string& op_type);

Q
Qiao Longfei 已提交
932 933
}  // namespace framework
}  // namespace paddle