new_executor_defs.h 12.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include <map>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/operator.h"
22
#include "paddle/fluid/framework/variable_helper.h"
23
#include "paddle/fluid/platform/device_event_base.h"
24
#include "paddle/fluid/platform/event.h"
25
#include "paddle/phi/core/utils/rw_lock.h"
26

L
Leo Chen 已提交
27 28 29 30 31 32 33 34 35 36 37 38
// When in inference scenario, the scopes will not be written by two threads in
// a mean time, but a scope may be read by multiple threads concurrently, and
// the mutex will cause serious performance issue.
// So the mutex is disabled when `ON_INFER`.
#ifdef PADDLE_ON_INFERENCE
#define SCOPE_VARS_READER_LOCK
#define SCOPE_VARS_WRITER_LOCK
#else
#define SCOPE_VARS_READER_LOCK AutoRDLock auto_lock(&vars_lock_);
#define SCOPE_VARS_WRITER_LOCK AutoWRLock auto_lock(&vars_lock_);
#endif

39 40 41 42 43 44 45
namespace paddle {
namespace framework {

using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
using OpKernelMap =
    std::unordered_map<OpKernelType, OpKernelComputeFunc, OpKernelType::Hash>;

W
wanghuancoder 已提交
46 47
constexpr int kEmptyVarIndex = 0;

48 49 50
class InterpretercoreInferShapeContext : public InferShapeContext {
 public:
  InterpretercoreInferShapeContext(const OperatorBase& op,
L
Leo Chen 已提交
51
                                   const RuntimeContext& ctx);
52

L
Leo Chen 已提交
53
  bool HasInput(const std::string& name) const override;
54

L
Leo Chen 已提交
55
  bool HasOutput(const std::string& name) const override;
56

57 58
  bool HasAttr(const std::string& name) const override;

L
Leo Chen 已提交
59
  bool HasInputs(const std::string& name) const override;
60

L
Leo Chen 已提交
61
  bool HasOutputs(const std::string& name) const override;
62

L
Leo Chen 已提交
63
  AttrReader Attrs() const override;
64

L
Leo Chen 已提交
65
  std::vector<std::string> Inputs(const std::string& name) const override;
66

L
Leo Chen 已提交
67
  std::vector<std::string> Outputs(const std::string& name) const override;
68

L
Leo Chen 已提交
69 70 71
  std::string GetInputNameByIdx(size_t idx) const override;

  std::string GetOutputNameByIdx(size_t idx) const override;
72 73

  void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
L
Leo Chen 已提交
74
                size_t j = 0) override;
75 76

  void ShareAllLoD(const std::string& in,
L
Leo Chen 已提交
77
                   const std::string& out) const override;
78 79

  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
L
Leo Chen 已提交
80
                size_t j = 0) const override;
81

L
Leo Chen 已提交
82
  int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override;
83 84

  void SetLoDLevel(const std::string& out, int32_t lod_level,
L
Leo Chen 已提交
85
                   size_t j = 0) const override;
86

L
Leo Chen 已提交
87
  bool IsRuntime() const override;
88

89 90
  bool IsRunMKLDNNKernel() const override;

91 92
  // TODO(paddle-dev): Can this be template?
  std::vector<InferShapeVarPtr> GetInputVarPtrs(
93
      const std::string& name) const override;
94 95

  std::vector<InferShapeVarPtr> GetOutputVarPtrs(
96
      const std::string& name) const override;
97

L
Leo Chen 已提交
98
  DDim GetInputDim(const std::string& name) const override;
99

L
Leo Chen 已提交
100
  std::vector<DDim> GetInputsDim(const std::string& name) const override;
101 102

  std::vector<proto::VarType::Type> GetInputsVarType(
L
Leo Chen 已提交
103
      const std::string& name) const override;
104 105

  std::vector<proto::VarType::Type> GetOutputsVarType(
L
Leo Chen 已提交
106
      const std::string& name) const override;
107

L
Leo Chen 已提交
108
  void SetOutputDim(const std::string& name, const DDim& dim) override;
109 110

  void SetOutputsDim(const std::string& name,
L
Leo Chen 已提交
111
                     const std::vector<DDim>& dims) override;
112

L
Leo Chen 已提交
113
  void SetSkipLoD(bool skip);
114 115

 protected:
L
Leo Chen 已提交
116
  DDim GetDim(Variable* var) const;
117

L
Leo Chen 已提交
118
  std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const;
119

L
Leo Chen 已提交
120
  std::vector<DDim> GetRepeatedDims(const std::string& name) const override;
121

L
Leo Chen 已提交
122
  void SetDim(Variable* var, const DDim& dim);
123 124

  void SetDims(const std::vector<Variable*>& vars,
L
Leo Chen 已提交
125
               const std::vector<DDim>& dims);
126 127

  void SetRepeatedDims(const std::string& name,
L
Leo Chen 已提交
128
                       const std::vector<DDim>& dims) override;
129 130

  std::vector<proto::VarType::Type> GetVarTypes(
L
Leo Chen 已提交
131
      const std::vector<Variable*>& vars) const;
132

L
Leo Chen 已提交
133
  proto::VarType::Type GetVarType(Variable* var) const;
134 135

 private:
L
Leo Chen 已提交
136
  const std::vector<Variable*>& InputVars(const std::string& name) const;
137

L
Leo Chen 已提交
138
  const std::vector<Variable*>& OutputVars(const std::string& name) const;
139 140 141 142 143 144

  const OperatorBase& op_;
  const RuntimeContext& ctx_;
  bool can_skip_lod_;
};

145 146 147 148 149
struct OpKernelFunc {
  OpKernelComputeFunc compute_func_;
};

struct VariableMetaInfo {
L
Leo Chen 已提交
150 151
  int var_ref_count_{0};
  framework::VarDesc* var_desc_{nullptr};
152
  bool sikp_inplace_{false};
L
Leo Chen 已提交
153 154 155 156 157 158 159 160 161 162

  VariableMetaInfo() {}
  VariableMetaInfo(int var_ref_count, framework::VarDesc* var_desc)
      : var_ref_count_(var_ref_count), var_desc_(var_desc) {}
};

class VariableScope;
class VariableScopeListener : public ScopeListener {
 public:
  explicit VariableScopeListener(VariableScope* var_scope_);
163
  void onCreateVariable(const std::string& name, Variable* v) override;
L
Leo Chen 已提交
164 165 166 167 168 169 170 171 172
  void onDeleteVariable(const std::string& name) override;
  void onRenameVariable(const std::string& old_name,
                        const std::string& new_name) override;
  void onCreateScope(Scope* Scope) override;
  void onDeleteScope(Scope* Scope) override;
  void onClear() override;

 private:
  VariableScope* var_scope_;  // not owned
173 174
};

175
// TODO(zhiqiu): Maybe we need to add rwlock for VariableScope?
176 177

// NOTE(xiongkun03): Use scope as a member of VariableScope, we don't need
178 179 180
// ScopeBase. Scope manager the variables and VariableScope is just a quick
// access machanism. ScopeListener is the callback to sync changes in Original
// Scope. We can make it a membership of VariableScope. Here we use inherent.
L
Leo Chen 已提交
181
class VariableScope : public ScopeBase {
182
 public:
L
Leo Chen 已提交
183
  explicit VariableScope(Scope* scope);
184

185 186 187 188 189
  Scope* GetMutableScope() const;

  Scope* GetMutableLocalScope() const;

  void SetLocalScope(Scope* local_scope);
190

L
Leo Chen 已提交
191 192 193
  Variable* FindVar(const std::string& name) const;

  ~VariableScope();
194 195

  // Get variable id by name, return -1 if not found
L
Leo Chen 已提交
196
  int GetIdByName(const std::string& name) const;
197 198

  // Get variable name by id, return "" if not found
L
Leo Chen 已提交
199
  std::string GetNameById(int id) const;
200

L
Leo Chen 已提交
201
  bool HasVar(const std::string& name) const;
202

L
Leo Chen 已提交
203
  int VarId(const std::string& name) const;
204

L
Leo Chen 已提交
205
  Variable* Var(int id) const;
206

L
Leo Chen 已提交
207
  Variable* Var(const std::string& name) const;
208

L
Leo Chen 已提交
209
  size_t VarSize() const;
210

211 212
  void AddVar(const std::string& name, VarDesc* var_desc,
              bool local_scope = false);
213

L
Leo Chen 已提交
214
  void AddVar(const std::string& name, const Variable& var);
215

L
Leo Chen 已提交
216
  void SetVarDesc(const std::string& name, framework::VarDesc* var_desc);
217

L
Leo Chen 已提交
218
  paddle::framework::VarDesc* VarDesc(const std::string& name) const;
219

L
Leo Chen 已提交
220
  paddle::framework::VarDesc* VarDesc(int id) const;
221

L
Leo Chen 已提交
222 223 224
  void CheckExist(int id) const;

  void CheckExist(const std::string& name) const;
225

226 227 228 229 230 231
  std::vector<VariableMetaInfo>& MutableVecMetaInfo() { return vec_meta_info_; }

  const std::vector<VariableMetaInfo>& VecMetaInfo() const {
    return vec_meta_info_;
  }

232 233 234 235
  const std::shared_ptr<VariableScopeListener>& Listener() const {
    return listener_;
  }

236 237 238 239
  void SetVarSikpInplace(const std::string& name, bool skip);

  bool GetVarSikpInplace(int id) const;

L
Leo Chen 已提交
240 241
  friend class VariableScopeListener;

242
 private:
243 244
  std::vector<Variable*> var_list_;
  std::map<std::string, int> name2id_;
W
wanghuancoder 已提交
245
  std::vector<VariableMetaInfo> vec_meta_info_;
246 247 248
  Scope* scope_{nullptr};
  // TODO(zhiqiu): find a better way to support local scope.
  Scope* local_scope_{nullptr};
L
Leo Chen 已提交
249
  // mutable RWLock vars_lock_;
250
  std::shared_ptr<VariableScopeListener> listener_{nullptr};
251 252
};

253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
class NextInstruction {
 public:
  void AddDirectRun(size_t id) { direct_run_.push_back(id); }

  void ADDEventRun(size_t id) { event_wait_run_.push_back(id); }

  void AddSyncRun(size_t id) { synchronize_run_.push_back(id); }

  const std::vector<size_t>& DirectRunIds() const { return direct_run_; }

  const std::vector<size_t>& EventRunIds() const { return event_wait_run_; }

  const std::vector<size_t>& SyncRunIds() const { return synchronize_run_; }

 private:
268
  std::vector<size_t> direct_run_;
269 270
  std::vector<size_t> event_wait_run_;
  std::vector<size_t> synchronize_run_;
271 272
};

273
struct EventInter {
274 275
  explicit EventInter(size_t var_id,
                      std::shared_ptr<platform::DeviceEvent> event,
276 277
                      platform::DeviceType waiter_type)
      : var_id_(var_id), event_(event), waiter_type_(waiter_type) {}
278
  size_t var_id_;
279
  std::shared_ptr<platform::DeviceEvent> event_;
280
  platform::DeviceType waiter_type_;
281
};
282 283 284 285 286

struct InstructionInfo {
  std::vector<size_t> dependecy_count_;
};

287 288 289 290 291 292
enum class OpFuncType {
  kQueueSync = 0,   // CPU kernel, block host
  kQueueAsync = 1,  // GPU Kernel or d2h, h2d, send, recv, broadcast
};
class RuntimeInferShapeContext;

293
struct OpFuncNode {
L
Leo Chen 已提交
294 295
  // TODO(zhiqiu): Better make it unique_ptr
  std::shared_ptr<OperatorBase> operator_base_;
296 297 298 299 300 301
  std::map<std::string, std::vector<int>> input_index;
  std::map<std::string, std::vector<int>> output_index;
  std::unordered_set<int> no_data_transform_index;

  OpKernelComputeFunc kernel_func_;
  platform::DeviceContext* dev_ctx_;  // not owned
302

303
  // fit for phi kernel
304
  phi::Kernel* pt_kernel_{nullptr};  // not owned
305

306 307 308 309 310
  OpFuncType type_;
};

class Instruction {
 public:
L
Leo Chen 已提交
311 312
  Instruction(size_t id, OpFuncNode&& op_func_node,
              const platform::DeviceContext& dev_ctx);
313

L
Leo Chen 已提交
314
  size_t Id() const;
315

L
Leo Chen 已提交
316
  const std::map<std::string, std::vector<int>>& Inputs() const;
317

L
Leo Chen 已提交
318
  const std::map<std::string, std::vector<int>>& Outputs() const;
319

L
Leo Chen 已提交
320
  const std::unordered_set<int>& NoDataTransformVars() const;
321

L
Leo Chen 已提交
322
  OpKernelComputeFunc KernelFunc() const;
323

324
  phi::Kernel* PhiKernel() const;
325

L
Leo Chen 已提交
326
  OpFuncType KernelType() const;
327

L
Leo Chen 已提交
328
  OperatorBase* OpBase() const;
329

L
Leo Chen 已提交
330
  NextInstruction& NextInstructions();
331

L
Leo Chen 已提交
332
  const NextInstruction& NextInstructions() const;
333

L
Leo Chen 已提交
334
  void AddGCCheckVar(size_t id);
335

L
Leo Chen 已提交
336
  const std::vector<size_t>& GCCheckVars() const;
337 338

  void ResetContext(const VariableValueMap& in_vars,
L
Leo Chen 已提交
339
                    const VariableValueMap& out_vars);
340

L
Leo Chen 已提交
341
  std::shared_ptr<RuntimeContext> InnerRuntimeContext() const;
342 343

  std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext()
L
Leo Chen 已提交
344
      const;
345

L
Leo Chen 已提交
346
  std::shared_ptr<ExecutionContext> InnerExecutionContext() const;
347

L
Leo Chen 已提交
348
  const platform::DeviceContext& DeviceContext() const;
349

L
Leo Chen 已提交
350
  const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const;
351

L
Leo Chen 已提交
352
  void AddInplace(Variable* in, Variable* out);
353

L
Leo Chen 已提交
354
  const std::vector<EventInter>& InputEvents() const;
355

L
Leo Chen 已提交
356
  const std::vector<EventInter>& OutputEvents() const;
357 358 359

  void AddInputEvent(size_t var_id,
                     std::shared_ptr<platform::DeviceEvent> event,
L
Leo Chen 已提交
360
                     platform::DeviceType waiter_type);
361 362 363

  void AddOutputEvent(size_t var_id,
                      std::shared_ptr<platform::DeviceEvent> event,
L
Leo Chen 已提交
364
                      platform::DeviceType waiter_type);
365 366 367

 private:
  size_t id_;
L
Leo Chen 已提交
368
  OpFuncNode op_func_node_;
369 370
  const platform::DeviceContext& dev_ctx_;  // not owned

371
  std::shared_ptr<RuntimeContext> runtime_ctx_;
372
  std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_;
373
  std::shared_ptr<ExecutionContext> execution_ctx_;
374

375
  std::vector<size_t> gc_check_var_list_;
376
  NextInstruction next_instruction_;
377 378 379 380

  std::vector<EventInter> intput_events_;
  std::vector<EventInter> output_events_;

381
  std::vector<std::pair<Variable*, Variable*>> vec_inplace_in_to_out_;
382 383
};

384
namespace interpreter {
385 386
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
387
static constexpr char kFetchVarName[] = "fetch";
388 389

static bool IsMemcpyH2D(const Instruction& instr) {
390
  return instr.OpBase()->Type() == kMemcpyH2D;
391 392 393
}

static bool IsMemcpyD2H(const Instruction& instr) {
394
  return instr.OpBase()->Type() == kMemcpyD2H;
395
}
L
Leo Chen 已提交
396 397 398 399 400

static bool IsCpuOp(const Instruction& instr) {
  return platform::is_cpu_place(instr.DeviceContext().GetPlace());
}

401
}  // namespace interpreter
402

403 404
}  // namespace framework
}  // namespace paddle