operator.h 30.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
18
#include <atomic>
L
luotao1 已提交
19
#include <memory>
20
#include <mutex>  // NOLINT
Q
Qiao Longfei 已提交
21
#include <string>
D
dzhwinter 已提交
22
#include <tuple>
Q
Qiao Longfei 已提交
23
#include <unordered_map>
L
luotao1 已提交
24
#include <utility>
Q
Qiao Longfei 已提交
25 26
#include <vector>

Y
Yu Yang 已提交
27
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
28 29
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/block_desc.h"
30
#include "paddle/fluid/framework/convert_utils.h"
Y
Yi Wang 已提交
31 32 33
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h"
34
#include "paddle/fluid/framework/phi_utils.h"
Y
Yi Wang 已提交
35
#include "paddle/fluid/framework/scope.h"
36
#include "paddle/fluid/framework/selected_rows_utils.h"
37
#include "paddle/fluid/framework/shape_inference.h"
Y
Yi Wang 已提交
38
#include "paddle/fluid/framework/tensor.h"
39
#include "paddle/fluid/framework/unused_var_check.h"
40
#include "paddle/fluid/memory/malloc.h"
Y
Yi Wang 已提交
41
#include "paddle/fluid/platform/device_context.h"
42

43 44
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
45
#include "paddle/phi/core/kernel_context.h"
46
#include "paddle/phi/core/kernel_factory.h"
47
#include "paddle/utils/flat_hash_map.h"
48

W
wanghuancoder 已提交
49 50 51 52 53 54 55 56
namespace paddle {
namespace framework {
class OpInfo;
class Scope;
class Variable;
}  // namespace framework
}  // namespace paddle

57 58 59 60
namespace phi {
class KernelContext;
}

Q
Qiao Longfei 已提交
61 62
DECLARE_int32(inner_op_parallelism);

Q
Qiao Longfei 已提交
63 64 65
namespace paddle {
namespace framework {

66
/// If a variable is a empty variable, that name will be used.
67
constexpr char kEmptyVarName[] = "@EMPTY@";
68 69 70

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
71
constexpr char kTempVarName[] = "@TEMP@";
72 73

/// If a variable's name has a certain suffix, it means that the
T
tianshuo78520a 已提交
74 75
/// variable is the gradient of another variable.
/// e.g. Variable "x@GRAD" is the gradient of variable "x".
76
constexpr char kGradVarSuffix[] = "@GRAD";
77

M
minqiyang 已提交
78 79
constexpr size_t kGradVarSuffixSize = 5U;

80
/// Variables with this suffix are supposed to be filled up with zeros.
81
constexpr char kZeroVarSuffix[] = "@ZERO";
82

C
chengduo 已提交
83 84 85
/// Variables with this suffix are the new Gradient.
constexpr char kNewGradSuffix[] = "@NEWGRAD@";

L
luotao1 已提交
86 87 88 89 90 91 92 93
/// RuntimeContext is used to relate input/output names of Operator with
/// the corresponding variables in name scope.
/// If an Op has attribute kEnableCacheRuntimeContext, it means that in a same
/// name scope, since the input/output names of this Op do not change in the
/// execution, RuntimeContext could be created only at the first iteration of
/// this Op's execution to save the elapsed time.
constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@";

L
luotao1 已提交
94 95 96 97 98 99 100 101 102
/// If an Op has this attribute, all its kernels should calculate output
/// variable's shape in the corresponding Compute() function. And
/// OperatorWithKernel::RunImpl() would skip call this Op's InferShape()
/// function in its runtime for speedup.
/// TODO(luotao): Note that this temporal attribute would be deleted after all
/// ops contain it.
constexpr char kAllKernelsMustComputeRuntimeShape[] =
    "@ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE@";

D
dzhwinter 已提交
103
// define some kernel priority
104
/* Define multiple kernel type fallback order*/
D
dzhwinter 已提交
105 106
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;

107
inline std::string GradVarName(const std::string& var_name) {
M
minqiyang 已提交
108 109 110 111 112
  std::string result;
  result.reserve(var_name.size() + kGradVarSuffixSize);
  result += var_name;
  result += kGradVarSuffix;
  return result;
113 114
}

M
minqiyang 已提交
115
inline std::string GradOriginalVarName(const std::string& grad_var_name) {
M
minqiyang 已提交
116
  std::size_t pos = grad_var_name.rfind(kGradVarSuffix);
117 118 119 120 121
  if (pos == std::string::npos) {
    return grad_var_name;
  } else {
    return grad_var_name.substr(0, pos);
  }
122 123
}

124
inline bool VarIsTensor(const Variable& var) {
125
  return var.IsType<phi::DenseTensor>() || var.IsType<phi::SelectedRows>();
126 127
}

128 129 130
const phi::DenseTensor* GetLoDTensorOrSelectedRowsValueFromVar(
    const Variable& var);
phi::DenseTensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
Q
qiaolongfei 已提交
131

132
class ExecutionContext;
W
wanghuancoder 已提交
133
class OperatorBase;
134

X
Xin Pan 已提交
135 136
class RuntimeContext {
 public:
X
Xin Pan 已提交
137
  RuntimeContext(const VariableNameMap& innames,
138 139
                 const VariableNameMap& outnames,
                 const Scope& scope);
X
Xin Pan 已提交
140

X
Xin Pan 已提交
141 142 143 144
  RuntimeContext(const VariableValueMap& invars,
                 const VariableValueMap& outvars)
      : inputs(invars), outputs(outvars) {}

X
Xin Pan 已提交
145 146 147 148
  VariableValueMap inputs;
  VariableValueMap outputs;
};

149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
class RuntimeInferShapeContext : public InferShapeContext {
 public:
  RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx);

  bool HasInput(const std::string& name) const override;

  bool HasOutput(const std::string& name) const override;

  bool HasAttr(const std::string& name) const override;

  bool HasInputs(const std::string& name) const override;

  bool HasOutputs(const std::string& name,
                  bool allow_null = false) const override;

  AttrReader Attrs() const override;

  std::vector<std::string> Inputs(const std::string& name) const override;

  std::vector<std::string> Outputs(const std::string& name) const override;

  std::string GetInputNameByIdx(size_t idx) const override;

  std::string GetOutputNameByIdx(size_t idx) const override;

  void ShareDim(const std::string& in,
                const std::string& out,
                size_t i = 0,
                size_t j = 0) override;

  void ShareAllLoD(const std::string& in,
                   const std::string& out) const override;

  void ShareLoD(const std::string& in,
                const std::string& out,
                size_t i = 0,
                size_t j = 0) const override;

  int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override;

  void SetLoDLevel(const std::string& out,
                   int32_t lod_level,
                   size_t j = 0) const override;

  bool IsRuntime() const override;

  bool IsRunMKLDNNKernel() const override;

  // TODO(paddle-dev): Can this be template?
  paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
  GetInputVarPtrs(const std::string& name) const override;

  paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
  GetOutputVarPtrs(const std::string& name) const override;

  DDim GetInputDim(const std::string& name) const override;

  std::vector<DDim> GetInputsDim(const std::string& name) const override;

  proto::VarType::Type GetInputVarType(const std::string& name) const override;

  std::vector<proto::VarType::Type> GetInputsVarType(
      const std::string& name) const override;

  std::vector<proto::VarType::Type> GetOutputsVarType(
      const std::string& name) const override;

  void SetOutputDim(const std::string& name, const DDim& dim) override;

  void SetOutputsDim(const std::string& name,
                     const std::vector<DDim>& dims) override;

  const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override;

  const phi::KernelSignature* GetPhiDefaultKernelSignature() const override;

  void SetSkipLoD(bool skip);

P
pangengzheng 已提交
227 228 229 230
  std::vector<LoD> GetOutputsLod(const std::string& out) const;

  std::vector<DDim> GetOutputsDim(const std::string& name) const;

231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
 protected:
  DDim GetDim(Variable* var) const;

  std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const;

  std::vector<DDim> GetRepeatedDims(const std::string& name) const override;

  void SetDim(Variable* var, const DDim& dim);

  void SetDims(const std::vector<Variable*>& vars,
               const std::vector<DDim>& dims);

  void SetRepeatedDims(const std::string& name,
                       const std::vector<DDim>& dims) override;

  std::vector<proto::VarType::Type> GetVarTypes(
      const std::vector<Variable*>& vars) const;

  proto::VarType::Type GetVarType(Variable* var) const;

 private:
  const std::vector<Variable*>& InputVars(const std::string& name) const;

  const std::vector<Variable*>& OutputVars(const std::string& name) const;

  const OperatorBase& op_;
  const RuntimeContext& ctx_;
  bool can_skip_lod_{false};
};

Q
Qiao Longfei 已提交
261
/**
X
Xin Pan 已提交
262
 * OperatorBase has the basic elements that Net will call to do computation.
Q
Qiao Longfei 已提交
263 264 265 266 267 268
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
269 270 271 272
  OperatorBase(const std::string& type,
               const VariableNameMap& inputs,
               const VariableNameMap& outputs,
               const AttributeMap& attrs);
273

Q
Qiao Longfei 已提交
274 275
  virtual ~OperatorBase() {}

276
  /// Executor will call this interface function to Run an op.
277 278
  //  The implementation should be written at RunImpl
  void Run(const Scope& scope, const platform::Place& place);
Y
Yu Yang 已提交
279

T
typhoonzero 已提交
280 281 282
  // FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
  virtual void Stop() {}

283
  /// if scope is not null, also show dimensions of arguments
284
  virtual std::string DebugStringEx(const Scope* scope) const;
285
  std::string DebugString() const { return DebugStringEx(nullptr); }
Y
Yu Yang 已提交
286

287
  virtual bool SupportGPU() const { return false; }
288
  virtual bool SupportXPU() const { return false; }
289

290 291
  const std::string& Type() const { return type_; }

292 293 294
  bool HasAttr(const std::string& name) const {
    return attrs_.count(name) || runtime_attrs_.count(name);
  }
295 296
  template <typename T>
  inline const T& Attr(const std::string& name) const {
297 298 299 300 301 302 303 304 305 306 307
    auto it = attrs_.find(name);
    if (it == attrs_.end()) {
      it = runtime_attrs_.find(name);
      PADDLE_ENFORCE_NE(
          it,
          runtime_attrs_.end(),
          platform::errors::NotFound(
              "(%s) is not found in AttributeMap and RuntimeAttributeMap.",
              name));
    }
    return PADDLE_GET_CONST(T, it->second);
308
  }
309 310
  void SetAttr(const std::string& name, const Attribute& v) {
    PADDLE_ENFORCE_EQ(
311 312
        HasAttr(name),
        true,
313 314 315 316 317
        platform::errors::NotFound(
            "The attribute %s is not found in operator %s", name, Type()));

    attrs_[name] = v;
  }
318
  const AttributeMap& Attrs() const { return attrs_; }
319 320 321 322
  const AttributeMap& RuntimeAttrs() const { return runtime_attrs_; }
  void SetRuntimeAttributeMap(const AttributeMap& runtime_attrs) {
    runtime_attrs_ = runtime_attrs;
  }
D
dongzhihong 已提交
323

Y
Yu Yang 已提交
324 325
  const VariableNameMap& Inputs() const { return inputs_; }
  const VariableNameMap& Outputs() const { return outputs_; }
326 327
  VariableNameMap& Inputs() { return inputs_; }
  VariableNameMap& Outputs() { return outputs_; }
328

S
sneaxiy 已提交
329
  const OpInfo& Info() const {
330
    PADDLE_ENFORCE_NOT_NULL(
331 332 333
        info_,
        platform::errors::NotFound("OpInfo of operator (%s) is not found.",
                                   type_));
S
sneaxiy 已提交
334 335 336
    return *info_;
  }

337
  bool HasInputs(const std::string& name) const;
Y
Yu Yang 已提交
338
  //! Get a input with argument's name described in `op_proto`
339
  std::string Input(const std::string& name) const;
Y
Yu Yang 已提交
340
  //! Get a input which has multiple variables.
Y
Yu Yang 已提交
341
  const std::vector<std::string>& Inputs(const std::string& name) const;
342
  //! Get all inputs variable names
Q
qijun 已提交
343 344
  std::vector<std::string> InputVars() const;

345
  bool HasOutputs(const std::string& name) const;
Y
Yu Yang 已提交
346
  //! Get a output with argument's name described in `op_proto`
347
  std::string Output(const std::string& name) const;
Y
Yu Yang 已提交
348 349
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yu Yang 已提交
350
  const std::vector<std::string>& Outputs(const std::string& name) const;
351
  //! Get all outputs variable names
Y
Yu Yang 已提交
352
  virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
353

354
  void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
355

P
pangengzheng 已提交
356 357
  virtual void SetIsRuntimeInferShape(bool x) {}

B
baojun-nervana 已提交
358
  virtual void RuntimeInferShape(const Scope& scope,
X
Xin Pan 已提交
359 360
                                 const platform::Place& place,
                                 const RuntimeContext& ctx) const {}
361

Z
Zhang Ting 已提交
362 363 364 365 366
  virtual platform::Place GetExecutionPlace(
      const platform::Place& place) const {
    return place;
  }

367 368 369 370
  uint64_t Id() const { return id_; }

  void SetId(uint64_t id) { id_ = id; }

Q
qiaolongfei 已提交
371
 protected:
Q
Qiao Longfei 已提交
372
  std::string type_;
D
dongzhihong 已提交
373
  // NOTE: in case of OpGrad, inputs_ contains:
374
  // I (Inputs)
D
dongzhihong 已提交
375 376
  // O (Outputs)
  // OG (Output Gradients)
Y
Yu Yang 已提交
377
  VariableNameMap inputs_;
Y
Yu Yang 已提交
378

D
dongzhihong 已提交
379 380
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Y
Yu Yang 已提交
381
  VariableNameMap outputs_;
Q
Qiao Longfei 已提交
382
  AttributeMap attrs_;
383 384 385 386 387 388
  // NOTE: runtime_attrs_ contains the attributes which used for dispatching
  // kernel (use_mkldnn, use_cudnn, ...) or passing additional configuration
  // for special heterogeneous kernel (workspace_size_MB, ...).
  // The attributes in runtime_attrs_ are setted by framework (such as PASS),
  // and not in the python api.
  AttributeMap runtime_attrs_;
S
sneaxiy 已提交
389 390 391 392

  // OpInfo
  const OpInfo* info_;

393 394 395
  // OpDesc Id
  uint64_t id_ = UINT64_MAX;

396 397
  // Whether this operator executes in an Executor.
  bool run_by_executor_{true};
398 399 400 401

 private:
  void GenerateTemporaryNames();
  void CheckAllInputOutputSet() const;
402 403
  virtual void RunImpl(const Scope& scope,
                       const platform::Place& place) const = 0;
Y
Yan Chunwei 已提交
404 405
};

406
class ExecutionContext : public phi::KernelContext {
Y
Yan Chunwei 已提交
407
 public:
408 409
  ExecutionContext(const OperatorBase& op,
                   const Scope& scope,
X
Xin Pan 已提交
410
                   const platform::DeviceContext& device_context,
411 412
                   const RuntimeContext& ctx)
      : op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {}
H
hong 已提交
413
  virtual ~ExecutionContext() {}
414

H
hong 已提交
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
  virtual std::string InputName(const std::string& name) const {
    return op_.Input(name);
  }
  virtual std::vector<std::string> InputNames(const std::string& name) const {
    return op_.Inputs(name);
  }
  virtual std::string OutputName(const std::string& name) const {
    return op_.Output(name);
  }

  virtual std::vector<std::string> OutputNames(const std::string& name) const {
    return op_.Outputs(name);
  }

  virtual bool HasAttr(const std::string& name) const {
    return op_.HasAttr(name);
  }
  virtual const AttributeMap& Attrs() const { return op_.Attrs(); }

  const std::string& Type() const { return op_.Type(); }
Q
qiaolongfei 已提交
435 436 437

  const Scope& scope() const { return scope_; }

Q
qiaolongfei 已提交
438
  template <typename T>
Y
Yu Yang 已提交
439
  inline const T& Attr(const std::string& name) const {
R
Ruibiao Chen 已提交
440
    return PADDLE_GET_CONST(T, GetAttr(name));
Q
qiaolongfei 已提交
441 442
  }

H
hong 已提交
443
  virtual const Attribute& GetAttr(const std::string& name) const {
444 445
    auto iter = op_.Attrs().find(name);
    if (iter == op_.Attrs().end()) {
446 447 448 449 450 451 452 453
      iter = op_.RuntimeAttrs().find(name);
      PADDLE_ENFORCE_NE(
          iter,
          op_.RuntimeAttrs().end(),
          platform::errors::NotFound("(%s) is not found in AttributeMap and "
                                     "RuntimeAttributeMap of (%s) operator.",
                                     name,
                                     op_.Type()));
454
    }
455
    return iter->second;
H
hong 已提交
456
  }
457

H
hong 已提交
458
  virtual bool HasInput(const std::string& name) const;
459

460 461
  virtual bool HasInputs(const std::string& name) const;

H
hong 已提交
462
  virtual bool HasOutput(const std::string& name) const;
463

H
hong 已提交
464
  virtual size_t InputSize(const std::string& name) const {
Y
Yu Yang 已提交
465
    return op_.Inputs(name).size();
Y
Yan Chunwei 已提交
466 467
  }

H
hong 已提交
468
  virtual size_t OutputSize(const std::string& name) const {
Y
Yu Yang 已提交
469
    return op_.Outputs(name).size();
Y
Yan Chunwei 已提交
470 471
  }

H
hong 已提交
472
  virtual const Variable* InputVar(const std::string& name) const;
Y
Yan Chunwei 已提交
473

H
hong 已提交
474
  virtual Variable* OutputVar(const std::string& name) const;
Y
Yan Chunwei 已提交
475

H
hong 已提交
476
  virtual const std::vector<Variable*> MultiInputVar(
477
      const std::string& name) const {
478 479
    LogVarUsageIfUnusedVarCheckEnabled(name);

X
Xin Pan 已提交
480 481 482 483
    auto it = ctx_.inputs.find(name);
    if (it == ctx_.inputs.end()) {
      return {};
    }
G
Gabor Buella 已提交
484
    return {it->second.begin(), it->second.end()};
X
Xin Pan 已提交
485 486
  }

H
hong 已提交
487
  virtual std::vector<Variable*> MultiOutputVar(const std::string& name) const {
X
Xin Pan 已提交
488 489 490 491 492 493 494
    auto it = ctx_.outputs.find(name);
    if (it == ctx_.outputs.end()) {
      return {};
    }
    return it->second;
  }

C
Chen Weihang 已提交
495 496
  virtual paddle::small_vector<const std::string*> InNameList() const {
    paddle::small_vector<const std::string*> vec_temp;
H
hong 已提交
497 498 499
    vec_temp.reserve(ctx_.inputs.size());

    for (auto& input : ctx_.inputs) {
500
      vec_temp.push_back(&input.first);
H
hong 已提交
501 502 503 504 505
    }

    return vec_temp;
  }

506 507
  template <typename T>
  const T* Input(const std::string& name) const {
Y
Yu Yang 已提交
508
    auto* var = InputVar(name);
509
    return var == nullptr ? nullptr : &var->Get<T>();
510 511 512 513
  }

  template <typename T>
  T* Output(const std::string& name) const {
514
    auto var = OutputVar(name);
515
    return var == nullptr ? nullptr : var->GetMutable<T>();
516 517 518 519
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
520 521
    LogVarUsageIfUnusedVarCheckEnabled(name);

H
hong 已提交
522 523
    auto vars = MultiInputVar(name);
    if (vars.size() == 0) {
X
Xin Pan 已提交
524 525 526 527
      return {};
    }
    std::vector<const T*> res;
    res.reserve(vars.size());
528 529 530
    std::transform(vars.begin(),
                   vars.end(),
                   std::back_inserter(res),
H
hong 已提交
531
                   [&](const Variable* var) -> const T* {
X
Xin Pan 已提交
532 533 534 535 536 537 538
                     return var == nullptr ? nullptr : &var->Get<T>();
                   });
    return res;
  }

  template <typename T>
  std::vector<T*> MultiOutput(const std::string& name) const {
H
hong 已提交
539 540 541
    auto vars = MultiOutputVar(name);

    if (vars.size() == 0) {
X
Xin Pan 已提交
542 543
      return {};
    }
H
hong 已提交
544

X
Xin Pan 已提交
545 546
    std::vector<T*> res;
    res.reserve(vars.size());
547 548 549
    std::transform(vars.begin(),
                   vars.end(),
                   std::back_inserter(res),
X
Xin Pan 已提交
550 551 552
                   [&](Variable* var) -> T* {
                     return var == nullptr ? nullptr : var->GetMutable<T>();
                   });
H
hong 已提交
553

X
Xin Pan 已提交
554 555 556
    return res;
  }

557
  platform::Place GetPlace() const { return device_context_.GetPlace(); }
Q
qijun 已提交
558

Q
QI JUN 已提交
559 560 561 562 563
  template <typename DeviceContextType>
  const DeviceContextType& device_context() const {
    return *reinterpret_cast<const DeviceContextType*>(&device_context_);
  }

564
  const platform::DeviceContext& device_context() const {
Q
qijun 已提交
565
    return device_context_;
Q
qijun 已提交
566
  }
Q
qijun 已提交
567

568
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
L
Leo Chen 已提交
569
  const inline phi::GPUContext& cuda_device_context() const {
570 571
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(device_context_.GetPlace()),
                      true,
572 573
                      platform::errors::PreconditionNotMet(
                          "Current device context place is not GPUPlace."));
L
Leo Chen 已提交
574
    return *reinterpret_cast<const phi::GPUContext*>(&device_context_);
Q
QI JUN 已提交
575 576 577
  }
#endif

X
Xin Pan 已提交
578
  template <typename T, typename DevContext>
579 580
  phi::DenseTensor AllocateTmpTensor(const framework::DDim& dim,
                                     const DevContext& dev_ctx) const {
581 582 583 584
    phi::DenseTensor tmp;
    tmp.Resize(dim);
    dev_ctx.template Alloc<T>(&tmp);
    return tmp;
X
Xin Pan 已提交
585 586
  }

H
hong 已提交
587 588 589
  const RuntimeContext Context() const { return ctx_; }

  std::string DebugString() const { return op_.DebugString(); }
590
  const OperatorBase& GetOp() const { return op_; }
H
hong 已提交
591

592
 private:
593 594
  const OperatorBase& op_;
  const Scope& scope_;
595
  const platform::DeviceContext& device_context_;
X
Xin Pan 已提交
596
  const RuntimeContext& ctx_;
Q
Qiao Longfei 已提交
597 598
};

599
// TODO(chenweihang): split impl based OpProto or Dygraph if needed
600
class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
601 602 603 604 605
 public:
  explicit ExecutionArgumentMappingContext(const ExecutionContext& ctx)
      : ctx_(ctx) {}

  bool HasInput(const std::string& name) const override {
606
    return ctx_.HasInputs(name);
607 608 609 610 611 612
  }

  bool HasOutput(const std::string& name) const override {
    return ctx_.HasOutput(name);
  }

613 614 615 616
  bool HasAttr(const std::string& name) const override {
    return ctx_.HasAttr(name);
  }

617 618 619
  paddle::any Attr(const std::string& name) const override {
    auto& attr = ctx_.GetAttr(name);
    return GetAttrValue(attr);
620 621 622
  }

  size_t InputSize(const std::string& name) const override {
623
    return ctx_.MultiInputVar(name).size();
624 625 626
  }

  size_t OutputSize(const std::string& name) const override {
627
    return ctx_.MultiOutputVar(name).size();
628 629 630
  }

  bool IsDenseTensorInput(const std::string& name) const override {
631 632 633 634 635
    const auto* var = ctx_.InputVar(name);
    return var->IsType<phi::DenseTensor>();
  }

  bool IsDenseTensorInputs(const std::string& name) const override {
636 637 638 639
    auto vars = ctx_.MultiInputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::DenseTensor>();
    });
640 641
  }

Y
YuanRisheng 已提交
642 643 644 645 646 647 648
  bool IsSelectedRowsInputs(const std::string& name) const override {
    auto vars = ctx_.MultiInputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::SelectedRows>();
    });
  }

649
  bool IsSelectedRowsInput(const std::string& name) const override {
650 651
    const auto* var = ctx_.InputVar(name);
    return var->IsType<phi::SelectedRows>();
652 653
  }

654
  bool IsDenseTensorVectorInput(const std::string& name) const override {
655 656 657 658
    auto vars = ctx_.MultiInputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<framework::LoDTensorArray>();
    });
659 660
  }

661 662 663 664 665
  bool IsSparseCooTensorInput(const std::string& name) const override {
    const auto* var = ctx_.InputVar(name);
    return var->IsType<phi::SparseCooTensor>();
  }

666 667 668 669 670 671 672
  bool IsSparseCooTensorOutput(const std::string& name) const override {
    auto vars = ctx_.MultiOutputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::SparseCooTensor>();
    });
  }

673 674 675 676 677
  bool IsSparseCsrTensorInput(const std::string& name) const override {
    const auto* var = ctx_.InputVar(name);
    return var->IsType<phi::SparseCsrTensor>();
  }

678
  bool IsDenseTensorOutput(const std::string& name) const override {
679 680 681 682
    auto vars = ctx_.MultiOutputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::DenseTensor>();
    });
683 684 685
  }

  bool IsSelectedRowsOutput(const std::string& name) const override {
686 687 688 689
    auto vars = ctx_.MultiOutputVar(name);
    return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
      return var->IsType<phi::SelectedRows>();
    });
690 691
  }

692 693
  bool IsForInferShape() const override { return false; }

694 695 696 697
 private:
  const ExecutionContext& ctx_;
};

698
template <>
699 700
const std::vector<const phi::DenseTensor*>
ExecutionContext::MultiInput<phi::DenseTensor>(const std::string& name) const;
701 702

template <>
703
std::vector<phi::DenseTensor*> ExecutionContext::MultiOutput<phi::DenseTensor>(
704 705
    const std::string& name) const;

Y
Yu Yang 已提交
706
class OpKernelBase {
Q
qijun 已提交
707
 public:
Q
qijun 已提交
708
  /**
709
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
710 711
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
712
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
713 714
   */

715
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
716

Y
Yu Yang 已提交
717 718 719 720 721 722 723
  virtual ~OpKernelBase() = default;
};

template <typename T>
class OpKernel : public OpKernelBase {
 public:
  using ELEMENT_TYPE = T;
Y
Yu Yang 已提交
724 725
};

Y
Yu Yang 已提交
726 727
class OperatorWithKernel : public OperatorBase {
 public:
Y
yuyang18 已提交
728
  using OpKernelFunc = std::function<void(const ExecutionContext&)>;
Y
Yu Yang 已提交
729
  using OpKernelMap =
Y
yuyang18 已提交
730
      std::unordered_map<OpKernelType, OpKernelFunc, OpKernelType::Hash>;
Q
Qiao Longfei 已提交
731

732 733 734
  OperatorWithKernel(const std::string& type,
                     const VariableNameMap& inputs,
                     const VariableNameMap& outputs,
735 736 737
                     const AttributeMap& attrs);

  virtual ~OperatorWithKernel();
Y
Yu Yang 已提交
738

C
chentianyu03 已提交
739
  static paddle::flat_hash_map<std::string /* op_type */, OpKernelMap>&
Y
Yu Yang 已提交
740
  AllOpKernels() {
C
chentianyu03 已提交
741
    static paddle::flat_hash_map<std::string, OpKernelMap> g_all_op_kernels;
Y
Yu Yang 已提交
742
    return g_all_op_kernels;
Y
Yu Yang 已提交
743
  }
Y
Yan Chunwei 已提交
744

745 746
  bool SupportGPU() const override;

747 748
  bool SupportXPU() const override;

749
  bool SupportsMKLDNN(phi::DataType data_type) const;
750

751
  bool SupportsCUDNN(phi::DataType data_type) const;
752

753 754
  bool SupportsKernelType(const OpKernelType& kernel_type,
                          const ExecutionContext& exe_ctx) const;
755

756 757 758
  bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
                       phi::DataType data_type) const;

759 760
  bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
                       proto::VarType::Type data_type) const;
761

762 763 764
  bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
                      phi::DataType data_type) const;

765 766 767
  bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
                      proto::VarType::Type data_type) const;

768
  virtual void InferShape(InferShapeContext* ctx) const;
Y
Yu Yang 已提交
769

P
pangengzheng 已提交
770 771 772 773
  void SetIsRuntimeInferShape(bool x) override {
    all_kernels_must_compute_runtime_shape_ = x;
  }

774 775
  void RuntimeInferShape(const Scope& scope,
                         const platform::Place& place,
X
Xin Pan 已提交
776
                         const RuntimeContext& ctx) const override;
B
baojun-nervana 已提交
777

778 779 780
  proto::VarType::Type IndicateVarDataType(const ExecutionContext& ctx,
                                           const std::string& name) const;

781
  proto::VarType::Type IndicateOrPromoteVarDataTypes(
782 783
      const ExecutionContext& ctx,
      const std::string& name1,
784 785
      const std::string& name2) const;

786 787
  virtual phi::KernelKey GetExpectedKernelType(
      const ExecutionContext& ctx) const;
X
Xin Pan 已提交
788

789 790
  // change this to public so that in dygraph mode we can call it to check if we
  // need transform data
791
  virtual phi::KernelKey GetKernelTypeForVar(
792
      const std::string& var_name,
793
      const phi::DenseTensor& tensor,
794
      const phi::KernelKey& expected_kernel_type) const;
Y
Yu Yang 已提交
795

796 797
  platform::Place GetExecutionPlace(
      const platform::Place& platform) const override {
Z
Zhang Ting 已提交
798 799 800
    return kernel_type_->place_;
  }

801
  /* member functions for adapting to phi lib */
802 803 804 805 806 807 808
  /** In the phi::DenseTensor calculation library, the new Kernel adopts a
   * clearer and more streamlined design. The arguments of the Kernel and the
   * input and output arguments registered in the original OpMaker do not match
   * in some cases, so we use map to record the arguments required by the
   * kernel. When selecting Kernel during Op execution, select the arguments of
   * the original Op according to the GetExpectedPhiKernelArgs returned
   * arguments.
809
   */
810
  phi::KernelSignature GetExpectedPhiKernelArgs(
811 812
      const ExecutionContext& ctx) const;

813 814
  /* member functions for adapting to phi lib */
  phi::KernelKey ChoosePhiKernel(const ExecutionContext& ctx) const;
815

816
  void ChooseKernel(const ExecutionContext& ctx) const;
817

818 819
  void BuildPhiKernelContext(const RuntimeContext& ctx,
                             platform::DeviceContext* dev_ctx,
820
                             phi::KernelContext* phi_kernel_context) const;
821

822
  phi::KernelSignature* PhiKernelSignature() const {
823
    return kernel_signature_.get();
824 825
  }

826
  phi::Kernel* PhiKernel() const { return phi_kernel_.get(); }
827

828
  void ResetPhiKernel(phi::Kernel* kernel) const {
829
    return phi_kernel_.reset(kernel);
830 831
  }

832
  const OpKernelType* kernel_type() const { return kernel_type_.get(); }
833
  const OpKernelFunc* kernel_func() const { return kernel_func_.get(); }
834

835 836 837 838
  void ResetKernelType(OpKernelType* kernel_type) {
    kernel_type_.reset(kernel_type);
  }

839 840 841 842
  bool DnnFallback() const { return dnn_fallback_; }

  void SetDnnFallback(bool dnn_fallback) const { dnn_fallback_ = dnn_fallback; }

Y
Yu Yang 已提交
843
 private:
844
  void RunImpl(const Scope& scope, const platform::Place& place) const final;
845 846
  void RunImpl(const Scope& scope,
               const platform::Place& place,
L
luotao1 已提交
847
               RuntimeContext* runtime_ctx) const;
Y
yuyang18 已提交
848 849

  /**
T
tianshuo78520a 已提交
850
   * Transfer data from scope to a transferred scope. If there is no data need
851
   * to be transferred, it returns nullptr.
Y
yuyang18 已提交
852
   *
853
   * transfered_inplace_vars is a output vector.
Y
yuyang18 已提交
854
   */
X
Xin Pan 已提交
855
  Scope* PrepareData(const Scope& scope,
856
                     const phi::KernelKey& expected_kernel_key,
X
Xin Pan 已提交
857
                     std::vector<std::string>* transfered_inplace_vars,
858 859
                     RuntimeContext* ctx,
                     const phi::Place& place) const;
Y
yuyang18 已提交
860

861 862 863 864
  void CheckWhetherPreparePhiData(const VariableNameMap& innames,
                                  const VariableNameMap& outnames,
                                  const Scope& scope) const;

Y
yuyang18 已提交
865 866 867
  void TransferInplaceVarsBack(const Scope& scope,
                               const std::vector<std::string>& inplace_vars,
                               const Scope& exec_scope) const;
868

869 870
  OpKernelType InnerGetExpectedKernelType(const ExecutionContext& ctx) const;

871 872 873
  void HandleComplexGradToRealGrad(const Scope& scope,
                                   RuntimeContext* ctx) const;

874 875 876 877 878
  /* Inner assist methods */
  // indicate kernel DataType by input data.
  // By default all input data must be same.
  proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
  // used for IndicateDataType
879 880
  void ParseInputDataType(const Variable* vars,
                          const std::string& name,
881
                          proto::VarType::Type* data_type) const;
882 883 884
  void ParseMultiInputDataType(const std::vector<Variable*>& vars,
                               const std::string& name,
                               proto::VarType::Type* data_type) const;
885
  // used for IndicateOrPromoteVarDataTypes
886 887
  phi::DenseTensor* GetTensorFormInputSafely(const ExecutionContext& ctx,
                                             const std::string& name) const;
888

889
 protected:
L
Liu Yiqun 已提交
890 891
  mutable std::unique_ptr<OpKernelType> kernel_type_;
  mutable std::unique_ptr<OpKernelFunc> kernel_func_;
L
luotao1 已提交
892
  mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
C
csy0225 已提交
893
  mutable const Scope* pre_scope_ = nullptr;
894
  mutable bool need_prepare_data_ = true;
895
  mutable bool need_prepare_phi_data_ = false;
896 897
  mutable bool enable_cache_runtime_context_ = false;
  mutable bool all_kernels_must_compute_runtime_shape_ = false;
898
  mutable std::mutex cache_update_mutex_;
899
  mutable bool enable_cache_transfer_scope_ = false;
900 901 902 903
  // NOTE(jiahongyu): Whether fallback to plain kernel after calling
  // GetExpectedKernelType, use this bool flag to solve mkldnn and cudnn hard
  // code
  mutable bool dnn_fallback_ = false;
904
  // NOTE(chenweihang): Similar op members are used to adapt to
905
  // new phi kernel, if there is a better design in the future,
906
  // we may polish the implementation here
907
  mutable bool run_phi_kernel_ = false;
L
Liu-xiandong 已提交
908
  mutable bool run_kp_kernel = false;
909
  mutable std::unique_ptr<phi::KernelSignature> kernel_signature_;
910
  mutable std::unique_ptr<phi::Kernel> phi_kernel_;
911
  mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;
912

913
 private:
914
  struct CacheImpl;
915
  mutable std::unique_ptr<CacheImpl> impl_;
Q
Qiao Longfei 已提交
916 917
};

Y
Yu Yang 已提交
918 919
extern bool OpSupportGPU(const std::string& op_type);

Q
Qiao Longfei 已提交
920 921
}  // namespace framework
}  // namespace paddle