new_executor_defs.h 14.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include <map>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/operator.h"
22
#include "paddle/fluid/framework/variable_helper.h"
23
#include "paddle/fluid/platform/device_event_base.h"
24
#include "paddle/fluid/platform/event.h"
25
#include "paddle/phi/core/utils/rw_lock.h"
26

L
Leo Chen 已提交
27 28 29
#define SCOPE_VARS_READER_LOCK AutoRDLock auto_lock(&vars_lock_);
#define SCOPE_VARS_WRITER_LOCK AutoWRLock auto_lock(&vars_lock_);

30 31 32 33 34
namespace paddle {
namespace framework {

using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;

W
wanghuancoder 已提交
35 36
constexpr int kEmptyVarIndex = 0;

37 38 39 40 41 42
// stream types
constexpr const char* kCustomStream = "CustromStream";
constexpr const char* kDefaultStream = "DefaultStream";
constexpr const char* kD2HStream = "D2HStream";
constexpr const char* kH2DStream = "H2DStream";

43 44
enum class Priority { kLowest, kNormal };

45 46 47
class InterpretercoreInferShapeContext : public InferShapeContext {
 public:
  InterpretercoreInferShapeContext(const OperatorBase& op,
L
Leo Chen 已提交
48
                                   const RuntimeContext& ctx);
49

L
Leo Chen 已提交
50
  bool HasInput(const std::string& name) const override;
51

L
Leo Chen 已提交
52
  bool HasOutput(const std::string& name) const override;
53

54 55
  bool HasAttr(const std::string& name) const override;

L
Leo Chen 已提交
56
  bool HasInputs(const std::string& name) const override;
57

58 59
  bool HasOutputs(const std::string& name,
                  bool allow_null = false) const override;
60

L
Leo Chen 已提交
61
  AttrReader Attrs() const override;
62

L
Leo Chen 已提交
63
  std::vector<std::string> Inputs(const std::string& name) const override;
64

L
Leo Chen 已提交
65
  std::vector<std::string> Outputs(const std::string& name) const override;
66

L
Leo Chen 已提交
67 68 69
  std::string GetInputNameByIdx(size_t idx) const override;

  std::string GetOutputNameByIdx(size_t idx) const override;
70

71 72 73
  void ShareDim(const std::string& in,
                const std::string& out,
                size_t i = 0,
L
Leo Chen 已提交
74
                size_t j = 0) override;
75 76

  void ShareAllLoD(const std::string& in,
L
Leo Chen 已提交
77
                   const std::string& out) const override;
78

79 80 81
  void ShareLoD(const std::string& in,
                const std::string& out,
                size_t i = 0,
L
Leo Chen 已提交
82
                size_t j = 0) const override;
83

L
Leo Chen 已提交
84
  int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override;
85

86 87
  void SetLoDLevel(const std::string& out,
                   int32_t lod_level,
L
Leo Chen 已提交
88
                   size_t j = 0) const override;
89

L
Leo Chen 已提交
90
  bool IsRuntime() const override;
91

92 93
  bool IsRunMKLDNNKernel() const override;

94
  // TODO(paddle-dev): Can this be template?
C
Chen Weihang 已提交
95
  paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
96
  GetInputVarPtrs(const std::string& name) const override;
97

C
Chen Weihang 已提交
98
  paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
99
  GetOutputVarPtrs(const std::string& name) const override;
100

L
Leo Chen 已提交
101
  DDim GetInputDim(const std::string& name) const override;
102

L
Leo Chen 已提交
103
  std::vector<DDim> GetInputsDim(const std::string& name) const override;
104

105 106
  proto::VarType::Type GetInputVarType(const std::string& name) const override;

107
  std::vector<proto::VarType::Type> GetInputsVarType(
L
Leo Chen 已提交
108
      const std::string& name) const override;
109 110

  std::vector<proto::VarType::Type> GetOutputsVarType(
L
Leo Chen 已提交
111
      const std::string& name) const override;
112

L
Leo Chen 已提交
113
  void SetOutputDim(const std::string& name, const DDim& dim) override;
114 115

  void SetOutputsDim(const std::string& name,
L
Leo Chen 已提交
116
                     const std::vector<DDim>& dims) override;
117

118 119 120 121
  const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override;

  const phi::KernelSignature* GetPhiDefaultKernelSignature() const override;

L
Leo Chen 已提交
122
  void SetSkipLoD(bool skip);
123 124

 protected:
L
Leo Chen 已提交
125
  DDim GetDim(Variable* var) const;
126

L
Leo Chen 已提交
127
  std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const;
128

L
Leo Chen 已提交
129
  std::vector<DDim> GetRepeatedDims(const std::string& name) const override;
130

L
Leo Chen 已提交
131
  void SetDim(Variable* var, const DDim& dim);
132 133

  void SetDims(const std::vector<Variable*>& vars,
L
Leo Chen 已提交
134
               const std::vector<DDim>& dims);
135 136

  void SetRepeatedDims(const std::string& name,
L
Leo Chen 已提交
137
                       const std::vector<DDim>& dims) override;
138 139

  std::vector<proto::VarType::Type> GetVarTypes(
L
Leo Chen 已提交
140
      const std::vector<Variable*>& vars) const;
141

L
Leo Chen 已提交
142
  proto::VarType::Type GetVarType(Variable* var) const;
143 144

 private:
L
Leo Chen 已提交
145
  const std::vector<Variable*>& InputVars(const std::string& name) const;
146

L
Leo Chen 已提交
147
  const std::vector<Variable*>& OutputVars(const std::string& name) const;
148 149 150 151 152 153

  const OperatorBase& op_;
  const RuntimeContext& ctx_;
  bool can_skip_lod_;
};

154 155 156 157 158
struct OpKernelFunc {
  OpKernelComputeFunc compute_func_;
};

struct VariableMetaInfo {
L
Leo Chen 已提交
159 160
  int var_ref_count_{0};
  framework::VarDesc* var_desc_{nullptr};
161
  bool sikp_inplace_{false};
L
Leo Chen 已提交
162 163 164 165 166 167

  VariableMetaInfo() {}
  VariableMetaInfo(int var_ref_count, framework::VarDesc* var_desc)
      : var_ref_count_(var_ref_count), var_desc_(var_desc) {}
};

168
class VariableScope {
169
 public:
L
Leo Chen 已提交
170
  explicit VariableScope(Scope* scope);
171

172 173 174 175
  Scope* GetMutableScope() const;

  Scope* GetMutableLocalScope() const;

176 177
  void SetScope(Scope* scope);

178
  void SetLocalScope(Scope* local_scope);
179

L
Leo Chen 已提交
180
  ~VariableScope();
181 182

  // Get variable id by name, return -1 if not found
L
Leo Chen 已提交
183
  int GetIdByName(const std::string& name) const;
184 185

  // Get variable name by id, return "" if not found
L
Leo Chen 已提交
186
  std::string GetNameById(int id) const;
187

L
Leo Chen 已提交
188
  bool HasVar(const std::string& name) const;
189

L
Leo Chen 已提交
190
  int VarId(const std::string& name) const;
191

L
Leo Chen 已提交
192
  size_t VarSize() const;
193

194
  void AddVar(const std::string& name, VarDesc* var_desc);
195

196
  Variable* VarRef(int id) const;
197

L
Leo Chen 已提交
198
  void SetVarDesc(const std::string& name, framework::VarDesc* var_desc);
199

L
Leo Chen 已提交
200
  paddle::framework::VarDesc* VarDesc(const std::string& name) const;
201

L
Leo Chen 已提交
202
  paddle::framework::VarDesc* VarDesc(int id) const;
203

L
Leo Chen 已提交
204 205 206
  void CheckExist(int id) const;

  void CheckExist(const std::string& name) const;
207

208 209 210 211 212 213
  std::vector<VariableMetaInfo>& MutableVecMetaInfo() { return vec_meta_info_; }

  const std::vector<VariableMetaInfo>& VecMetaInfo() const {
    return vec_meta_info_;
  }

214 215 216 217 218 219 220 221 222 223 224
  const std::vector<std::pair<std::string, int>>& DataTransferAddedVars()
      const {
    return data_transfer_added_vars_;
  }

  std::vector<std::pair<std::string, int>>& MutableDataTransferAddedVars() {
    return data_transfer_added_vars_;
  }

  std::vector<Variable*>& MutableVarList() { return var_list_; }

225 226 227 228
  void SetVarSikpInplace(const std::string& name, bool skip);

  bool GetVarSikpInplace(int id) const;

229
 private:
230 231
  // not owned, better remove it since all vars should be
  // accessed by Scope instead of VariableScope
232
  std::vector<Variable*> var_list_;
233

234
  std::map<std::string, int> name2id_;
W
wanghuancoder 已提交
235
  std::vector<VariableMetaInfo> vec_meta_info_;
236

237 238 239
  Scope* scope_{nullptr};
  // TODO(zhiqiu): find a better way to support local scope.
  Scope* local_scope_{nullptr};
L
Leo Chen 已提交
240
  // mutable RWLock vars_lock_;
241 242 243

  // var_name -> var_type
  std::vector<std::pair<std::string, int>> data_transfer_added_vars_;
244 245
};

246
class NextInstructionList {
247 248 249 250 251 252 253 254 255 256 257 258 259 260
 public:
  void AddDirectRun(size_t id) { direct_run_.push_back(id); }

  void ADDEventRun(size_t id) { event_wait_run_.push_back(id); }

  void AddSyncRun(size_t id) { synchronize_run_.push_back(id); }

  const std::vector<size_t>& DirectRunIds() const { return direct_run_; }

  const std::vector<size_t>& EventRunIds() const { return event_wait_run_; }

  const std::vector<size_t>& SyncRunIds() const { return synchronize_run_; }

 private:
261
  std::vector<size_t> direct_run_;
262 263
  std::vector<size_t> event_wait_run_;
  std::vector<size_t> synchronize_run_;
264 265
};

266
struct EventInter {
267 268
  explicit EventInter(size_t var_id,
                      std::shared_ptr<platform::DeviceEvent> event,
269 270
                      platform::DeviceType waiter_type)
      : var_id_(var_id), event_(event), waiter_type_(waiter_type) {}
271
  size_t var_id_;
272
  std::shared_ptr<platform::DeviceEvent> event_;
273
  platform::DeviceType waiter_type_;
274
};
275

276 277
enum class OpFuncType {
  kQueueSync = 0,   // CPU kernel, block host
278
  kQueueAsync = 1,  // GPU、XPU Kernel or d2h, h2d, send, recv, broadcast
279 280 281
};
class RuntimeInferShapeContext;

282
struct OpFuncNode {
L
Leo Chen 已提交
283 284
  // TODO(zhiqiu): Better make it unique_ptr
  std::shared_ptr<OperatorBase> operator_base_;
285
  std::string execution_stream_{kDefaultStream};
286 287 288 289
  std::map<std::string, std::vector<int>> input_index;
  std::map<std::string, std::vector<int>> output_index;
  std::unordered_set<int> no_data_transform_index;

290 291
  std::map<int, int> inplace_back_map;

292 293
  OpKernelComputeFunc kernel_func_;
  platform::DeviceContext* dev_ctx_;  // not owned
294

295
  // fit for phi kernel
296
  phi::Kernel* phi_kernel_{nullptr};  // not owned
297

298 299 300 301 302
  OpFuncType type_;
};

class Instruction {
 public:
303 304
  Instruction(size_t id,
              OpFuncNode&& op_func_node,
305 306
              const platform::DeviceContext& dev_ctx,
              const Priority priority);
307

L
Leo Chen 已提交
308
  size_t Id() const;
309

L
Leo Chen 已提交
310
  const std::map<std::string, std::vector<int>>& Inputs() const;
311

L
Leo Chen 已提交
312
  const std::map<std::string, std::vector<int>>& Outputs() const;
313

L
Leo Chen 已提交
314
  const std::unordered_set<int>& NoDataTransformVars() const;
315

L
Leo Chen 已提交
316
  OpKernelComputeFunc KernelFunc() const;
317

318
  phi::Kernel* PhiKernel() const;
319

L
Leo Chen 已提交
320
  OpFuncType KernelType() const;
321

322 323
  const std::map<int, int>& InplaceBackMap() const;

L
Leo Chen 已提交
324
  OperatorBase* OpBase() const;
325

326
  NextInstructionList& NextInstructions();
327

328
  const NextInstructionList& NextInstructions() const;
329

L
Leo Chen 已提交
330
  void AddGCCheckVar(size_t id);
331

L
Leo Chen 已提交
332
  const std::vector<size_t>& GCCheckVars() const;
333 334

  void ResetContext(const VariableValueMap& in_vars,
L
Leo Chen 已提交
335
                    const VariableValueMap& out_vars);
336

337 338 339 340
  void ResetContextWithScope(const VariableValueMap& in_vars,
                             const VariableValueMap& out_vars,
                             const framework::Scope& scope);

L
Leo Chen 已提交
341
  std::shared_ptr<RuntimeContext> InnerRuntimeContext() const;
342 343

  std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext()
L
Leo Chen 已提交
344
      const;
345

L
Leo Chen 已提交
346
  std::shared_ptr<ExecutionContext> InnerExecutionContext() const;
347

L
Leo Chen 已提交
348
  const platform::DeviceContext& DeviceContext() const;
349

L
Leo Chen 已提交
350
  const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const;
351

L
Leo Chen 已提交
352
  void AddInplace(Variable* in, Variable* out);
353

354 355
  void ClearInplace();

L
Leo Chen 已提交
356
  const std::vector<EventInter>& InputEvents() const;
357

L
Leo Chen 已提交
358
  const std::vector<EventInter>& OutputEvents() const;
359 360 361

  void AddInputEvent(size_t var_id,
                     std::shared_ptr<platform::DeviceEvent> event,
L
Leo Chen 已提交
362
                     platform::DeviceType waiter_type);
363 364 365

  void AddOutputEvent(size_t var_id,
                      std::shared_ptr<platform::DeviceEvent> event,
L
Leo Chen 已提交
366
                      platform::DeviceType waiter_type);
367

368 369
  Priority GetPriority() const { return priority_; }

370 371
 private:
  size_t id_;
L
Leo Chen 已提交
372
  OpFuncNode op_func_node_;
373
  const platform::DeviceContext& dev_ctx_;  // not owned
374
  const Priority priority_;
375

376
  std::shared_ptr<RuntimeContext> runtime_ctx_;
377
  std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_;
378
  std::shared_ptr<ExecutionContext> execution_ctx_;
379

380 381 382
  std::vector<size_t> gc_check_vars_;

  NextInstructionList next_instruction_;
383 384 385 386

  std::vector<EventInter> intput_events_;
  std::vector<EventInter> output_events_;

387
  std::vector<std::pair<Variable*, Variable*>> vec_inplace_in_to_out_;
388 389
};

390
namespace interpreter {
391 392
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
393
static constexpr char kFetchVarName[] = "fetch";
394

395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
// static_ref_ is the numer of last live ops calculated to statically after
// `build` the Instructions. dynamic_ref_  is the runtime version ref which will
// be decreased by one dynamiclly after the execution of an op (in last ops
// list). var_ is the related variable

// The dynamic_ref_ is initialized to static_ref_ first, and is decreased to 1
// during interpretercore's execution, after the interpretercore run, it `reset`
// all dynamic_ref_, i.e., dynamic_ref_ = static_ref_ see ResetAtomicGuard for
// details
class VarRefInfo {
 public:
  explicit VarRefInfo(size_t ref, Variable* var)
      : static_ref_(ref), dynamic_ref_(ref), var_(var) {}
  size_t DynamicRef() { return dynamic_ref_; }
  Variable* Var() { return var_; }
  void ResetDynamicRef() {
    if (static_ref_ != 1) {
      dynamic_ref_ = static_ref_;
    }
  }
415
  void ResetVariable(Variable* new_var) { var_ = new_var; }
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
  bool CheckAndDecrease() {
    return static_ref_ == 1 || (dynamic_ref_.fetch_sub(1) == 1);
  }

 private:
  const size_t static_ref_;
  std::atomic<size_t> dynamic_ref_;
  Variable* var_;
};

// static_dep_ is the numer of dependencies (ops that must run before it) of
// each op which is calculated to statically. static_dep_  is the runtime
// version dep which will be decreased by one dynamiclly after the execution of
// one dependency op.

// The dynamic_dep_ is initialized to static_dep_ first, and is decreased to 1
// during interpretercore's execution, after the interpretercore run, it `reset`
// all dynamic_dep_, i.e., dynamic_dep_ = static_dep_ see ResetAtomicGuard for
// details

class OpDepInfo {
 public:
  explicit OpDepInfo(size_t dep) : static_dep_(dep), dynamic_dep_(dep) {}
  size_t DynamicDep() { return dynamic_dep_; }
  void ResetDynamicDep() {
    if (static_dep_ != 1) {
      dynamic_dep_ = static_dep_;
    }
  }
  bool CheckAndDecrease() {
    return static_dep_ == 1 || (dynamic_dep_.fetch_sub(1) == 1);
  }

 private:
  const size_t static_dep_;
  std::atomic<size_t> dynamic_dep_;
};

class ResetAtomicGuard {
 public:
  ResetAtomicGuard(std::vector<std::shared_ptr<OpDepInfo>>* deps,
                   std::vector<std::shared_ptr<VarRefInfo>>* refs)
      : deps_(deps), refs_(refs) {}

  ~ResetAtomicGuard() {
    VLOG(10) << "Reset DynamicDep";
    for (auto&& dep : *deps_) {
      dep->ResetDynamicDep();
    }
    VLOG(10) << "Reset DynamicRef";
    for (auto&& ref : *refs_) {
      ref->ResetDynamicRef();
    }
  }

 private:
  std::vector<std::shared_ptr<OpDepInfo>>* deps_;
  std::vector<std::shared_ptr<VarRefInfo>>* refs_;
};

476
}  // namespace interpreter
477

478 479
}  // namespace framework
}  // namespace paddle