new_executor_defs.h 14.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include <map>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/operator.h"
22
#include "paddle/fluid/framework/variable_helper.h"
23
#include "paddle/fluid/platform/device_event_base.h"
24
#include "paddle/fluid/platform/event.h"
25
#include "paddle/phi/core/utils/rw_lock.h"
26

L
Leo Chen 已提交
27 28 29
#define SCOPE_VARS_READER_LOCK AutoRDLock auto_lock(&vars_lock_);
#define SCOPE_VARS_WRITER_LOCK AutoWRLock auto_lock(&vars_lock_);

30 31 32 33 34
namespace paddle {
namespace framework {

using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;

W
wanghuancoder 已提交
35 36
constexpr int kEmptyVarIndex = 0;

37 38 39
class InterpretercoreInferShapeContext : public InferShapeContext {
 public:
  InterpretercoreInferShapeContext(const OperatorBase& op,
L
Leo Chen 已提交
40
                                   const RuntimeContext& ctx);
41

L
Leo Chen 已提交
42
  bool HasInput(const std::string& name) const override;
43

L
Leo Chen 已提交
44
  bool HasOutput(const std::string& name) const override;
45

46 47
  bool HasAttr(const std::string& name) const override;

L
Leo Chen 已提交
48
  bool HasInputs(const std::string& name) const override;
49

50 51
  bool HasOutputs(const std::string& name,
                  bool allow_null = false) const override;
52

L
Leo Chen 已提交
53
  AttrReader Attrs() const override;
54

L
Leo Chen 已提交
55
  std::vector<std::string> Inputs(const std::string& name) const override;
56

L
Leo Chen 已提交
57
  std::vector<std::string> Outputs(const std::string& name) const override;
58

L
Leo Chen 已提交
59 60 61
  std::string GetInputNameByIdx(size_t idx) const override;

  std::string GetOutputNameByIdx(size_t idx) const override;
62

63 64 65
  void ShareDim(const std::string& in,
                const std::string& out,
                size_t i = 0,
L
Leo Chen 已提交
66
                size_t j = 0) override;
67 68

  void ShareAllLoD(const std::string& in,
L
Leo Chen 已提交
69
                   const std::string& out) const override;
70

71 72 73
  void ShareLoD(const std::string& in,
                const std::string& out,
                size_t i = 0,
L
Leo Chen 已提交
74
                size_t j = 0) const override;
75

L
Leo Chen 已提交
76
  int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override;
77

78 79
  void SetLoDLevel(const std::string& out,
                   int32_t lod_level,
L
Leo Chen 已提交
80
                   size_t j = 0) const override;
81

L
Leo Chen 已提交
82
  bool IsRuntime() const override;
83

84 85
  bool IsRunMKLDNNKernel() const override;

86
  // TODO(paddle-dev): Can this be template?
C
Chen Weihang 已提交
87
  paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
88
  GetInputVarPtrs(const std::string& name) const override;
89

C
Chen Weihang 已提交
90
  paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
91
  GetOutputVarPtrs(const std::string& name) const override;
92

L
Leo Chen 已提交
93
  DDim GetInputDim(const std::string& name) const override;
94

L
Leo Chen 已提交
95
  std::vector<DDim> GetInputsDim(const std::string& name) const override;
96

97 98
  proto::VarType::Type GetInputVarType(const std::string& name) const override;

99
  std::vector<proto::VarType::Type> GetInputsVarType(
L
Leo Chen 已提交
100
      const std::string& name) const override;
101 102

  std::vector<proto::VarType::Type> GetOutputsVarType(
L
Leo Chen 已提交
103
      const std::string& name) const override;
104

L
Leo Chen 已提交
105
  void SetOutputDim(const std::string& name, const DDim& dim) override;
106 107

  void SetOutputsDim(const std::string& name,
L
Leo Chen 已提交
108
                     const std::vector<DDim>& dims) override;
109

110 111 112 113
  const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override;

  const phi::KernelSignature* GetPhiDefaultKernelSignature() const override;

L
Leo Chen 已提交
114
  void SetSkipLoD(bool skip);
115 116

 protected:
L
Leo Chen 已提交
117
  DDim GetDim(Variable* var) const;
118

L
Leo Chen 已提交
119
  std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const;
120

L
Leo Chen 已提交
121
  std::vector<DDim> GetRepeatedDims(const std::string& name) const override;
122

L
Leo Chen 已提交
123
  void SetDim(Variable* var, const DDim& dim);
124 125

  void SetDims(const std::vector<Variable*>& vars,
L
Leo Chen 已提交
126
               const std::vector<DDim>& dims);
127 128

  void SetRepeatedDims(const std::string& name,
L
Leo Chen 已提交
129
                       const std::vector<DDim>& dims) override;
130 131

  std::vector<proto::VarType::Type> GetVarTypes(
L
Leo Chen 已提交
132
      const std::vector<Variable*>& vars) const;
133

L
Leo Chen 已提交
134
  proto::VarType::Type GetVarType(Variable* var) const;
135 136

 private:
L
Leo Chen 已提交
137
  const std::vector<Variable*>& InputVars(const std::string& name) const;
138

L
Leo Chen 已提交
139
  const std::vector<Variable*>& OutputVars(const std::string& name) const;
140 141 142 143 144 145

  const OperatorBase& op_;
  const RuntimeContext& ctx_;
  bool can_skip_lod_;
};

146 147 148 149 150
struct OpKernelFunc {
  OpKernelComputeFunc compute_func_;
};

struct VariableMetaInfo {
L
Leo Chen 已提交
151 152
  int var_ref_count_{0};
  framework::VarDesc* var_desc_{nullptr};
153
  bool sikp_inplace_{false};
L
Leo Chen 已提交
154 155 156 157 158 159

  VariableMetaInfo() {}
  VariableMetaInfo(int var_ref_count, framework::VarDesc* var_desc)
      : var_ref_count_(var_ref_count), var_desc_(var_desc) {}
};

160
class VariableScope {
161
 public:
L
Leo Chen 已提交
162
  explicit VariableScope(Scope* scope);
163

164 165 166 167
  Scope* GetMutableScope() const;

  Scope* GetMutableLocalScope() const;

168 169
  void SetScope(Scope* scope);

170
  void SetLocalScope(Scope* local_scope);
171

L
Leo Chen 已提交
172
  ~VariableScope();
173 174

  // Get variable id by name, return -1 if not found
L
Leo Chen 已提交
175
  int GetIdByName(const std::string& name) const;
176 177

  // Get variable name by id, return "" if not found
L
Leo Chen 已提交
178
  std::string GetNameById(int id) const;
179

L
Leo Chen 已提交
180
  bool HasVar(const std::string& name) const;
181

L
Leo Chen 已提交
182
  int VarId(const std::string& name) const;
183

L
Leo Chen 已提交
184
  size_t VarSize() const;
185

186
  void AddVar(const std::string& name, VarDesc* var_desc);
187

188
  Variable* VarRef(int id) const;
189

L
Leo Chen 已提交
190
  void SetVarDesc(const std::string& name, framework::VarDesc* var_desc);
191

L
Leo Chen 已提交
192
  paddle::framework::VarDesc* VarDesc(const std::string& name) const;
193

L
Leo Chen 已提交
194
  paddle::framework::VarDesc* VarDesc(int id) const;
195

L
Leo Chen 已提交
196 197 198
  void CheckExist(int id) const;

  void CheckExist(const std::string& name) const;
199

200 201 202 203 204 205
  std::vector<VariableMetaInfo>& MutableVecMetaInfo() { return vec_meta_info_; }

  const std::vector<VariableMetaInfo>& VecMetaInfo() const {
    return vec_meta_info_;
  }

206 207 208 209 210 211 212 213 214 215 216
  const std::vector<std::pair<std::string, int>>& DataTransferAddedVars()
      const {
    return data_transfer_added_vars_;
  }

  std::vector<std::pair<std::string, int>>& MutableDataTransferAddedVars() {
    return data_transfer_added_vars_;
  }

  std::vector<Variable*>& MutableVarList() { return var_list_; }

217 218 219 220
  void SetVarSikpInplace(const std::string& name, bool skip);

  bool GetVarSikpInplace(int id) const;

221
 private:
222 223
  // not owned, better remove it since all vars should be
  // accessed by Scope instead of VariableScope
224
  std::vector<Variable*> var_list_;
225

226
  std::map<std::string, int> name2id_;
W
wanghuancoder 已提交
227
  std::vector<VariableMetaInfo> vec_meta_info_;
228

229 230 231
  Scope* scope_{nullptr};
  // TODO(zhiqiu): find a better way to support local scope.
  Scope* local_scope_{nullptr};
L
Leo Chen 已提交
232
  // mutable RWLock vars_lock_;
233 234 235

  // var_name -> var_type
  std::vector<std::pair<std::string, int>> data_transfer_added_vars_;
236 237
};

238
class NextInstructionList {
239 240 241 242 243 244 245 246 247 248 249 250 251 252
 public:
  void AddDirectRun(size_t id) { direct_run_.push_back(id); }

  void ADDEventRun(size_t id) { event_wait_run_.push_back(id); }

  void AddSyncRun(size_t id) { synchronize_run_.push_back(id); }

  const std::vector<size_t>& DirectRunIds() const { return direct_run_; }

  const std::vector<size_t>& EventRunIds() const { return event_wait_run_; }

  const std::vector<size_t>& SyncRunIds() const { return synchronize_run_; }

 private:
253
  std::vector<size_t> direct_run_;
254 255
  std::vector<size_t> event_wait_run_;
  std::vector<size_t> synchronize_run_;
256 257
};

258
struct EventInter {
259 260
  explicit EventInter(size_t var_id,
                      std::shared_ptr<platform::DeviceEvent> event,
261 262
                      platform::DeviceType waiter_type)
      : var_id_(var_id), event_(event), waiter_type_(waiter_type) {}
263
  size_t var_id_;
264
  std::shared_ptr<platform::DeviceEvent> event_;
265
  platform::DeviceType waiter_type_;
266
};
267

268 269
enum class OpFuncType {
  kQueueSync = 0,   // CPU kernel, block host
270
  kQueueAsync = 1,  // GPU、XPU Kernel or d2h, h2d, send, recv, broadcast
271 272 273
};
class RuntimeInferShapeContext;

274
struct OpFuncNode {
L
Leo Chen 已提交
275 276
  // TODO(zhiqiu): Better make it unique_ptr
  std::shared_ptr<OperatorBase> operator_base_;
277 278 279 280
  std::map<std::string, std::vector<int>> input_index;
  std::map<std::string, std::vector<int>> output_index;
  std::unordered_set<int> no_data_transform_index;

281 282
  std::map<int, int> inplace_back_map;

283 284
  OpKernelComputeFunc kernel_func_;
  platform::DeviceContext* dev_ctx_;  // not owned
285

286
  // fit for phi kernel
287
  phi::Kernel* phi_kernel_{nullptr};  // not owned
288

289 290 291 292 293
  OpFuncType type_;
};

class Instruction {
 public:
294 295
  Instruction(size_t id,
              OpFuncNode&& op_func_node,
L
Leo Chen 已提交
296
              const platform::DeviceContext& dev_ctx);
297

L
Leo Chen 已提交
298
  size_t Id() const;
299

L
Leo Chen 已提交
300
  const std::map<std::string, std::vector<int>>& Inputs() const;
301

L
Leo Chen 已提交
302
  const std::map<std::string, std::vector<int>>& Outputs() const;
303

L
Leo Chen 已提交
304
  const std::unordered_set<int>& NoDataTransformVars() const;
305

L
Leo Chen 已提交
306
  OpKernelComputeFunc KernelFunc() const;
307

308
  phi::Kernel* PhiKernel() const;
309

L
Leo Chen 已提交
310
  OpFuncType KernelType() const;
311

312 313
  const std::map<int, int>& InplaceBackMap() const;

L
Leo Chen 已提交
314
  OperatorBase* OpBase() const;
315

316
  NextInstructionList& NextInstructions();
317

318
  const NextInstructionList& NextInstructions() const;
319

L
Leo Chen 已提交
320
  void AddGCCheckVar(size_t id);
321

L
Leo Chen 已提交
322
  const std::vector<size_t>& GCCheckVars() const;
323 324

  void ResetContext(const VariableValueMap& in_vars,
L
Leo Chen 已提交
325
                    const VariableValueMap& out_vars);
326

327 328 329 330
  void ResetContextWithScope(const VariableValueMap& in_vars,
                             const VariableValueMap& out_vars,
                             const framework::Scope& scope);

L
Leo Chen 已提交
331
  std::shared_ptr<RuntimeContext> InnerRuntimeContext() const;
332 333

  std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext()
L
Leo Chen 已提交
334
      const;
335

L
Leo Chen 已提交
336
  std::shared_ptr<ExecutionContext> InnerExecutionContext() const;
337

L
Leo Chen 已提交
338
  const platform::DeviceContext& DeviceContext() const;
339

L
Leo Chen 已提交
340
  const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const;
341

L
Leo Chen 已提交
342
  void AddInplace(Variable* in, Variable* out);
343

344 345
  void ClearInplace();

L
Leo Chen 已提交
346
  const std::vector<EventInter>& InputEvents() const;
347

L
Leo Chen 已提交
348
  const std::vector<EventInter>& OutputEvents() const;
349 350 351

  void AddInputEvent(size_t var_id,
                     std::shared_ptr<platform::DeviceEvent> event,
L
Leo Chen 已提交
352
                     platform::DeviceType waiter_type);
353 354 355

  void AddOutputEvent(size_t var_id,
                      std::shared_ptr<platform::DeviceEvent> event,
L
Leo Chen 已提交
356
                      platform::DeviceType waiter_type);
357 358 359

 private:
  size_t id_;
L
Leo Chen 已提交
360
  OpFuncNode op_func_node_;
361 362
  const platform::DeviceContext& dev_ctx_;  // not owned

363
  std::shared_ptr<RuntimeContext> runtime_ctx_;
364
  std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_;
365
  std::shared_ptr<ExecutionContext> execution_ctx_;
366

367 368 369
  std::vector<size_t> gc_check_vars_;

  NextInstructionList next_instruction_;
370 371 372 373

  std::vector<EventInter> intput_events_;
  std::vector<EventInter> output_events_;

374
  std::vector<std::pair<Variable*, Variable*>> vec_inplace_in_to_out_;
375 376
};

377
namespace interpreter {
378 379
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
380
static constexpr char kFetchVarName[] = "fetch";
381 382

static bool IsMemcpyH2D(const Instruction& instr) {
383
  return instr.OpBase()->Type() == kMemcpyH2D;
384 385 386
}

static bool IsMemcpyD2H(const Instruction& instr) {
387
  return instr.OpBase()->Type() == kMemcpyD2H;
388
}
L
Leo Chen 已提交
389 390 391 392 393

static bool IsCpuOp(const Instruction& instr) {
  return platform::is_cpu_place(instr.DeviceContext().GetPlace());
}

394 395
// is supported heterogeneous place
static bool IsSupportedHetePlace(const phi::Place& place) {
396
  return platform::is_gpu_place(place) || platform::is_npu_place(place) ||
397 398
         platform::is_xpu_place(place) || platform::is_ipu_place(place) ||
         platform::is_custom_place(place);
399 400
}

401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
// static_ref_ is the numer of last live ops calculated to statically after
// `build` the Instructions. dynamic_ref_  is the runtime version ref which will
// be decreased by one dynamiclly after the execution of an op (in last ops
// list). var_ is the related variable

// The dynamic_ref_ is initialized to static_ref_ first, and is decreased to 1
// during interpretercore's execution, after the interpretercore run, it `reset`
// all dynamic_ref_, i.e., dynamic_ref_ = static_ref_ see ResetAtomicGuard for
// details
class VarRefInfo {
 public:
  explicit VarRefInfo(size_t ref, Variable* var)
      : static_ref_(ref), dynamic_ref_(ref), var_(var) {}
  size_t DynamicRef() { return dynamic_ref_; }
  Variable* Var() { return var_; }
  void ResetDynamicRef() {
    if (static_ref_ != 1) {
      dynamic_ref_ = static_ref_;
    }
  }
  bool CheckAndDecrease() {
    return static_ref_ == 1 || (dynamic_ref_.fetch_sub(1) == 1);
  }

 private:
  const size_t static_ref_;
  std::atomic<size_t> dynamic_ref_;
  Variable* var_;
};

// static_dep_ is the numer of dependencies (ops that must run before it) of
// each op which is calculated to statically. static_dep_  is the runtime
// version dep which will be decreased by one dynamiclly after the execution of
// one dependency op.

// The dynamic_dep_ is initialized to static_dep_ first, and is decreased to 1
// during interpretercore's execution, after the interpretercore run, it `reset`
// all dynamic_dep_, i.e., dynamic_dep_ = static_dep_ see ResetAtomicGuard for
// details

class OpDepInfo {
 public:
  explicit OpDepInfo(size_t dep) : static_dep_(dep), dynamic_dep_(dep) {}
  size_t DynamicDep() { return dynamic_dep_; }
  void ResetDynamicDep() {
    if (static_dep_ != 1) {
      dynamic_dep_ = static_dep_;
    }
  }
  bool CheckAndDecrease() {
    return static_dep_ == 1 || (dynamic_dep_.fetch_sub(1) == 1);
  }

 private:
  const size_t static_dep_;
  std::atomic<size_t> dynamic_dep_;
};

class ResetAtomicGuard {
 public:
  ResetAtomicGuard(std::vector<std::shared_ptr<OpDepInfo>>* deps,
                   std::vector<std::shared_ptr<VarRefInfo>>* refs)
      : deps_(deps), refs_(refs) {}

  ~ResetAtomicGuard() {
    VLOG(10) << "Reset DynamicDep";
    for (auto&& dep : *deps_) {
      dep->ResetDynamicDep();
    }
    VLOG(10) << "Reset DynamicRef";
    for (auto&& ref : *refs_) {
      ref->ResetDynamicRef();
    }
  }

 private:
  std::vector<std::shared_ptr<OpDepInfo>>* deps_;
  std::vector<std::shared_ptr<VarRefInfo>>* refs_;
};

481
}  // namespace interpreter
482

483 484
}  // namespace framework
}  // namespace paddle