未验证 提交 dc96ebc0 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Support GC and TraceRun for NewIr InterpreterCore (#55772)

* add interface

* add code

* add code

* add code

* add code

* fix bug

* fix bug

* add var prefix

* add code

* add code

* add code

* fix compile bug

* fix bug

* refine code

* refine code

* refine code

* refine code

* fix bug

* add code

* add code

* fix bug

* add code

* add code

* refine code

* refine code

* fix bug
上级 6f53d3b2
......@@ -37,6 +37,19 @@ InterpreterCoreEventGarbageCollector::InterpreterCoreEventGarbageCollector(
}
}
InterpreterCoreEventGarbageCollector::InterpreterCoreEventGarbageCollector(
const std::vector<std::unique_ptr<InstructionBase>>& vec_instruction) {
WorkQueueOptions options(/*name*/ "GarbageCollector",
/*num_threads*/ 1,
/*allow_spinning*/ true,
/*track_task*/ false);
queue_ = CreateSingleThreadedWorkQueue(options);
for (auto& instruc : vec_instruction) {
gc_event_.emplace_back(instruc->DeviceContext().GetPlace(),
platform::GenerateDeviceEventFlag());
}
}
InterpreterCoreEventGarbageCollector::~InterpreterCoreEventGarbageCollector() {
queue_.reset(nullptr);
}
......@@ -53,6 +66,18 @@ void InterpreterCoreEventGarbageCollector::Add(Variable* var,
Add(var, &gc_event_.at(instr.Id()), &instr.DeviceContext());
}
void InterpreterCoreEventGarbageCollector::Add(Variable* var,
const InstructionBase* instr) {
PADDLE_ENFORCE_LT(instr->Id(),
gc_event_.size(),
platform::errors::OutOfRange(
"The index should be less than the size of gc event "
", but got index is %d and size is %d",
instr->Id(),
gc_event_.size()));
Add(var, &gc_event_.at(instr->Id()), &instr->DeviceContext());
}
void InterpreterCoreEventGarbageCollector::Add(
Variable* var,
platform::DeviceEvent* event,
......
......@@ -26,9 +26,16 @@ class InterpreterCoreEventGarbageCollector
public:
InterpreterCoreEventGarbageCollector(
const std::vector<Instruction>& vec_instruction);
InterpreterCoreEventGarbageCollector(
const std::vector<std::unique_ptr<InstructionBase>>& vec_instruction);
~InterpreterCoreEventGarbageCollector();
void Add(Variable* var, const Instruction& instruction) override;
void Add(Variable* var, const InstructionBase* instruction) override;
private:
void Add(Variable* var,
platform::DeviceEvent* event,
......
......@@ -22,6 +22,11 @@ void InterpreterCoreFastGarbageCollector::Add(Variable* var,
Add(var);
}
void InterpreterCoreFastGarbageCollector::Add(Variable* var,
const InstructionBase*) {
Add(var);
}
void InterpreterCoreFastGarbageCollector::Add(Variable* var) {
if (UNLIKELY(max_memory_size_ < 0) || var == nullptr) {
return;
......
......@@ -23,6 +23,8 @@ class InterpreterCoreFastGarbageCollector
public:
void Add(Variable* var, const Instruction& instr) override;
void Add(Variable* var, const InstructionBase* instr) override;
private:
void Add(Variable* var);
void Add(Garbage garbage);
......
......@@ -49,6 +49,36 @@ InterpreterCoreGarbageCollector::InterpreterCoreGarbageCollector() {
cur_memory_size_ = 0;
}
std::unique_ptr<InterpreterCoreGarbageCollector>
CreateInterpreterCoreGarbageCollector(
const platform::Place& place,
const std::vector<std::unique_ptr<InstructionBase>>& vec_instruction) {
if (platform::is_gpu_place(place)) {
if (IsInterpretercoreFastGCEnabled()) {
return std::unique_ptr<InterpreterCoreGarbageCollector>(
new InterpreterCoreFastGarbageCollector());
} else {
return std::unique_ptr<InterpreterCoreGarbageCollector>(
new InterpreterCoreEventGarbageCollector(vec_instruction));
}
} else if (platform::is_xpu_place(place)) {
// Because there is no multi-stream on XPU device, fast GC can
// be used.
// Previously, XPU used no_event GC. But `Wait` in no_event GC
// may cause GC delayed, causing no enough memory problem.
// TODO(pangyoki): Multi-stream allocator and multi-stream GC
// are needed to be adapted for XPU.
return std::unique_ptr<InterpreterCoreGarbageCollector>(
new InterpreterCoreFastGarbageCollector());
} else if (platform::is_ipu_place(place)) {
return std::unique_ptr<InterpreterCoreGarbageCollector>(
new InterpreterCoreNoEventGarbageCollector());
} else {
return std::unique_ptr<InterpreterCoreGarbageCollector>(
new InterpreterCoreEventGarbageCollector(vec_instruction));
}
}
std::unique_ptr<InterpreterCoreGarbageCollector>
CreateInterpreterCoreGarbageCollector(
const platform::Place& place,
......
......@@ -15,6 +15,7 @@
#include <queue>
#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/memory/allocation/spin_lock.h"
#include "paddle/fluid/platform/device_event.h"
......@@ -34,6 +35,8 @@ class InterpreterCoreGarbageCollector {
virtual void Add(Variable* var, const Instruction& instruction) = 0;
virtual void Add(Variable* var, const InstructionBase* instruction) = 0;
DISABLE_COPY_AND_ASSIGN(InterpreterCoreGarbageCollector);
protected:
......@@ -50,5 +53,10 @@ CreateInterpreterCoreGarbageCollector(
const platform::Place& place,
const std::vector<Instruction>& vec_instruction);
std::unique_ptr<InterpreterCoreGarbageCollector>
CreateInterpreterCoreGarbageCollector(
const platform::Place& place,
const std::vector<std::unique_ptr<InstructionBase>>& vec_instruction);
} // namespace framework
} // namespace paddle
......@@ -36,6 +36,11 @@ void InterpreterCoreNoEventGarbageCollector::Add(Variable* var,
Add(var, &instr.DeviceContext());
}
void InterpreterCoreNoEventGarbageCollector::Add(Variable* var,
const InstructionBase* instr) {
Add(var, &instr->DeviceContext());
}
void InterpreterCoreNoEventGarbageCollector::Add(
Variable* var, const platform::DeviceContext* ctx) {
if (UNLIKELY(max_memory_size_ < 0) || var == nullptr) {
......
......@@ -28,6 +28,8 @@ class InterpreterCoreNoEventGarbageCollector
~InterpreterCoreNoEventGarbageCollector();
void Add(Variable* var, const Instruction& instr) override;
void Add(Variable* var, const InstructionBase* instr) override;
private:
void Add(Variable* var, const platform::DeviceContext* ctx);
void Add(Garbage garbage, const platform::DeviceContext* ctx);
......
......@@ -246,6 +246,7 @@ PhiKernelInstruction::PhiKernelInstruction(
kernel_context_.SetDeviceContext(phi::DeviceContextPool::Instance().Get(
phi::TransToPhiPlace(kernel_key.backend())));
VLOG(6) << "finish process kernel context";
SetDeviceContext(
ParseDeviceContext(op,
phi::DeviceContextPool::Instance().Get(
......
......@@ -16,6 +16,7 @@
#include <memory>
#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"
#include "paddle/fluid/framework/new_executor/interpreter_base_impl.h"
#include "paddle/ir/core/value.h"
namespace ir {
class Program;
......@@ -86,6 +87,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {
std::string GetNameById(int id) const;
int GetIdByName(const std::string& name) const;
private:
// build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
......@@ -93,7 +96,11 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void BuildAndCacheInstructionCtx(Instruction* instr_node);
void BuildSkipShareLoDInfo();
void UpdateSyncOpNum();
void AnalyseExecuteOrderForTrace();
void AnalyseExecuteOrderForTrace(
std::map<size_t, std::set<size_t>> op_downstream_map,
InstructionSchedulingPriorityLess compare);
void ConstructEventForJitInput();
void CalculateLastLiveOps();
// inplace
void BuildInplace();
......@@ -201,10 +208,31 @@ class NewIRInterpreter : public InterpreterBaseImpl {
/// ======================== ///
std::string DebugValueInfo();
void PreAnalysis();
void BuildInstruction();
void BuildInstructionDependences();
void NewIrLoopRunImpl();
void BetaRunImpl();
void TraceInstructionList(
const std::vector<std::unique_ptr<InstructionBase>>& vec_instr);
void RunInstructionBase(InstructionBase* instr_node);
void RecordMemcpyD2H(InstructionBase* instr_node);
::ir::Value GetValueByName(const std::string& var_name);
void CheckGC(InstructionBase* instr);
void RecordStreamForGC(InstructionBase* instr);
InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less;
std::unique_ptr<::ir::Program> ir_program_{nullptr};
std::vector<std::unique_ptr<InstructionBase>> vec_instruction_base_;
......@@ -218,6 +246,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {
std::vector<Variable*> variable_list_;
std::vector<int> var_ref_count_;
interpreter::NewIrDependencyBuilder ir_dependency_builder_;
interpreter::NewIrStreamAnalyzer ir_stream_analyzer_;
......
......@@ -70,23 +70,19 @@ TEST(StandaloneExecutor, run) {
ProgramDesc prog_desc;
InterpreterCore test_core(place, std::move(kernel_program), &scope);
VLOG(0) << "&test_core" << &test_core;
VLOG(0) << "&test_core.impl" << test_core.Impl();
VLOG(0) << "&test_core.impl.cast"
<< reinterpret_cast<NewIRInterpreter*>(
const_cast<InterpreterBaseImpl*>(test_core.Impl()));
test_core.BetaRun({});
std::stringstream os;
os << reinterpret_cast<NewIRInterpreter*>(
const_cast<InterpreterBaseImpl*>(test_core.Impl()));
std::string prefix_str = os.str();
std::string out_name = os.str() + "_inner_var_2";
test_core.SetSkipGcVars({out_name});
test_core.BetaRun({});
auto out_tensor =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_2")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_2")
->Get<phi::DenseTensor>();
? scope.FindVar(out_name)->Get<phi::DenseTensor>()
: test_core.local_scope()->FindVar(out_name)->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 2.0);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 2.0);
......@@ -115,18 +111,19 @@ TEST(StandaloneExecutor, run_inplace_sqrt) {
auto place = platform::CPUPlace();
Scope scope;
InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.BetaRun({});
std::stringstream os;
os << reinterpret_cast<NewIRInterpreter*>(
const_cast<InterpreterBaseImpl*>(test_core.Impl()));
std::string prefix_str = os.str();
std::string out_name = os.str() + "_inner_var_0";
test_core.SetSkipGcVars({out_name});
test_core.BetaRun({});
auto out_tensor =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_0")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_0")
->Get<phi::DenseTensor>();
? scope.FindVar(out_name)->Get<phi::DenseTensor>()
: test_core.local_scope()->FindVar(out_name)->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 2.0);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 2.0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册