new_executor_defs.h 11.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include <map>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/operator.h"
L
Leo Chen 已提交
22
#include "paddle/fluid/framework/rw_lock.h"
23
#include "paddle/fluid/framework/variable_helper.h"
24
#include "paddle/fluid/platform/device_event_base.h"
25
#include "paddle/fluid/platform/event.h"
26

L
Leo Chen 已提交
27 28 29 30 31 32 33 34 35 36 37 38
// When in inference scenario, the scopes will not be written by two threads in
// a mean time, but a scope may be read by multiple threads concurrently, and
// the mutex will cause serious performance issue.
// So the mutex is disabled when `ON_INFER`.
#ifdef PADDLE_ON_INFERENCE
#define SCOPE_VARS_READER_LOCK
#define SCOPE_VARS_WRITER_LOCK
#else
#define SCOPE_VARS_READER_LOCK AutoRDLock auto_lock(&vars_lock_);
#define SCOPE_VARS_WRITER_LOCK AutoWRLock auto_lock(&vars_lock_);
#endif

39 40 41 42 43 44 45
namespace paddle {
namespace framework {

using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
using OpKernelMap =
    std::unordered_map<OpKernelType, OpKernelComputeFunc, OpKernelType::Hash>;

W
wanghuancoder 已提交
46 47
constexpr int kEmptyVarIndex = 0;

48 49 50
class InterpretercoreInferShapeContext : public InferShapeContext {
 public:
  InterpretercoreInferShapeContext(const OperatorBase& op,
L
Leo Chen 已提交
51
                                   const RuntimeContext& ctx);
52

L
Leo Chen 已提交
53
  bool HasInput(const std::string& name) const override;
54

L
Leo Chen 已提交
55
  bool HasOutput(const std::string& name) const override;
56

L
Leo Chen 已提交
57
  bool HasInputs(const std::string& name) const override;
58

L
Leo Chen 已提交
59
  bool HasOutputs(const std::string& name) const override;
60

L
Leo Chen 已提交
61
  AttrReader Attrs() const override;
62

L
Leo Chen 已提交
63
  std::vector<std::string> Inputs(const std::string& name) const override;
64

L
Leo Chen 已提交
65
  std::vector<std::string> Outputs(const std::string& name) const override;
66

L
Leo Chen 已提交
67 68 69
  std::string GetInputNameByIdx(size_t idx) const override;

  std::string GetOutputNameByIdx(size_t idx) const override;
70 71

  void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
L
Leo Chen 已提交
72
                size_t j = 0) override;
73 74

  void ShareAllLoD(const std::string& in,
L
Leo Chen 已提交
75
                   const std::string& out) const override;
76 77

  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
L
Leo Chen 已提交
78
                size_t j = 0) const override;
79

L
Leo Chen 已提交
80
  int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override;
81 82

  void SetLoDLevel(const std::string& out, int32_t lod_level,
L
Leo Chen 已提交
83
                   size_t j = 0) const override;
84

L
Leo Chen 已提交
85
  bool IsRuntime() const override;
86 87 88

  // TODO(paddle-dev): Can this be template?
  std::vector<InferShapeVarPtr> GetInputVarPtrs(
L
Leo Chen 已提交
89
      const std::string& name) override;
90 91

  std::vector<InferShapeVarPtr> GetOutputVarPtrs(
L
Leo Chen 已提交
92
      const std::string& name) override;
93

L
Leo Chen 已提交
94
  DDim GetInputDim(const std::string& name) const override;
95

L
Leo Chen 已提交
96
  std::vector<DDim> GetInputsDim(const std::string& name) const override;
97 98

  std::vector<proto::VarType::Type> GetInputsVarType(
L
Leo Chen 已提交
99
      const std::string& name) const override;
100 101

  std::vector<proto::VarType::Type> GetOutputsVarType(
L
Leo Chen 已提交
102
      const std::string& name) const override;
103

L
Leo Chen 已提交
104
  void SetOutputDim(const std::string& name, const DDim& dim) override;
105 106

  void SetOutputsDim(const std::string& name,
L
Leo Chen 已提交
107
                     const std::vector<DDim>& dims) override;
108

L
Leo Chen 已提交
109
  void SetSkipLoD(bool skip);
110 111

 protected:
L
Leo Chen 已提交
112
  DDim GetDim(Variable* var) const;
113

L
Leo Chen 已提交
114
  std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const;
115

L
Leo Chen 已提交
116
  std::vector<DDim> GetRepeatedDims(const std::string& name) const override;
117

L
Leo Chen 已提交
118
  void SetDim(Variable* var, const DDim& dim);
119 120

  void SetDims(const std::vector<Variable*>& vars,
L
Leo Chen 已提交
121
               const std::vector<DDim>& dims);
122 123

  void SetRepeatedDims(const std::string& name,
L
Leo Chen 已提交
124
                       const std::vector<DDim>& dims) override;
125 126

  std::vector<proto::VarType::Type> GetVarTypes(
L
Leo Chen 已提交
127
      const std::vector<Variable*>& vars) const;
128

L
Leo Chen 已提交
129
  proto::VarType::Type GetVarType(Variable* var) const;
130 131

 private:
L
Leo Chen 已提交
132
  const std::vector<Variable*>& InputVars(const std::string& name) const;
133

L
Leo Chen 已提交
134
  const std::vector<Variable*>& OutputVars(const std::string& name) const;
135 136 137 138 139 140

  const OperatorBase& op_;
  const RuntimeContext& ctx_;
  bool can_skip_lod_;
};

141 142 143 144 145
struct OpKernelFunc {
  OpKernelComputeFunc compute_func_;
};

struct VariableMetaInfo {
L
Leo Chen 已提交
146 147
  int var_ref_count_{0};
  framework::VarDesc* var_desc_{nullptr};
148
  bool sikp_inplace_{false};
L
Leo Chen 已提交
149 150 151 152 153 154 155 156 157 158

  VariableMetaInfo() {}
  VariableMetaInfo(int var_ref_count, framework::VarDesc* var_desc)
      : var_ref_count_(var_ref_count), var_desc_(var_desc) {}
};

class VariableScope;
class VariableScopeListener : public ScopeListener {
 public:
  explicit VariableScopeListener(VariableScope* var_scope_);
159
  void onCreateVariable(const std::string& name, Variable* v) override;
L
Leo Chen 已提交
160 161 162 163 164 165 166 167 168
  void onDeleteVariable(const std::string& name) override;
  void onRenameVariable(const std::string& old_name,
                        const std::string& new_name) override;
  void onCreateScope(Scope* Scope) override;
  void onDeleteScope(Scope* Scope) override;
  void onClear() override;

 private:
  VariableScope* var_scope_;  // not owned
169 170
};

171
// TODO(zhiqiu): Maybe we need to add rwlock for VariableScope?
172 173

// NOTE(xiongkun03): Use scope as a member of VariableScope, we don't need
174 175 176
// ScopeBase. Scope manager the variables and VariableScope is just a quick
// access machanism. ScopeListener is the callback to sync changes in Original
// Scope. We can make it a membership of VariableScope. Here we use inherent.
L
Leo Chen 已提交
177
class VariableScope : public ScopeBase {
178
 public:
L
Leo Chen 已提交
179
  explicit VariableScope(Scope* scope);
180

181 182 183 184 185
  Scope* GetMutableScope() const;

  Scope* GetMutableLocalScope() const;

  void SetLocalScope(Scope* local_scope);
186

L
Leo Chen 已提交
187 188 189
  Variable* FindVar(const std::string& name) const;

  ~VariableScope();
190 191

  // Get variable id by name, return -1 if not found
L
Leo Chen 已提交
192
  int GetIdByName(const std::string& name) const;
193 194

  // Get variable name by id, return "" if not found
L
Leo Chen 已提交
195
  std::string GetNameById(int id) const;
196

L
Leo Chen 已提交
197
  bool HasVar(const std::string& name) const;
198

L
Leo Chen 已提交
199
  int VarId(const std::string& name) const;
200

L
Leo Chen 已提交
201
  Variable* Var(int id) const;
202

L
Leo Chen 已提交
203
  Variable* Var(const std::string& name) const;
204

L
Leo Chen 已提交
205
  size_t VarSize() const;
206

207 208
  void AddVar(const std::string& name, VarDesc* var_desc,
              bool local_scope = false);
209

L
Leo Chen 已提交
210
  void AddVar(const std::string& name, const Variable& var);
211

L
Leo Chen 已提交
212
  void SetVarDesc(const std::string& name, framework::VarDesc* var_desc);
213

L
Leo Chen 已提交
214
  paddle::framework::VarDesc* VarDesc(const std::string& name) const;
215

L
Leo Chen 已提交
216
  paddle::framework::VarDesc* VarDesc(int id) const;
217

L
Leo Chen 已提交
218 219 220
  void CheckExist(int id) const;

  void CheckExist(const std::string& name) const;
221

222 223 224 225 226 227
  std::vector<VariableMetaInfo>& MutableVecMetaInfo() { return vec_meta_info_; }

  const std::vector<VariableMetaInfo>& VecMetaInfo() const {
    return vec_meta_info_;
  }

228 229 230 231
  const std::shared_ptr<VariableScopeListener>& Listener() const {
    return listener_;
  }

232 233 234 235
  void SetVarSikpInplace(const std::string& name, bool skip);

  bool GetVarSikpInplace(int id) const;

L
Leo Chen 已提交
236 237
  friend class VariableScopeListener;

238
 private:
239 240
  std::vector<Variable*> var_list_;
  std::map<std::string, int> name2id_;
W
wanghuancoder 已提交
241
  std::vector<VariableMetaInfo> vec_meta_info_;
242 243 244
  Scope* scope_{nullptr};
  // TODO(zhiqiu): find a better way to support local scope.
  Scope* local_scope_{nullptr};
L
Leo Chen 已提交
245
  // mutable RWLock vars_lock_;
246
  std::shared_ptr<VariableScopeListener> listener_{nullptr};
247 248
};

249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
class NextInstruction {
 public:
  void AddDirectRun(size_t id) { direct_run_.push_back(id); }

  void ADDEventRun(size_t id) { event_wait_run_.push_back(id); }

  void AddSyncRun(size_t id) { synchronize_run_.push_back(id); }

  const std::vector<size_t>& DirectRunIds() const { return direct_run_; }

  const std::vector<size_t>& EventRunIds() const { return event_wait_run_; }

  const std::vector<size_t>& SyncRunIds() const { return synchronize_run_; }

 private:
264
  std::vector<size_t> direct_run_;
265 266
  std::vector<size_t> event_wait_run_;
  std::vector<size_t> synchronize_run_;
267 268
};

269
struct EventInter {
270 271
  explicit EventInter(size_t var_id,
                      std::shared_ptr<platform::DeviceEvent> event,
272 273
                      platform::DeviceType waiter_type)
      : var_id_(var_id), event_(event), waiter_type_(waiter_type) {}
274
  size_t var_id_;
275
  std::shared_ptr<platform::DeviceEvent> event_;
276
  platform::DeviceType waiter_type_;
277
};
278 279 280 281 282

struct InstructionInfo {
  std::vector<size_t> dependecy_count_;
};

283 284 285 286 287 288
enum class OpFuncType {
  kQueueSync = 0,   // CPU kernel, block host
  kQueueAsync = 1,  // GPU Kernel or d2h, h2d, send, recv, broadcast
};
class RuntimeInferShapeContext;

289
struct OpFuncNode {
L
Leo Chen 已提交
290 291
  // TODO(zhiqiu): Better make it unique_ptr
  std::shared_ptr<OperatorBase> operator_base_;
292 293 294 295 296 297 298 299 300 301 302
  std::map<std::string, std::vector<int>> input_index;
  std::map<std::string, std::vector<int>> output_index;
  std::unordered_set<int> no_data_transform_index;

  OpKernelComputeFunc kernel_func_;
  platform::DeviceContext* dev_ctx_;  // not owned
  OpFuncType type_;
};

class Instruction {
 public:
L
Leo Chen 已提交
303 304
  Instruction(size_t id, OpFuncNode&& op_func_node,
              const platform::DeviceContext& dev_ctx);
305

L
Leo Chen 已提交
306
  size_t Id() const;
307

L
Leo Chen 已提交
308
  const std::map<std::string, std::vector<int>>& Inputs() const;
309

L
Leo Chen 已提交
310
  const std::map<std::string, std::vector<int>>& Outputs() const;
311

L
Leo Chen 已提交
312
  const std::unordered_set<int>& NoDataTransformVars() const;
313

L
Leo Chen 已提交
314
  OpKernelComputeFunc KernelFunc() const;
315

L
Leo Chen 已提交
316
  OpFuncType KernelType() const;
317

L
Leo Chen 已提交
318
  OperatorBase* OpBase() const;
319

L
Leo Chen 已提交
320
  NextInstruction& NextInstructions();
321

L
Leo Chen 已提交
322
  const NextInstruction& NextInstructions() const;
323

L
Leo Chen 已提交
324
  void AddGCCheckVar(size_t id);
325

L
Leo Chen 已提交
326
  const std::vector<size_t>& GCCheckVars() const;
327 328

  void ResetContext(const VariableValueMap& in_vars,
L
Leo Chen 已提交
329
                    const VariableValueMap& out_vars);
330

L
Leo Chen 已提交
331
  std::shared_ptr<RuntimeContext> InnerRuntimeContext() const;
332 333

  std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext()
L
Leo Chen 已提交
334
      const;
335

L
Leo Chen 已提交
336
  std::shared_ptr<ExecutionContext> InnerExecutionContext() const;
337

L
Leo Chen 已提交
338
  const platform::DeviceContext& DeviceContext() const;
339

L
Leo Chen 已提交
340
  const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const;
341

L
Leo Chen 已提交
342
  void AddInplace(Variable* in, Variable* out);
343

L
Leo Chen 已提交
344
  const std::vector<EventInter>& InputEvents() const;
345

L
Leo Chen 已提交
346
  const std::vector<EventInter>& OutputEvents() const;
347 348 349

  void AddInputEvent(size_t var_id,
                     std::shared_ptr<platform::DeviceEvent> event,
L
Leo Chen 已提交
350
                     platform::DeviceType waiter_type);
351 352 353

  void AddOutputEvent(size_t var_id,
                      std::shared_ptr<platform::DeviceEvent> event,
L
Leo Chen 已提交
354
                      platform::DeviceType waiter_type);
355 356 357

 private:
  size_t id_;
L
Leo Chen 已提交
358
  OpFuncNode op_func_node_;
359 360
  const platform::DeviceContext& dev_ctx_;  // not owned

361
  std::shared_ptr<RuntimeContext> runtime_ctx_;
362
  std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_;
363
  std::shared_ptr<ExecutionContext> execution_ctx_;
364

365
  std::vector<size_t> gc_check_var_list_;
366
  NextInstruction next_instruction_;
367 368 369 370

  std::vector<EventInter> intput_events_;
  std::vector<EventInter> output_events_;

371
  std::vector<std::pair<Variable*, Variable*>> vec_inplace_in_to_out_;
372 373
};

374
namespace interpreter {
375 376
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
377
static constexpr char kFetchVarName[] = "fetch";
378 379

static bool IsMemcpyH2D(const Instruction& instr) {
380
  return instr.OpBase()->Type() == kMemcpyH2D;
381 382 383
}

static bool IsMemcpyD2H(const Instruction& instr) {
384
  return instr.OpBase()->Type() == kMemcpyD2H;
385
}
L
Leo Chen 已提交
386 387 388 389 390

static bool IsCpuOp(const Instruction& instr) {
  return platform::is_cpu_place(instr.DeviceContext().GetPlace());
}

391
}  // namespace interpreter
392

393 394
}  // namespace framework
}  // namespace paddle