未验证 提交 63b7fc80 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] NewIr Interpreter Beta run regular (#55828)

* add interface

* add code

* add code

* add code

* add code

* fix bug

* fix bug

* add var prefix

* add code

* add code

* add code

* fix compile bug

* fix bug

* refine code

* refine code

* refine code

* refine code

* fix bug

* add code

* add code

* fix bug

* add code

* add code

* refine code

* refine code

* fix bug

* add code

* fix bug in phi__kernel_utils

* refine code

* fix bug

* open flag

* refine code

* fix bug

* fix bug

* refine code

* fix bug
上级 e61d892a
...@@ -377,7 +377,7 @@ inline void RunProgramAPI( ...@@ -377,7 +377,7 @@ inline void RunProgramAPI(
if (FLAGS_enable_new_ir_in_executor) { if (FLAGS_enable_new_ir_in_executor) {
// build new ir program // build new ir program
auto ir_program = paddle::framework::ConstructFowardIrProgram( auto ir_program = paddle::framework::ConstructFowardIrProgram(
forward_global_block, backward_global_block, output_names, x); forward_global_block, backward_global_block, output_names, x, params);
interpreter_core = interpreter_core =
paddle::framework::CreateNewIRInterpreterCoreInfoToCache( paddle::framework::CreateNewIRInterpreterCoreInfoToCache(
std::move(ir_program), std::move(ir_program),
...@@ -529,8 +529,12 @@ inline void RunProgramGradAPI( ...@@ -529,8 +529,12 @@ inline void RunProgramGradAPI(
details::ShareTensorsIntoScope(out_grad, global_inner_scope); details::ShareTensorsIntoScope(out_grad, global_inner_scope);
if (FLAGS_enable_new_ir_in_executor) { if (FLAGS_enable_new_ir_in_executor) {
auto res = paddle::framework::ConstructBackwardIrProgram( auto res =
backward_global_block, out_grad, x_grad, params_grad); paddle::framework::ConstructBackwardIrProgram(backward_global_block,
out_grad,
x_grad,
params_grad,
global_inner_scope);
interpreter_core = interpreter_core =
paddle::framework::CreateNewIRInterpreterCoreInfoToCache( paddle::framework::CreateNewIRInterpreterCoreInfoToCache(
......
...@@ -338,7 +338,7 @@ std::shared_ptr<InterpreterCore> CreateNewIRInterpreterCoreInfoToCache( ...@@ -338,7 +338,7 @@ std::shared_ptr<InterpreterCore> CreateNewIRInterpreterCoreInfoToCache(
std::shared_ptr<InterpreterCore> core = nullptr; std::shared_ptr<InterpreterCore> core = nullptr;
core.reset(new InterpreterCore( core.reset(new InterpreterCore(
place, std::move(ir_program), scope, execution_config)); place, {}, std::move(ir_program), scope, execution_config));
auto &cached_value = auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, is_grad); interpretercore_info_cache.GetMutable(program_id, is_grad);
...@@ -350,7 +350,8 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram( ...@@ -350,7 +350,8 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram(
const paddle::framework::BlockDesc *forward_global_block, const paddle::framework::BlockDesc *forward_global_block,
const paddle::framework::BlockDesc *backward_global_block, const paddle::framework::BlockDesc *backward_global_block,
const std::vector<std::string> output_names, const std::vector<std::string> output_names,
const std::vector<paddle::Tensor> &x) { const std::vector<paddle::Tensor> &x,
const std::vector<paddle::Tensor> &params) {
auto ir_ctx = ::ir::IrContext::Instance(); auto ir_ctx = ::ir::IrContext::Instance();
auto program = std::make_unique<::ir::Program>(ir_ctx); auto program = std::make_unique<::ir::Program>(ir_ctx);
...@@ -386,6 +387,20 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram( ...@@ -386,6 +387,20 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram(
op_desc->SetOutput("out", {name}); op_desc->SetOutput("out", {name});
} }
for (auto &param : params) {
auto name = param.name();
auto place = param.place().GetType();
auto op_desc = local_program.MutableBlock(0)->PrependOp();
op_desc->SetType("feed_with_place");
op_desc->SetAttr("index", 0);
// TODO(phlrain) : using tensor dtype
op_desc->SetAttr("dtype", 0);
op_desc->SetAttr("place", static_cast<int>(place));
op_desc->SetAttr("name", name);
op_desc->SetOutput("out", {name});
}
std::set<std::string> set_parameter_names; std::set<std::string> set_parameter_names;
for (auto op_desc : backward_global_block->Program()->Block(0).AllOps()) { for (auto op_desc : backward_global_block->Program()->Block(0).AllOps()) {
for (const auto &n : op_desc->Inputs()) { for (const auto &n : op_desc->Inputs()) {
...@@ -426,31 +441,45 @@ std::unique_ptr<::ir::Program> ConstructBackwardIrProgram( ...@@ -426,31 +441,45 @@ std::unique_ptr<::ir::Program> ConstructBackwardIrProgram(
const paddle::framework::BlockDesc *backward_global_block, const paddle::framework::BlockDesc *backward_global_block,
const std::vector<paddle::Tensor> &out_grad, const std::vector<paddle::Tensor> &out_grad,
const std::vector<paddle::Tensor *> &x_grad, const std::vector<paddle::Tensor *> &x_grad,
const std::vector<paddle::Tensor *> &params_grad) { const std::vector<paddle::Tensor *> &params_grad,
const paddle::framework::Scope *scope) {
auto ir_ctx = ::ir::IrContext::Instance(); auto ir_ctx = ::ir::IrContext::Instance();
auto program = std::make_unique<::ir::Program>(ir_ctx); auto program = std::make_unique<::ir::Program>(ir_ctx);
auto local_program = auto local_program =
paddle::framework::ProgramDesc(*(backward_global_block->Program())); paddle::framework::ProgramDesc(*(backward_global_block->Program()));
// add feed kernel
auto *block = local_program.MutableBlock(0); // get feed with data
for (auto &out_grad_t : out_grad) { std::set<std::string> set_parameter_names;
auto name = out_grad_t.name(); for (auto op_desc : backward_global_block->Program()->Block(0).AllOps()) {
if (block->FindVarRecursive(name) == nullptr) { for (const auto &n : op_desc->Inputs()) {
continue; const auto &input_var_names = n.second;
for (const auto &var_name : input_var_names) {
set_parameter_names.insert(var_name);
}
} }
auto place = out_grad_t.place().GetType(); }
if (name == "@EMPTY@") {
continue; for (auto &var_name : set_parameter_names) {
if (scope->FindVar(var_name)) {
auto tensor = scope->FindVar(var_name)->Get<phi::DenseTensor>();
phi::AllocationType place(phi::AllocationType::UNDEFINED);
if (tensor.initialized()) {
place = tensor.place().GetType();
}
if (var_name == "@EMPTY@") {
continue;
}
auto op_desc = local_program.MutableBlock(0)->PrependOp();
op_desc->SetType("feed_with_place");
op_desc->SetAttr("index", 0);
// TODO(phlrain) : using tensor dtype
op_desc->SetAttr("dtype", 0);
op_desc->SetAttr("place", static_cast<int>(place));
op_desc->SetAttr("name", var_name);
op_desc->SetOutput("out", {var_name});
} }
auto op_desc = block->PrependOp();
op_desc->SetType("feed_with_place");
op_desc->SetAttr("index", 0);
// TODO(phlrain) : using tensor dtype
op_desc->SetAttr("dtype", 0);
op_desc->SetAttr("place", static_cast<int>(place));
op_desc->SetAttr("name", name);
op_desc->SetOutput("out", {name});
} }
std::vector<std::string> param_grad_names; std::vector<std::string> param_grad_names;
......
...@@ -241,13 +241,15 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram( ...@@ -241,13 +241,15 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram(
const paddle::framework::BlockDesc* forward_global_block, const paddle::framework::BlockDesc* forward_global_block,
const paddle::framework::BlockDesc* backward_global_block, const paddle::framework::BlockDesc* backward_global_block,
const std::vector<std::string> output_names, const std::vector<std::string> output_names,
const std::vector<paddle::Tensor>& x); const std::vector<paddle::Tensor>& x,
const std::vector<paddle::Tensor>& params);
std::unique_ptr<::ir::Program> ConstructBackwardIrProgram( std::unique_ptr<::ir::Program> ConstructBackwardIrProgram(
const paddle::framework::BlockDesc* backward_global_block, const paddle::framework::BlockDesc* backward_global_block,
const std::vector<paddle::Tensor>& out_grad, const std::vector<paddle::Tensor>& out_grad,
const std::vector<paddle::Tensor*>& x_grad, const std::vector<paddle::Tensor*>& x_grad,
const std::vector<paddle::Tensor*>& params_grad); const std::vector<paddle::Tensor*>& params_grad,
const paddle::framework::Scope* scope);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -47,13 +47,15 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -47,13 +47,15 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
place, block, scope, execution_config); place, block, scope, execution_config);
} }
InterpreterCore::InterpreterCore(const platform::Place& place, InterpreterCore::InterpreterCore(
std::unique_ptr<::ir::Program> ir_prog, const platform::Place& place,
framework::Scope* scope, const std::vector<std::string>& fetch_var_names,
const ExecutionConfig& execution_config) { std::unique_ptr<::ir::Program> ir_prog,
framework::Scope* scope,
const ExecutionConfig& execution_config) {
VLOG(4) << "InterpreterCore(): " << this << " on " << place; VLOG(4) << "InterpreterCore(): " << this << " on " << place;
impl_ = std::make_unique<NewIRInterpreter>( impl_ = std::make_unique<NewIRInterpreter>(
place, std::move(ir_prog), scope, execution_config); place, fetch_var_names, std::move(ir_prog), scope, execution_config);
} }
InterpreterCore::~InterpreterCore() { InterpreterCore::~InterpreterCore() {
......
...@@ -37,6 +37,7 @@ class InterpreterCore { ...@@ -37,6 +37,7 @@ class InterpreterCore {
const ExecutionConfig& execution_config = ExecutionConfig()); const ExecutionConfig& execution_config = ExecutionConfig());
// This constructor is for New IR. // This constructor is for New IR.
InterpreterCore(const platform::Place& place, InterpreterCore(const platform::Place& place,
const std::vector<std::string>& fetch_var_names,
std::unique_ptr<::ir::Program> ir_prog, std::unique_ptr<::ir::Program> ir_prog,
Scope* scope, Scope* scope,
const ExecutionConfig& execution_config = ExecutionConfig()); const ExecutionConfig& execution_config = ExecutionConfig());
...@@ -80,6 +81,8 @@ class InterpreterCore { ...@@ -80,6 +81,8 @@ class InterpreterCore {
DISABLE_COPY_AND_ASSIGN(InterpreterCore); DISABLE_COPY_AND_ASSIGN(InterpreterCore);
std::unique_ptr<InterpreterBaseImpl> impl_; std::unique_ptr<InterpreterBaseImpl> impl_;
std::vector<std::string> fetch_var_names_;
}; };
} // namespace framework } // namespace framework
......
...@@ -40,20 +40,29 @@ ...@@ -40,20 +40,29 @@
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" #include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
PHI_DECLARE_bool(enable_new_ir_in_executor);
PHI_DECLARE_bool(enable_new_ir_in_executor_beta_run);
PHI_DECLARE_bool(enable_new_ir_in_executor_loop_run);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
NewIRInterpreter::NewIRInterpreter(const platform::Place& place, NewIRInterpreter::NewIRInterpreter(
std::unique_ptr<::ir::Program> ir_prog, const platform::Place& place,
framework::Scope* scope, const std::vector<std::string>& fetch_var_names,
const ExecutionConfig& execution_config) std::unique_ptr<::ir::Program> ir_prog,
framework::Scope* scope,
const ExecutionConfig& execution_config)
: place_(place), : place_(place),
stream_analyzer_(place), stream_analyzer_(place),
execution_config_(execution_config), execution_config_(execution_config),
var_scope_(scope), var_scope_(scope),
scope_(scope), scope_(scope),
ir_program_(std::move(ir_prog)), ir_program_(std::move(ir_prog)),
ir_stream_analyzer_(place) { ir_stream_analyzer_(place),
fetch_var_names_(fetch_var_names) {
VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_; VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_;
static_build_ = FLAGS_new_executor_static_build && static_build_ = FLAGS_new_executor_static_build &&
!FLAGS_new_executor_use_cuda_graph && !FLAGS_new_executor_use_cuda_graph &&
...@@ -188,6 +197,11 @@ FetchList NewIRInterpreter::Run( ...@@ -188,6 +197,11 @@ FetchList NewIRInterpreter::Run(
FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names, FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
bool need_fetch) { bool need_fetch) {
if (FLAGS_enable_new_ir_in_executor_beta_run) {
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode.";
return BetaRun(feed_names, need_fetch);
}
SetDeviceId(place_); SetDeviceId(place_);
CheckCUDAGraphBeforeRun(feed_names); CheckCUDAGraphBeforeRun(feed_names);
...@@ -196,7 +210,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names, ...@@ -196,7 +210,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
#endif #endif
if (!is_build_) { if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is Running."; LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in OldRun mode.";
std::stringstream ss; std::stringstream ss;
ss << this; ss << this;
::ir::BuildScope(*ir_program_->block(), ::ir::BuildScope(*ir_program_->block(),
...@@ -205,7 +219,8 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names, ...@@ -205,7 +219,8 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
&value_2_var_name_, &value_2_var_name_,
&variable_2_var_name_, &variable_2_var_name_,
&var_name_2_id_, &var_name_2_id_,
&variable_list_); &variable_list_,
&parameter_values_);
VLOG(4) << DebugValueInfo(); VLOG(4) << DebugValueInfo();
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
...@@ -235,20 +250,36 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names, ...@@ -235,20 +250,36 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
// return Fetch Tensors // return Fetch Tensors
Scope* inner_scope = InnerScope(); Scope* inner_scope = InnerScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var && need_fetch) { if (FLAGS_enable_new_ir_in_executor) {
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>()); framework::FetchList fetch_res;
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) { if (need_fetch) {
PADDLE_ENFORCE_EQ(fetch_list.empty(), for (auto& var_name : fetch_var_names_) {
true, auto* var = inner_scope->FindVar(var_name);
platform::errors::InvalidArgument( VLOG(0) << "fetch " << var_name << "[" << var << "]";
"Cannot fetch data when using CUDA Graph.")); fetch_res.push_back(var->Get<phi::DenseTensor>());
}
} }
#endif VLOG(4) << "get fetch list size: " << fetch_res.size();
return fetch_list; return fetch_res;
} else { } else {
return {}; auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var && need_fetch) {
auto fetch_list =
std::move(*fetch_var->GetMutable<framework::FetchList>());
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
PADDLE_ENFORCE_EQ(fetch_list.empty(),
true,
platform::errors::InvalidArgument(
"Cannot fetch data when using CUDA Graph."));
}
#endif
return fetch_list;
} else {
return {};
}
} }
} }
...@@ -1355,15 +1386,6 @@ void NewIRInterpreter::CheckGC(const Instruction& instr) { ...@@ -1355,15 +1386,6 @@ void NewIRInterpreter::CheckGC(const Instruction& instr) {
} }
} }
::ir::Value NewIRInterpreter::GetValueByName(const std::string& var_name) {
for (auto kv : value_2_var_name_) {
if (kv.second == var_name) {
return kv.first;
}
}
return nullptr;
}
void NewIRInterpreter::Prepare( void NewIRInterpreter::Prepare(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors, const std::vector<phi::DenseTensor>& feed_tensors,
...@@ -1599,10 +1621,10 @@ void NewIRInterpreter::BuildInstruction() { ...@@ -1599,10 +1621,10 @@ void NewIRInterpreter::BuildInstruction() {
.at("op_name") .at("op_name")
.dyn_cast<::ir::StrAttribute>() .dyn_cast<::ir::StrAttribute>()
.AsString(); .AsString();
if (op_name == "builtin.combine" || op_name == "builtin.slice" || if (op_name == "builtin.combine" || op_name == "pd.feed" ||
op_name == "pd.feed" || op_name == "pd.fetch" ||
op_name == "builtin.set_parameter" || op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter") { op_name == "builtin.get_parameter" || op_name == "builtin.slice" ||
op_name == "pd.feed_with_place" || op_name == "pd.shaddow_output") {
VLOG(6) << "skip process " << op_name; VLOG(6) << "skip process " << op_name;
continue; continue;
} }
...@@ -1777,12 +1799,17 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) { ...@@ -1777,12 +1799,17 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) {
// persistable var will be ignore while GC // persistable var will be ignore while GC
::ir::Value value = GetValueByName(GetNameById(var_id)); ::ir::Value value = GetValueByName(GetNameById(var_id));
if (value && value.GetDefiningOp()->attributes().count("is_persisable") && bool is_parameter = false;
value.GetDefiningOp() if (value) {
->attributes() for (auto item : parameter_values_) {
.at("is_persisable") if (item == value) {
.dyn_cast<::ir::BoolAttribute>() is_parameter = true;
.data()) { break;
}
}
}
if (is_parameter) {
VLOG(4) << "value " << value.impl() << " is a parameter, skip gc";
continue; continue;
} }
...@@ -1830,14 +1857,20 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) { ...@@ -1830,14 +1857,20 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) {
bool is_ready = refs_[var_id]->CheckAndDecrease(); bool is_ready = refs_[var_id]->CheckAndDecrease();
// ignore all persistable var while GCphi // ignore all persistable var while GCphi
::ir::Value value = GetValueByName(GetNameById(var_id)); ::ir::Value value = GetValueByName(GetNameById(var_id));
if (value && value.GetDefiningOp()->attributes().count("is_persisable") && bool is_parameter = false;
value.GetDefiningOp() if (value) {
->attributes() for (auto item : parameter_values_) {
.at("is_persisable") if (item == value) {
.dyn_cast<::ir::BoolAttribute>() is_parameter = true;
.data()) { break;
}
}
}
if (is_parameter) {
VLOG(4) << "value " << value.impl() << " is a parameter, skip gc";
continue; continue;
} }
if (is_ready) { if (is_ready) {
VLOG(6) << "Async delete variable with name : " << GetNameById(var_id); VLOG(6) << "Async delete variable with name : " << GetNameById(var_id);
gc_->Add(refs_[var_id]->Var(), instr); gc_->Add(refs_[var_id]->Var(), instr);
...@@ -1857,7 +1890,10 @@ void NewIRInterpreter::CalculateLastLiveOps() { ...@@ -1857,7 +1890,10 @@ void NewIRInterpreter::CalculateLastLiveOps() {
instr->Outputs(); instr->Outputs();
std::unordered_multimap<::ir::Value, std::vector<int>> ins_and_outs{ std::unordered_multimap<::ir::Value, std::vector<int>> ins_and_outs{
ins.begin(), ins.end()}; ins.begin(), ins.end()};
ins_and_outs.insert(outs.begin(), outs.end());
if (instr->Name() != "pd.fetch") {
ins_and_outs.insert(outs.begin(), outs.end());
}
for (auto& item : ins_and_outs) { for (auto& item : ins_and_outs) {
for (auto var_id : item.second) { for (auto var_id : item.second) {
...@@ -1989,7 +2025,9 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names, ...@@ -1989,7 +2025,9 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
&value_2_var_name_, &value_2_var_name_,
&variable_2_var_name_, &variable_2_var_name_,
&var_name_2_id_, &var_name_2_id_,
&variable_list_); &variable_list_,
&parameter_values_);
VLOG(4) << "Done BuildScope";
VLOG(4) << DebugValueInfo(); VLOG(4) << DebugValueInfo();
BuildInstruction(); BuildInstruction();
...@@ -1999,9 +2037,22 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names, ...@@ -1999,9 +2037,22 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
VLOG(4) << "Done PreAnalysis"; VLOG(4) << "Done PreAnalysis";
// Run // Run
BetaRunImpl(); if (FLAGS_enable_new_ir_in_executor_loop_run) {
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
"with for_loop version.";
LoopRunImpl();
} else {
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
"with trace version.";
TraceRunImpl();
}
is_build_ = true;
} else { } else {
BetaRunImpl(); if (FLAGS_enable_new_ir_in_executor_loop_run) {
LoopRunImpl();
} else {
TraceRunImpl();
}
} }
if (HasLocalScope()) { if (HasLocalScope()) {
...@@ -2010,31 +2061,52 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names, ...@@ -2010,31 +2061,52 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
// return Fetch Tensors // return Fetch Tensors
Scope* inner_scope = InnerScope(); Scope* inner_scope = InnerScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); if (FLAGS_enable_new_ir_in_executor) {
if (fetch_var && need_fetch) { framework::FetchList fetch_res;
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
#ifdef PADDLE_WITH_CUDA if (need_fetch) {
if (platform::IsCUDAGraphCapturing()) { for (auto& var_name : fetch_var_names_) {
PADDLE_ENFORCE_EQ(fetch_list.empty(), auto* var = inner_scope->FindVar(var_name);
true, VLOG(0) << "fetch " << var_name << "[" << var << "]";
platform::errors::InvalidArgument( fetch_res.push_back(var->Get<phi::DenseTensor>());
"Cannot fetch data when using CUDA Graph.")); }
} }
#endif VLOG(4) << "get fetch list size: " << fetch_res.size();
return fetch_list; return fetch_res;
} else { } else {
return {}; auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var && need_fetch) {
auto fetch_list =
std::move(*fetch_var->GetMutable<framework::FetchList>());
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
PADDLE_ENFORCE_EQ(fetch_list.empty(),
true,
platform::errors::InvalidArgument(
"Cannot fetch data when using CUDA Graph."));
}
#endif
return fetch_list;
} else {
return {};
}
} }
} }
void NewIRInterpreter::NewIrLoopRunImpl() { void NewIRInterpreter::LoopRunImpl() {
for (size_t instr_id = 0; instr_id < vec_instruction_base_.size(); // lazy initialization of gc, do not create gc is the program only run once
++instr_id) { if (!gc_) {
vec_instruction_base_[instr_id]->Run(); gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_base_);
} }
interpreter::ResetAtomicGuard guard(&deps_, &refs_);
VLOG(4) << "Loop Instruction List";
LoopRunInstructionList(vec_instruction_base_);
VLOG(4) << "Done LoopRunImpl";
} }
void NewIRInterpreter::BetaRunImpl() { void NewIRInterpreter::TraceRunImpl() {
// lazy initialization of gc, do not create gc is the program only run once // lazy initialization of gc, do not create gc is the program only run once
if (!gc_) { if (!gc_) {
gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_base_); gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_base_);
...@@ -2043,11 +2115,53 @@ void NewIRInterpreter::BetaRunImpl() { ...@@ -2043,11 +2115,53 @@ void NewIRInterpreter::BetaRunImpl() {
interpreter::ResetAtomicGuard guard(&deps_, &refs_); interpreter::ResetAtomicGuard guard(&deps_, &refs_);
VLOG(4) << "Tracing Instruction List"; VLOG(4) << "Tracing Instruction List";
TraceInstructionList(vec_instruction_base_); TraceRunInstructionList(vec_instruction_base_);
VLOG(4) << "Done BetaRunImpl"; VLOG(4) << "Done TraceRunImpl";
} }
void NewIRInterpreter::TraceInstructionList( void NewIRInterpreter::LoopRunInstructionList(
const std::vector<std::unique_ptr<InstructionBase>>& vec_instr) {
unfinished_op_number_ = vec_instr.size();
if (unfinished_op_number_ == 0) {
VLOG(4) << "No op to run, return";
return;
}
exception_holder_.Clear();
for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) {
// NOTE(zhiqiu): hot fix for jit input var
RecordMemcpyD2H(vec_instr.at(i).get());
}
}
for (size_t idx = 0; idx < vec_instr.size(); idx++) {
InstructionBase* instr_node = vec_instr[idx].get();
VLOG(6) << "Run InstructionBase " << idx;
RunInstructionBase(instr_node);
if (UNLIKELY(exception_holder_.IsCaught())) {
VLOG(4) << "Exception caught";
break;
}
}
if (UNLIKELY(exception_holder_.IsCaught())) {
VLOG(1) << "Exception caught " << exception_holder_.Type();
PADDLE_ENFORCE_EQ(
main_thread_blocker_.Clear(),
0,
platform::errors::PreconditionNotMet(
"main_thread_blocker_.Clear() return -1, clear failed"));
VLOG(4) << "clear ok";
exception_holder_.ReThrow();
}
VLOG(4) << "Done LoopRunInstructionList";
}
void NewIRInterpreter::TraceRunInstructionList(
const std::vector<std::unique_ptr<InstructionBase>>& vec_instr) { const std::vector<std::unique_ptr<InstructionBase>>& vec_instr) {
unfinished_op_number_ = vec_instr.size(); unfinished_op_number_ = vec_instr.size();
if (unfinished_op_number_ == 0) { if (unfinished_op_number_ == 0) {
...@@ -2087,7 +2201,7 @@ void NewIRInterpreter::TraceInstructionList( ...@@ -2087,7 +2201,7 @@ void NewIRInterpreter::TraceInstructionList(
VLOG(4) << "clear ok"; VLOG(4) << "clear ok";
exception_holder_.ReThrow(); exception_holder_.ReThrow();
} }
VLOG(4) << "Done TraceInstructionList"; VLOG(4) << "Done TraceRunInstructionList";
} }
void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) { void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) {
...@@ -2145,5 +2259,14 @@ void NewIRInterpreter::PreAnalysis() { ...@@ -2145,5 +2259,14 @@ void NewIRInterpreter::PreAnalysis() {
VLOG(4) << "Done AnalyseExecuteOrderForTrace"; VLOG(4) << "Done AnalyseExecuteOrderForTrace";
} }
::ir::Value NewIRInterpreter::GetValueByName(const std::string& var_name) {
for (auto kv : value_2_var_name_) {
if (kv.second == var_name) {
return kv.first;
}
}
return nullptr;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -35,6 +35,7 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -35,6 +35,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
public: public:
NewIRInterpreter(const platform::Place& place, NewIRInterpreter(const platform::Place& place,
const std::vector<std::string>& fetch_var_names,
std::unique_ptr<::ir::Program> ir_prog, std::unique_ptr<::ir::Program> ir_prog,
Scope* scope, Scope* scope,
const ExecutionConfig& execution_config = ExecutionConfig()); const ExecutionConfig& execution_config = ExecutionConfig());
...@@ -214,11 +215,14 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -214,11 +215,14 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void BuildInstructionDependences(); void BuildInstructionDependences();
void NewIrLoopRunImpl(); void LoopRunImpl();
void BetaRunImpl(); void TraceRunImpl();
void TraceInstructionList( void TraceRunInstructionList(
const std::vector<std::unique_ptr<InstructionBase>>& vec_instr);
void LoopRunInstructionList(
const std::vector<std::unique_ptr<InstructionBase>>& vec_instr); const std::vector<std::unique_ptr<InstructionBase>>& vec_instr);
void RunInstructionBase(InstructionBase* instr_node); void RunInstructionBase(InstructionBase* instr_node);
...@@ -251,6 +255,12 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -251,6 +255,12 @@ class NewIRInterpreter : public InterpreterBaseImpl {
interpreter::NewIrDependencyBuilder ir_dependency_builder_; interpreter::NewIrDependencyBuilder ir_dependency_builder_;
interpreter::NewIrStreamAnalyzer ir_stream_analyzer_; interpreter::NewIrStreamAnalyzer ir_stream_analyzer_;
std::vector<std::string> fetch_var_names_;
// Note(zhangbo): set_parameter_op's input and get_parameter_op's output
// belongs to a parameter and cannot GC.
std::vector<::ir::Value> parameter_values_;
}; };
} // namespace framework } // namespace framework
......
...@@ -65,10 +65,37 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -65,10 +65,37 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
if (FLAGS_enable_new_ir_in_executor) { if (FLAGS_enable_new_ir_in_executor) {
VLOG(6) << "begin to translate" << std::endl; VLOG(6) << "begin to translate" << std::endl;
auto base_program = paddle::TranslateLegacyProgramToProgram(*program); auto base_program = paddle::TranslateLegacyProgramToProgram(*program);
auto block = base_program->block();
for (auto it = block->begin(); it != block->end(); ++it) {
if ((*it)->name() == "pd.fetch") {
size_t index = (*it)
->attributes()
.at("col")
.dyn_cast<ir::Int32Attribute>()
.data();
if (fetch_var_names_.size() < index + 1) {
fetch_var_names_.resize(index + 1);
}
fetch_var_names_[index] = (*it)
->attributes()
.at("name")
.dyn_cast<ir::StrAttribute>()
.AsString() +
"@fetch";
}
}
auto kernel_program = auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(base_program.get(), place); paddle::dialect::PdOpLowerToKernelPass(base_program.get(), place);
interpretercores_.emplace_back(std::make_shared<InterpreterCore>( interpretercores_.emplace_back(
place_, std::move(kernel_program), scope_, execution_config)); std::make_shared<InterpreterCore>(place_,
fetch_var_names_,
std::move(kernel_program),
scope_,
execution_config));
} else { } else {
interpretercores_.emplace_back( interpretercores_.emplace_back(
std::make_shared<InterpreterCore>(place_, std::make_shared<InterpreterCore>(place_,
...@@ -130,11 +157,22 @@ paddle::framework::FetchList StandaloneExecutor::Run( ...@@ -130,11 +157,22 @@ paddle::framework::FetchList StandaloneExecutor::Run(
} }
// return Fetch Tensors // return Fetch Tensors
auto* fetch_var = scope_->FindVar(interpreter::kFetchVarName); if (FLAGS_enable_new_ir_in_executor) {
if (fetch_var) { framework::FetchList fetch_res;
return std::move(*fetch_var->GetMutable<framework::FetchList>());
for (auto& var_name : fetch_var_names_) {
auto* var = scope_->FindVar(var_name);
fetch_res.push_back(var->Get<phi::DenseTensor>());
}
return fetch_res;
} else { } else {
return {}; auto* fetch_var = scope_->FindVar(interpreter::kFetchVarName);
if (fetch_var) {
return std::move(*fetch_var->GetMutable<framework::FetchList>());
} else {
return {};
}
} }
} }
......
...@@ -50,6 +50,8 @@ class StandaloneExecutor { ...@@ -50,6 +50,8 @@ class StandaloneExecutor {
std::vector<std::shared_ptr<InterpreterCore>> interpretercores_; std::vector<std::shared_ptr<InterpreterCore>> interpretercores_;
Scope* scope_; Scope* scope_;
std::vector<std::string> fetch_var_names_;
}; };
} // namespace framework } // namespace framework
......
...@@ -69,7 +69,8 @@ class PhiKernelAdaptor { ...@@ -69,7 +69,8 @@ class PhiKernelAdaptor {
&value_2_var_name, &value_2_var_name,
&variable_2_var_name, &variable_2_var_name,
&var_name_2_id, &var_name_2_id,
&variable_list); &variable_list,
nullptr);
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
......
...@@ -56,9 +56,11 @@ void AddNewData(ir::Value value, ...@@ -56,9 +56,11 @@ void AddNewData(ir::Value value,
std::vector<paddle::framework::Variable*>* variable_list) { std::vector<paddle::framework::Variable*>* variable_list) {
value_2_var_name->emplace(value, name); value_2_var_name->emplace(value, name);
variable_2_var_name->emplace(var, name); variable_2_var_name->emplace(var, name);
auto id = var_name_2_id->size(); if (var_name_2_id->count(name) == 0) {
var_name_2_id->emplace(name, id); auto id = var_name_2_id->size();
variable_list->push_back(var); var_name_2_id->emplace(name, id);
variable_list->push_back(var);
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
variable_list->size(), variable_list->size(),
var_name_2_id->size(), var_name_2_id->size(),
...@@ -104,11 +106,8 @@ paddle::framework::Variable* CreateVar( ...@@ -104,11 +106,8 @@ paddle::framework::Variable* CreateVar(
std::vector<paddle::framework::Variable*>* variable_list) { std::vector<paddle::framework::Variable*>* variable_list) {
Operation* def_op = value.GetDefiningOp(); Operation* def_op = value.GetDefiningOp();
bool is_persisable = false; bool is_persisable = false;
if (def_op->attributes().count("is_persisable")) { if (def_op->name() == "builtin.set_parameter") {
is_persisable = def_op->attributes() is_persisable = true;
.at("is_persisable")
.dyn_cast<ir::BoolAttribute>()
.data();
} }
paddle::framework::Variable* var = nullptr; paddle::framework::Variable* var = nullptr;
...@@ -218,7 +217,8 @@ void HandleForSpecialOp( ...@@ -218,7 +217,8 @@ void HandleForSpecialOp(
std::unordered_map<const paddle::framework::Variable*, std::string>* std::unordered_map<const paddle::framework::Variable*, std::string>*
variable_2_var_name, variable_2_var_name,
std::map<std::string, int>* var_name_2_id, std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) { std::vector<paddle::framework::Variable*>* variable_list,
std::vector<::ir::Value>* parameter_values) {
std::string op_name = op->name(); std::string op_name = op->name();
if (op->attributes().count("op_name")) { if (op->attributes().count("op_name")) {
op_name = op_name =
...@@ -227,13 +227,21 @@ void HandleForSpecialOp( ...@@ -227,13 +227,21 @@ void HandleForSpecialOp(
if (op_name == "pd.fetch") { if (op_name == "pd.fetch") {
// fetch is a very special op, with no output // fetch is a very special op, with no output
auto var = const_cast<paddle::framework::Scope*>(inner_scope->root()) auto fetch_src_name =
->Var("fetch"); op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString();
VLOG(6) << "Create var: fetch in scope " << inner_scope->root();
auto fetch_list = var->GetMutable<paddle::framework::FetchList>(); auto fetch_var_name = fetch_src_name + "@fetch";
int index = auto* var = const_cast<paddle::framework::Scope*>(inner_scope->root())
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data(); ->Var(fetch_var_name);
fetch_list->resize(index + 1); var->GetMutable<phi::DenseTensor>();
auto value = op->result(0);
AddNewData(value,
fetch_var_name,
var,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
} }
if (op_name == "pd.feed") { if (op_name == "pd.feed") {
...@@ -262,6 +270,7 @@ void HandleForSpecialOp( ...@@ -262,6 +270,7 @@ void HandleForSpecialOp(
op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString(); op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString();
auto value = op->result(0); auto value = op->result(0);
paddle::framework::Variable* var = inner_scope->FindVar(var_name); paddle::framework::Variable* var = inner_scope->FindVar(var_name);
PADDLE_ENFORCE(var, PADDLE_ENFORCE(var,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
...@@ -322,6 +331,8 @@ void HandleForSpecialOp( ...@@ -322,6 +331,8 @@ void HandleForSpecialOp(
if (inner_scope->root()->FindVar(param_name) == nullptr) { if (inner_scope->root()->FindVar(param_name) == nullptr) {
const_cast<paddle::framework::Scope*>(inner_scope->root()) const_cast<paddle::framework::Scope*>(inner_scope->root())
->Rename(orig_name, param_name); ->Rename(orig_name, param_name);
VLOG(6) << "set_parameter rename var: " << orig_name << " -> "
<< param_name;
} }
RenameData(value, RenameData(value,
param_name, param_name,
...@@ -329,6 +340,10 @@ void HandleForSpecialOp( ...@@ -329,6 +340,10 @@ void HandleForSpecialOp(
value_2_var_name, value_2_var_name,
variable_2_var_name, variable_2_var_name,
var_name_2_id); var_name_2_id);
if (parameter_values) {
parameter_values->push_back(value);
}
} }
if (op_name == "pd.shadow_output") { if (op_name == "pd.shadow_output") {
...@@ -359,6 +374,7 @@ void HandleForSpecialOp( ...@@ -359,6 +374,7 @@ void HandleForSpecialOp(
.dyn_cast<ir::StrAttribute>() .dyn_cast<ir::StrAttribute>()
.AsString(); .AsString();
auto value = op->result(0); auto value = op->result(0);
paddle::framework::Variable* var = inner_scope->FindVar(param_name); paddle::framework::Variable* var = inner_scope->FindVar(param_name);
AddNewData(value, AddNewData(value,
param_name, param_name,
...@@ -367,6 +383,10 @@ void HandleForSpecialOp( ...@@ -367,6 +383,10 @@ void HandleForSpecialOp(
variable_2_var_name, variable_2_var_name,
var_name_2_id, var_name_2_id,
variable_list); variable_list);
if (parameter_values) {
parameter_values->push_back(value);
}
} }
if (op_name == "builtin.slice") { if (op_name == "builtin.slice") {
...@@ -452,7 +472,8 @@ void BuildScope(const ir::Block& block, ...@@ -452,7 +472,8 @@ void BuildScope(const ir::Block& block,
std::unordered_map<const paddle::framework::Variable*, std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name, std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id, std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) { std::vector<paddle::framework::Variable*>* variable_list,
std::vector<::ir::Value>* parameter_values) {
VLOG(4) << "***** [before build] scope" VLOG(4) << "***** [before build] scope"
<< "(" << inner_scope << ") ******\n" << "(" << inner_scope << ") ******\n"
<< paddle::framework::GenScopeTreeDebugInfo( << paddle::framework::GenScopeTreeDebugInfo(
...@@ -480,7 +501,8 @@ void BuildScope(const ir::Block& block, ...@@ -480,7 +501,8 @@ void BuildScope(const ir::Block& block,
value_2_var_name, value_2_var_name,
variable_2_var_name, variable_2_var_name,
var_name_2_id, var_name_2_id,
variable_list); variable_list,
parameter_values);
continue; continue;
} }
......
...@@ -49,7 +49,8 @@ void BuildScope(const ir::Block& block, ...@@ -49,7 +49,8 @@ void BuildScope(const ir::Block& block,
std::unordered_map<const paddle::framework::Variable*, std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name, std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id, std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list); std::vector<paddle::framework::Variable*>* variable_list,
std::vector<::ir::Value>* parameter_values);
void BuildRuntimeContext( void BuildRuntimeContext(
ir::Operation* op, ir::Operation* op,
...@@ -285,47 +286,36 @@ void BuildPhiContext(ir::Operation* op, ...@@ -285,47 +286,36 @@ void BuildPhiContext(ir::Operation* op,
} }
// TODO(phlrain): use var type instead of op name // TODO(phlrain): use var type instead of op name
if (op->attributes().count("op_name") && for (size_t i = 0; i < op->num_results(); ++i) {
(op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString() == ir::Value out_ptr = op->result(i);
"pd.fetch")) { auto name = name_map.at(out_ptr);
// process fetch op VLOG(6) << "ctx->EmplaceBackOutput: " << name;
auto fetch_var = inner_scope->FindVar("fetch"); auto out_type = out_ptr.type();
auto* fetch_list = fetch_var->GetMutable<paddle::framework::FetchList>(); if (!out_type) {
int index = phi::DenseTensor* ptr = nullptr;
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data(); OutType out_ptr(ptr);
auto* out_tensor = &(PADDLE_GET(phi::DenseTensor, fetch_list->at(index))); ctx->EmplaceBackOutput(out_ptr);
ctx->EmplaceBackOutput(out_tensor); } else if (out_type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
} else { ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>(
for (size_t i = 0; i < op->num_results(); ++i) { &(inner_scope->FindVar(name)->Get<phi::DenseTensor>()))));
ir::Value out_ptr = op->result(i); } else if (out_type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
auto name = name_map.at(out_ptr); ctx->EmplaceBackOutput(OutType(const_cast<phi::SelectedRows*>(
VLOG(6) << "ctx->EmplaceBackOutput: " << name; &(inner_scope->FindVar(name)->Get<phi::SelectedRows>()))));
auto out_type = out_ptr.type(); } else if (out_type.isa<ir::VectorType>()) {
if (!out_type) { OutListType outputs;
phi::DenseTensor* ptr = nullptr; auto& variable_array =
OutType out_ptr(ptr); scope->FindVar(name)->Get<paddle::framework::VariableRefArray>();
ctx->EmplaceBackOutput(out_ptr); for (size_t i = 0; i < variable_array.size(); ++i) {
} else if (out_type.isa<paddle::dialect::AllocatedDenseTensorType>()) { outputs.emplace_back(OutType(const_cast<phi::DenseTensor*>(
ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>( &(variable_array[i]->Get<phi::DenseTensor>()))));
&(inner_scope->FindVar(name)->Get<phi::DenseTensor>()))));
} else if (out_type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::SelectedRows*>(
&(inner_scope->FindVar(name)->Get<phi::SelectedRows>()))));
} else if (out_type.isa<ir::VectorType>()) {
OutListType outputs;
auto& variable_array =
scope->FindVar(name)->Get<paddle::framework::VariableRefArray>();
for (size_t i = 0; i < variable_array.size(); ++i) {
outputs.emplace_back(OutType(const_cast<phi::DenseTensor*>(
&(variable_array[i]->Get<phi::DenseTensor>()))));
}
ctx->EmplaceBackOutputs(outputs);
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support DenseTensor and vector "));
} }
ctx->EmplaceBackOutputs(outputs);
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support DenseTensor and vector "));
} }
} }
VLOG(6) << "Done build phi context"; VLOG(6) << "Done build phi context";
} }
......
...@@ -71,11 +71,32 @@ class ConstantFoldingPattern : public ir::RewritePattern { ...@@ -71,11 +71,32 @@ class ConstantFoldingPattern : public ir::RewritePattern {
ir::Program* program = op->GetParentProgram(); ir::Program* program = op->GetParentProgram();
auto temp_program = BuildProgramFromOperation(op); auto temp_program = BuildProgramFromOperation(op);
std::vector<std::string> fetch_var_names;
auto block = temp_program->block();
for (auto it = block->begin(); it != block->end(); ++it) {
if ((*it)->name() == "pd.fetch") {
size_t index =
(*it)->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
if (fetch_var_names.size() < index + 1) {
fetch_var_names.resize(index + 1);
}
fetch_var_names[index] = (*it)
->attributes()
.at("name")
.dyn_cast<ir::StrAttribute>()
.AsString() +
"@fetch";
}
}
// Execute program // Execute program
paddle::framework::interpreter::ExecutionConfig exe_config; paddle::framework::interpreter::ExecutionConfig exe_config;
exe_config.create_local_scope = false; exe_config.create_local_scope = false;
paddle::framework::InterpreterCore core( paddle::framework::InterpreterCore core(
phi::CPUPlace{}, phi::CPUPlace{},
fetch_var_names,
paddle::dialect::PdOpLowerToKernelPass(temp_program.get()), paddle::dialect::PdOpLowerToKernelPass(temp_program.get()),
&scope_, &scope_,
exe_config); exe_config);
......
...@@ -1280,3 +1280,28 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor, ...@@ -1280,3 +1280,28 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor,
PHI_DEFINE_EXPORTED_bool(enable_new_ir_api, PHI_DEFINE_EXPORTED_bool(enable_new_ir_api,
false, false,
"Enable new IR API in Python"); "Enable new IR API in Python");
/**
* Using new IR in executor FLAG
* Name: enable_new_ir_in_executor_beta_run
* Since Version: 2.6.0
* Value Range: bool, default=true
* Example:
* Note: If Ture, executor will use new IR and run in beta version.
*/
PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor_beta_run,
true,
"Enable new IR in executor");
/**
* Using new IR in executor FLAG
* Name: enable_new_ir_in_executor_loop_run
* Since Version: 2.6.0
* Value Range: bool, default=false
* Example:
* Note: If Ture, executor will use new IR and run in beta version by for loop
* version.
*/
PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor_loop_run,
false,
"Enable new IR in executor");
...@@ -1093,7 +1093,7 @@ TEST(pattern_rewrite, Patterns) { ...@@ -1093,7 +1093,7 @@ TEST(pattern_rewrite, Patterns) {
ir::PassManager pm(ctx); ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>()); pm.AddPass(std::make_unique<TestPass>());
pm.AddPass(ir::CreateConstantFoldingPass()); // pm.AddPass(ir::CreateConstantFoldingPass());
pm.AddPass(ir::CreateDeadCodeEliminationPass()); pm.AddPass(ir::CreateDeadCodeEliminationPass());
pm.EnablePassTiming(); pm.EnablePassTiming();
pm.EnableIRPrinting(); pm.EnableIRPrinting();
......
...@@ -69,7 +69,7 @@ TEST(StandaloneExecutor, run) { ...@@ -69,7 +69,7 @@ TEST(StandaloneExecutor, run) {
Scope scope; Scope scope;
ProgramDesc prog_desc; ProgramDesc prog_desc;
InterpreterCore test_core(place, std::move(kernel_program), &scope); InterpreterCore test_core(place, {}, std::move(kernel_program), &scope);
std::stringstream os; std::stringstream os;
os << reinterpret_cast<NewIRInterpreter*>( os << reinterpret_cast<NewIRInterpreter*>(
...@@ -110,7 +110,7 @@ TEST(StandaloneExecutor, run_inplace_sqrt) { ...@@ -110,7 +110,7 @@ TEST(StandaloneExecutor, run_inplace_sqrt) {
auto place = platform::CPUPlace(); auto place = platform::CPUPlace();
Scope scope; Scope scope;
InterpreterCore test_core(place, std::move(kernel_program), &scope); InterpreterCore test_core(place, {}, std::move(kernel_program), &scope);
std::stringstream os; std::stringstream os;
os << reinterpret_cast<NewIRInterpreter*>( os << reinterpret_cast<NewIRInterpreter*>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册