未验证 提交 bd7f28bb 编写于 作者: R Ruibiao Chen 提交者: GitHub

Multiplex Workqueue for InterpreterCore (#43660)

* Multiplex Workqueue for InterpreterCore

* Delete ResetWorkQueueOptions

* Update code format
上级 66a28e13
...@@ -110,7 +110,8 @@ if(WITH_GPU ...@@ -110,7 +110,8 @@ if(WITH_GPU
sgd_op sgd_op
squared_l2_norm_op squared_l2_norm_op
memcpy_h2d_op memcpy_h2d_op
memcpy_d2h_op) memcpy_d2h_op
fetch_v2_op)
# All deps of the operators above, part of GLOB_OPERATOR_DEPS. # All deps of the operators above, part of GLOB_OPERATOR_DEPS.
set(OP_DEPS generator softmax selected_rows_functor jit_kernel_helper set(OP_DEPS generator softmax selected_rows_functor jit_kernel_helper
......
...@@ -29,9 +29,11 @@ ...@@ -29,9 +29,11 @@
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, true, PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace,
true,
"Use inplace in new executor"); "Use inplace in new executor");
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_local_scope, true, PADDLE_DEFINE_EXPORTED_bool(new_executor_use_local_scope,
true,
"Use local_scope in new executor(especially used " "Use local_scope in new executor(especially used "
"in UT), can turn off for better performance"); "in UT), can turn off for better performance");
...@@ -99,9 +101,8 @@ InterpreterCore::~InterpreterCore() { ...@@ -99,9 +101,8 @@ InterpreterCore::~InterpreterCore() {
// cancle gc's thread // cancle gc's thread
gc_.reset(nullptr); gc_.reset(nullptr);
async_work_queue_.reset(nullptr); async_work_queue_.reset();
VLOG(4) << "~InterpreterCore(): " << this; VLOG(4) << "~InterpreterCore(): " << this << " on " << place_;
VLOG(4) << " on" << place_;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
...@@ -110,8 +111,29 @@ InterpreterCore::~InterpreterCore() { ...@@ -110,8 +111,29 @@ InterpreterCore::~InterpreterCore() {
#endif #endif
} }
void InterpreterCore::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) { interpreter::CostInfo InterpreterCore::DryRun(
copy_program_ = prog; const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) {
global_scope_->SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, true);
interpreter::CostInfo cost_info;
{
interpreter::ProfilerGuard(place_, &cost_info);
// For the program that only run once, it is no need to
// create work_queue, so the async_work_queue_ is created
// until the second step run.
async_work_queue_ = GetWorkQueue();
ExecuteInstructionList(vec_instruction_);
platform::DeviceContextPool::Instance().Get(place_)->Wait();
}
if (create_local_scope_) {
ClearLoDTensorArrayInLocalScope();
}
return cost_info;
} }
paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run(
...@@ -131,14 +153,7 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -131,14 +153,7 @@ paddle::framework::FetchList InterpreterCore::Run(
// For the program that only run once, it is no need to // For the program that only run once, it is no need to
// create work_queue, so the async_work_queue_ is created // create work_queue, so the async_work_queue_ is created
// until the second step run. // until the second step run.
if (async_work_queue_ == nullptr) { async_work_queue_ = GetWorkQueue();
async_work_queue_ = std::make_unique<interpreter::AsyncWorkQueue>(
kHostNumThreads, kDeviceNumThreads, &main_thread_blocker_);
// prepare for the first time.
async_work_queue_->PrepareAtomicDeps(dependecy_count_);
async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo());
}
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
} }
...@@ -172,11 +187,14 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -172,11 +187,14 @@ paddle::framework::FetchList InterpreterCore::Run(
// scope? // scope?
} }
global_scope_->SetLocalScope(local_scope_); global_scope_->SetLocalScope(local_scope_);
paddle::framework::interpreter::build_variable_scope(block_, global_scope_, paddle::framework::interpreter::build_variable_scope(
create_local_scope_); block_, global_scope_, create_local_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list( paddle::framework::interpreter::build_op_func_list(place_,
place_, block_, skip_gc_vars_, &op_func_nodes, global_scope_, block_,
skip_gc_vars_,
&op_func_nodes,
global_scope_,
create_local_scope_); create_local_scope_);
is_build_ = true; is_build_ = true;
SetFeedVarsInplaceSkip(feed_names); SetFeedVarsInplaceSkip(feed_names);
...@@ -190,13 +208,7 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -190,13 +208,7 @@ paddle::framework::FetchList InterpreterCore::Run(
// For the program that only run once, it is no need to // For the program that only run once, it is no need to
// create work_queue, so the async_work_queue_ is created // create work_queue, so the async_work_queue_ is created
// until the second step run. // until the second step run.
if (async_work_queue_ == nullptr) { async_work_queue_ = GetWorkQueue();
async_work_queue_ = std::make_unique<interpreter::AsyncWorkQueue>(
kHostNumThreads, kDeviceNumThreads, &main_thread_blocker_);
// prepare for the first time.
async_work_queue_->PrepareAtomicDeps(dependecy_count_);
async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo());
}
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
} }
...@@ -213,15 +225,115 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -213,15 +225,115 @@ paddle::framework::FetchList InterpreterCore::Run(
return std::move(*fetch_var->GetMutable<framework::FetchList>()); return std::move(*fetch_var->GetMutable<framework::FetchList>());
} }
// At the end of each step, the holder of Tensor in LoDTensorArray is null. void InterpreterCore::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) {
// Clear these Tensors and leave LoDTensorArray empty, otherwise an exception copy_program_ = prog;
// will occur in the next step }
void InterpreterCore::ClearLoDTensorArrayInLocalScope() {
auto vars = local_scope_->LocalVars(); void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src) {
for (auto var : vars) { async_work_queue_ = src->GetWorkQueue();
if (var->IsType<LoDTensorArray>()) { VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << &src
auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>(); << ") to InterpreterCore(" << this << ")";
lod_tensor_arr->clear(); }
bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) {
if (!global_scope_->VarDesc(var_index)) {
return input_var2op_info_.at(var_index).size() == 1;
} else {
int is_input_cnt = 0;
for (auto inst_id : input_var2op_info_.at(var_index)) {
OpInOutInfo info;
info.Build(vec_instruction_.at(inst_id).OpBase());
if (info.IsInArgBufferNeeded(global_scope_->VarDesc(var_index)->Name())) {
is_input_cnt++;
}
}
return is_input_cnt == 1;
}
}
std::shared_ptr<interpreter::AsyncWorkQueue> InterpreterCore::GetWorkQueue() {
if (async_work_queue_ == nullptr) {
async_work_queue_ = std::make_shared<interpreter::AsyncWorkQueue>(
kHostNumThreads, kDeviceNumThreads, &main_thread_blocker_);
}
return async_work_queue_;
}
void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
VariableValueMap ins_map;
for (auto& var_name_item : instr_node->Inputs()) {
std::vector<Variable*> input_vars;
input_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) {
input_vars.emplace_back(global_scope_->Var(id));
}
ins_map.emplace(var_name_item.first, std::move(input_vars));
}
VariableValueMap outs_map;
for (auto& var_name_item : instr_node->Outputs()) {
std::vector<Variable*> out_vars;
out_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) {
out_vars.emplace_back(global_scope_->Var(id));
}
outs_map.emplace(var_name_item.first, std::move(out_vars));
}
// set runtime_ctx and infershape_ctx_
if (instr_node->OpBase()->Type() == "cinn_launch") { // OP use scope in
// kernel
Scope* local_scope = create_local_scope_
? global_scope_->GetMutableLocalScope()
: global_scope_->GetMutableScope();
instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope);
} else {
instr_node->ResetContext(ins_map, outs_map);
}
}
void InterpreterCore::BuildInplace() {
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
auto& instr = vec_instruction_[i];
auto* op_base = instr.OpBase();
if (!op_base->Info().infer_inplace_) {
continue;
}
auto in_to_outs = op_base->Info().infer_inplace_(
platform::is_gpu_place(instr.DeviceContext().GetPlace()));
auto& inputs = instr.Inputs();
auto& outputs = instr.Outputs();
for (auto& pair : in_to_outs) {
auto iter = inputs.find(pair.first);
if (iter != inputs.end() && !iter->second.empty()) {
auto in_var_desc = global_scope_->VarDesc(iter->second[0]);
if (in_var_desc && in_var_desc->Persistable()) {
continue;
}
if (global_scope_->GetVarSikpInplace(iter->second[0])) {
continue;
}
if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) {
auto iterout = outputs.find(pair.second);
if (iterout != outputs.end() && !iterout->second.empty()) {
auto invar = global_scope_->Var(iter->second[0]);
auto outvar = global_scope_->Var(iterout->second[0]);
if (invar && outvar && invar->IsType<LoDTensor>() &&
outvar->IsType<LoDTensor>()) {
instr.AddInplace(invar, outvar);
VLOG(3) << "inplace " << vec_instruction_[i].OpBase()->Type()
<< " " << global_scope_->GetNameById(iter->second[0])
<< " -> "
<< global_scope_->GetNameById(iterout->second[0])
<< std::endl;
}
}
}
}
} }
} }
} }
...@@ -244,6 +356,19 @@ void InterpreterCore::BuildOperatorDependences() { ...@@ -244,6 +356,19 @@ void InterpreterCore::BuildOperatorDependences() {
} }
} }
// At the end of each step, the holder of Tensor in LoDTensorArray is null.
// Clear these Tensors and leave LoDTensorArray empty, otherwise an exception
// will occur in the next step
void InterpreterCore::ClearLoDTensorArrayInLocalScope() {
auto vars = local_scope_->LocalVars();
for (auto var : vars) {
if (var->IsType<LoDTensorArray>()) {
auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
lod_tensor_arr->clear();
}
}
}
void InterpreterCore::Convert( void InterpreterCore::Convert(
std::vector<paddle::framework::OpFuncNode>* op_func_nodes) { std::vector<paddle::framework::OpFuncNode>* op_func_nodes) {
auto& vec_meta_info = global_scope_->MutableVecMetaInfo(); auto& vec_meta_info = global_scope_->MutableVecMetaInfo();
...@@ -376,101 +501,18 @@ void InterpreterCore::Convert( ...@@ -376,101 +501,18 @@ void InterpreterCore::Convert(
if (FLAGS_new_executor_use_inplace && !inplaced) { if (FLAGS_new_executor_use_inplace && !inplaced) {
BuildInplace(); BuildInplace();
} }
}
bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) {
if (!global_scope_->VarDesc(var_index)) {
return input_var2op_info_.at(var_index).size() == 1;
} else {
int is_input_cnt = 0;
for (auto inst_id : input_var2op_info_.at(var_index)) {
OpInOutInfo info;
info.Build(vec_instruction_.at(inst_id).OpBase());
if (info.IsInArgBufferNeeded(global_scope_->VarDesc(var_index)->Name())) {
is_input_cnt++;
}
}
return is_input_cnt == 1;
}
}
void InterpreterCore::BuildInplace() {
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
auto& instr = vec_instruction_[i];
auto* op_base = instr.OpBase();
if (!op_base->Info().infer_inplace_) {
continue;
}
auto in_to_outs = op_base->Info().infer_inplace_( // prepare for the first time.
platform::is_gpu_place(instr.DeviceContext().GetPlace())); std::promise<std::unique_ptr<AtomicVectorSizeT>> deps_promise =
std::promise<std::unique_ptr<AtomicVectorSizeT>>();
auto& inputs = instr.Inputs(); atomic_deps_ = deps_promise.get_future();
auto& outputs = instr.Outputs(); deps_promise.set_value(interpreter::PrepareAtomicDeps(dependecy_count_));
for (auto& pair : in_to_outs) {
auto iter = inputs.find(pair.first); std::promise<std::unique_ptr<AtomicVectorSizeT>> var_ref_promise =
if (iter != inputs.end() && !iter->second.empty()) { std::promise<std::unique_ptr<AtomicVectorSizeT>>();
auto in_var_desc = global_scope_->VarDesc(iter->second[0]); atomic_var_ref_ = var_ref_promise.get_future();
if (in_var_desc && in_var_desc->Persistable()) { var_ref_promise.set_value(
continue; interpreter::PrepareAtomicVarRef(global_scope_->VecMetaInfo()));
}
if (global_scope_->GetVarSikpInplace(iter->second[0])) {
continue;
}
if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) {
auto iterout = outputs.find(pair.second);
if (iterout != outputs.end() && !iterout->second.empty()) {
auto invar = global_scope_->Var(iter->second[0]);
auto outvar = global_scope_->Var(iterout->second[0]);
if (invar && outvar && invar->IsType<LoDTensor>() &&
outvar->IsType<LoDTensor>()) {
instr.AddInplace(invar, outvar);
VLOG(3) << "inplace " << vec_instruction_[i].OpBase()->Type()
<< " " << global_scope_->GetNameById(iter->second[0])
<< " -> "
<< global_scope_->GetNameById(iterout->second[0])
<< std::endl;
}
}
}
}
}
}
}
void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
VariableValueMap ins_map;
for (auto& var_name_item : instr_node->Inputs()) {
std::vector<Variable*> input_vars;
input_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) {
input_vars.emplace_back(global_scope_->Var(id));
}
ins_map.emplace(var_name_item.first, std::move(input_vars));
}
VariableValueMap outs_map;
for (auto& var_name_item : instr_node->Outputs()) {
std::vector<Variable*> out_vars;
out_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) {
out_vars.emplace_back(global_scope_->Var(id));
}
outs_map.emplace(var_name_item.first, std::move(out_vars));
}
// set runtime_ctx and infershape_ctx_
if (instr_node->OpBase()->Type() == "cinn_launch") { // OP use scope in
// kernel
Scope* local_scope = create_local_scope_
? global_scope_->GetMutableLocalScope()
: global_scope_->GetMutableScope();
instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope);
} else {
instr_node->ResetContext(ins_map, outs_map);
}
} }
void InterpreterCore::BuildSkipShareLoDInfo() { void InterpreterCore::BuildSkipShareLoDInfo() {
...@@ -505,7 +547,9 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -505,7 +547,9 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
// If it is OperatorBase, InferShape do nothing. // If it is OperatorBase, InferShape do nothing.
if (op_with_kernel != nullptr) { if (op_with_kernel != nullptr) {
platform::RecordEvent infershape_event( platform::RecordEvent infershape_event(
"infer_shape", platform::TracerEventType::OperatorInner, 1, "infer_shape",
platform::TracerEventType::OperatorInner,
1,
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
// see OperatorWithKernel::RunImpl in operator.cc for why // see OperatorWithKernel::RunImpl in operator.cc for why
...@@ -531,7 +575,9 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -531,7 +575,9 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
{ {
platform::RecordEvent compute_event( platform::RecordEvent compute_event(
"compute", platform::TracerEventType::OperatorInner, 1, "compute",
platform::TracerEventType::OperatorInner,
1,
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
if (op_with_kernel == nullptr) { if (op_with_kernel == nullptr) {
instr_node.OpBase()->Run(*local_scope, place_); instr_node.OpBase()->Run(*local_scope, place_);
...@@ -588,7 +634,8 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -588,7 +634,8 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
if (op_with_kernel != nullptr && FLAGS_check_nan_inf) { if (op_with_kernel != nullptr && FLAGS_check_nan_inf) {
VLOG(4) << "Check nan/inf"; VLOG(4) << "Check nan/inf";
framework::details::CheckOpHasNanOrInf( framework::details::CheckOpHasNanOrInf(
*op, *global_scope_, *op,
*global_scope_,
place); // TODO(xiongkun03) change it to inner scope. place); // TODO(xiongkun03) change it to inner scope.
} }
} }
...@@ -605,10 +652,11 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -605,10 +652,11 @@ void InterpreterCore::ExecuteInstructionList(
"PrepareAtomic", platform::TracerEventType::UserDefined, 1); "PrepareAtomic", platform::TracerEventType::UserDefined, 1);
// NOTE(zhiqiu): get the prepared deps from std::future, and async prepare // NOTE(zhiqiu): get the prepared deps from std::future, and async prepare
// those for the next step // those for the next step
auto atomic_deps = async_work_queue_->AtomicDeps(); auto atomic_deps = atomic_deps_.get();
auto atomic_var_ref = async_work_queue_->AtomicVarRef(); auto atomic_var_ref = atomic_var_ref_.get();
async_work_queue_->PrepareAtomicDeps(dependecy_count_); atomic_deps_ = async_work_queue_->PrepareAtomicDeps(dependecy_count_);
atomic_var_ref_ =
async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo()); async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo());
record_prepare.End(); record_prepare.End();
...@@ -617,16 +665,19 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -617,16 +665,19 @@ void InterpreterCore::ExecuteInstructionList(
for (size_t i = 0; i < dependecy_count_.size(); ++i) { for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) { if (dependecy_count_[i] == 0) {
async_work_queue_->AddTask(vec_instr.at(i).KernelType(), async_work_queue_->AddTask(vec_instr.at(i).KernelType(),
[this, i, atomic_deps = atomic_deps.get(), [this,
i,
atomic_deps = atomic_deps.get(),
atomic_var_ref = atomic_var_ref.get()] { atomic_var_ref = atomic_var_ref.get()] {
RunInstructionAsync(i, atomic_deps, RunInstructionAsync(
atomic_var_ref); i, atomic_deps, atomic_var_ref);
}); });
} }
} }
auto event_name = main_thread_blocker_.WaitEvent(); auto event_name = main_thread_blocker_.WaitEvent();
VLOG(1) << "event_name: " << event_name; VLOG(1) << "main_thread_blocker_(" << &main_thread_blocker_
<< ") got event_name: " << event_name;
if (UNLIKELY(exception_holder_.IsCaught())) { if (UNLIKELY(exception_holder_.IsCaught())) {
VLOG(1) << "Exception caught " << exception_holder_.Type(); VLOG(1) << "Exception caught " << exception_holder_.Type();
...@@ -637,7 +688,8 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -637,7 +688,8 @@ void InterpreterCore::ExecuteInstructionList(
} }
VLOG(4) << "Cancel ok"; VLOG(4) << "Cancel ok";
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
main_thread_blocker_.Clear(), 0, main_thread_blocker_.Clear(),
0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"main_thread_blocker_.Clear() return -1, clear failed")); "main_thread_blocker_.Clear() return -1, clear failed"));
VLOG(4) << "clear ok"; VLOG(4) << "clear ok";
...@@ -646,11 +698,12 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -646,11 +698,12 @@ void InterpreterCore::ExecuteInstructionList(
} }
void InterpreterCore::RunNextInstructions( void InterpreterCore::RunNextInstructions(
const Instruction& instr, std::queue<size_t>* reserved_next_ops, const Instruction& instr,
std::queue<size_t>* reserved_next_ops,
std::vector<std::atomic<size_t>>* atomic_deps, std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref) { std::vector<std::atomic<size_t>>* atomic_var_ref) {
platform::RecordEvent record("RunNextInstructions", platform::RecordEvent record(
platform::TracerEventType::UserDefined, 10); "RunNextInstructions", platform::TracerEventType::UserDefined, 10);
VLOG(4) << "atomic 1:" << atomic_deps; VLOG(4) << "atomic 1:" << atomic_deps;
auto& next_instr = instr.NextInstructions(); auto& next_instr = instr.NextInstructions();
...@@ -716,7 +769,8 @@ void InterpreterCore::RunNextInstructions( ...@@ -716,7 +769,8 @@ void InterpreterCore::RunNextInstructions(
} }
void InterpreterCore::RunInstructionAsync( void InterpreterCore::RunInstructionAsync(
size_t instr_id, std::vector<std::atomic<size_t>>* atomic_deps, size_t instr_id,
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref) { std::vector<std::atomic<size_t>>* atomic_var_ref) {
std::queue<size_t> ready_ops; std::queue<size_t> ready_ops;
ready_ops.push(instr_id); ready_ops.push(instr_id);
...@@ -787,8 +841,8 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) { ...@@ -787,8 +841,8 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
instr.KernelType() != OpFuncType::kQueueAsync) { instr.KernelType() != OpFuncType::kQueueAsync) {
return; return;
} }
platform::RecordEvent record("RecordStreamForGC", platform::RecordEvent record(
platform::TracerEventType::UserDefined, 10); "RecordStreamForGC", platform::TracerEventType::UserDefined, 10);
gpuStream_t stream = reinterpret_cast<const platform::CUDADeviceContext&>( gpuStream_t stream = reinterpret_cast<const platform::CUDADeviceContext&>(
instr.DeviceContext()) instr.DeviceContext())
...@@ -880,8 +934,8 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) { ...@@ -880,8 +934,8 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
void InterpreterCore::CheckGC( void InterpreterCore::CheckGC(
const Instruction& instr, const Instruction& instr,
std::vector<std::atomic<size_t>>* atomic_var_ref) { std::vector<std::atomic<size_t>>* atomic_var_ref) {
platform::RecordEvent record("CheckGC", platform::RecordEvent record(
platform::TracerEventType::UserDefined, 10); "CheckGC", platform::TracerEventType::UserDefined, 10);
size_t instr_id = instr.Id(); size_t instr_id = instr.Id();
auto& var_scope = *global_scope_; auto& var_scope = *global_scope_;
...@@ -906,12 +960,14 @@ void InterpreterCore::CheckGC( ...@@ -906,12 +960,14 @@ void InterpreterCore::CheckGC(
} else { } else {
static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add( static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add(
var_scope.Var(var_id), &gc_event_.at(instr_id), var_scope.Var(var_id),
&gc_event_.at(instr_id),
&instr.DeviceContext()); &instr.DeviceContext());
} }
#else #else
static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add( static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add(
var_scope.Var(var_id), &gc_event_.at(instr_id), var_scope.Var(var_id),
&gc_event_.at(instr_id),
&instr.DeviceContext()); &instr.DeviceContext());
#endif #endif
} }
...@@ -920,20 +976,24 @@ void InterpreterCore::CheckGC( ...@@ -920,20 +976,24 @@ void InterpreterCore::CheckGC(
void InterpreterCore::Prepare( void InterpreterCore::Prepare(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors, bool prepare_feed) { const std::vector<framework::LoDTensor>& feed_tensors,
PADDLE_ENFORCE_EQ(feed_names.size(), feed_tensors.size(), bool prepare_feed) {
PADDLE_ENFORCE_EQ(feed_names.size(),
feed_tensors.size(),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Required feed_names.size() == feed_tensors.size(), " "Required feed_names.size() == feed_tensors.size(), "
"but received %d != %d", "but received %d != %d",
feed_names.size(), feed_tensors.size())); feed_names.size(),
feed_tensors.size()));
auto FeedInput = [&] { auto FeedInput = [&] {
VLOG(4) << "Feed inputs"; VLOG(4) << "Feed inputs";
for (size_t i = 0; i < feed_names.size(); ++i) { for (size_t i = 0; i < feed_names.size(); ++i) {
auto* feed_var = global_scope_->FindVar(feed_names[i]); auto* feed_var = global_scope_->FindVar(feed_names[i]);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
feed_var, platform::errors::NotFound( feed_var,
"Variable %s should not be nullptr.", feed_names[i])); platform::errors::NotFound("Variable %s should not be nullptr.",
feed_names[i]));
auto feed_tensor = feed_var->GetMutable<framework::LoDTensor>(); auto feed_tensor = feed_var->GetMutable<framework::LoDTensor>();
feed_tensor->ShareDataWith(feed_tensors[i]); feed_tensor->ShareDataWith(feed_tensors[i]);
...@@ -942,12 +1002,15 @@ void InterpreterCore::Prepare( ...@@ -942,12 +1002,15 @@ void InterpreterCore::Prepare(
}; };
if (!is_build_) { if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(block_, global_scope_, paddle::framework::interpreter::build_variable_scope(
create_local_scope_); block_, global_scope_, create_local_scope_);
FeedInput(); FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list( paddle::framework::interpreter::build_op_func_list(place_,
place_, block_, skip_gc_vars_, &op_func_nodes, global_scope_, block_,
skip_gc_vars_,
&op_func_nodes,
global_scope_,
create_local_scope_); create_local_scope_);
is_build_ = true; is_build_ = true;
SetFeedVarsInplaceSkip(feed_names); SetFeedVarsInplaceSkip(feed_names);
...@@ -962,37 +1025,6 @@ void InterpreterCore::Prepare( ...@@ -962,37 +1025,6 @@ void InterpreterCore::Prepare(
} }
} }
interpreter::CostInfo InterpreterCore::DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) {
global_scope_->SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, true);
interpreter::CostInfo cost_info;
{
interpreter::ProfilerGuard(place_, &cost_info);
// For the program that only run once, it is no need to
// create work_queue, so the async_work_queue_ is created
// until the second step run.
if (async_work_queue_ == nullptr) {
async_work_queue_ = std::make_unique<interpreter::AsyncWorkQueue>(
kHostNumThreads, kDeviceNumThreads, &main_thread_blocker_);
// prepare for the first time.
async_work_queue_->PrepareAtomicDeps(dependecy_count_);
async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo());
}
ExecuteInstructionList(vec_instruction_);
platform::DeviceContextPool::Instance().Get(place_)->Wait();
}
if (create_local_scope_) {
ClearLoDTensorArrayInLocalScope();
}
return cost_info;
}
void InterpreterCore::SetFeedVarsInplaceSkip( void InterpreterCore::SetFeedVarsInplaceSkip(
const std::vector<std::string>& feed_names) { const std::vector<std::string>& feed_names) {
for (auto& feed_name : feed_names) { for (auto& feed_name : feed_names) {
...@@ -1001,8 +1033,10 @@ void InterpreterCore::SetFeedVarsInplaceSkip( ...@@ -1001,8 +1033,10 @@ void InterpreterCore::SetFeedVarsInplaceSkip(
} }
std::shared_ptr<InterpreterCore> CreateInterpreterCore( std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place, const ProgramDesc& prog, const platform::Place& place,
VariableScope* global_scope, const std::vector<std::string>& fetch_names, const ProgramDesc& prog,
VariableScope* global_scope,
const std::vector<std::string>& fetch_names,
const std::set<std::string>& skip_gc_vars) { const std::set<std::string>& skip_gc_vars) {
std::shared_ptr<InterpreterCore> core = nullptr; std::shared_ptr<InterpreterCore> core = nullptr;
// NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy // NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy
...@@ -1011,8 +1045,8 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCore( ...@@ -1011,8 +1045,8 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCore(
auto* block = new_prog->MutableBlock(0); auto* block = new_prog->MutableBlock(0);
interpreter::add_fetch(fetch_names, block); interpreter::add_fetch(fetch_names, block);
core = std::make_shared<InterpreterCore>(place, *block, skip_gc_vars, core = std::make_shared<InterpreterCore>(
global_scope); place, *block, skip_gc_vars, global_scope);
core->SetCopyProgram(new_prog); core->SetCopyProgram(new_prog);
return core; return core;
} }
......
...@@ -34,36 +34,44 @@ ...@@ -34,36 +34,44 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
class InterpreterCore { class InterpreterCore {
public: public:
InterpreterCore(const platform::Place& place, const BlockDesc& block, InterpreterCore(const platform::Place& place,
const BlockDesc& block,
const std::set<std::string>& skip_gc_vars, const std::set<std::string>& skip_gc_vars,
VariableScope* global_scope); VariableScope* global_scope);
~InterpreterCore(); ~InterpreterCore();
interpreter::CostInfo DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors);
paddle::framework::FetchList Run( paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors); const std::vector<framework::LoDTensor>& feed_tensors);
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names); paddle::framework::FetchList Run(const std::vector<std::string>& feed_names);
interpreter::CostInfo DryRun( void ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src);
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors);
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog); void SetCopyProgram(std::shared_ptr<ProgramDesc> prog);
private: private:
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes); bool BuildInplaceCheckVarIsOnlyInput(size_t var_index);
std::shared_ptr<interpreter::AsyncWorkQueue> GetWorkQueue();
void BuildAndCacheInstructionCtx(Instruction* instr_node); void BuildAndCacheInstructionCtx(Instruction* instr_node);
void BuildInplace(); void BuildInplace();
bool BuildInplaceCheckVarIsOnlyInput(size_t var_index); void BuildOperatorDependences();
void ClearLoDTensorArrayInLocalScope();
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void RunInstruction(const Instruction& instr_node); void RunInstruction(const Instruction& instr_node);
...@@ -90,12 +98,8 @@ class InterpreterCore { ...@@ -90,12 +98,8 @@ class InterpreterCore {
void BuildSkipShareLoDInfo(); void BuildSkipShareLoDInfo();
void BuildOperatorDependences();
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names); void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
void ClearLoDTensorArrayInLocalScope();
bool is_build_; bool is_build_;
const platform::Place& place_; const platform::Place& place_;
...@@ -123,7 +127,7 @@ class InterpreterCore { ...@@ -123,7 +127,7 @@ class InterpreterCore {
StreamAnalyzer stream_analyzer_; StreamAnalyzer stream_analyzer_;
EventsWaiter main_thread_blocker_; EventsWaiter main_thread_blocker_;
std::unique_ptr<interpreter::AsyncWorkQueue> async_work_queue_; std::shared_ptr<interpreter::AsyncWorkQueue> async_work_queue_;
details::ExceptionHolder exception_holder_; details::ExceptionHolder exception_holder_;
std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr}; std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr};
std::shared_ptr<EventsWaiter::EventNotifier> completion_notifier_{nullptr}; std::shared_ptr<EventsWaiter::EventNotifier> completion_notifier_{nullptr};
...@@ -132,10 +136,14 @@ class InterpreterCore { ...@@ -132,10 +136,14 @@ class InterpreterCore {
std::vector<paddle::platform::DeviceEvent> gc_event_; std::vector<paddle::platform::DeviceEvent> gc_event_;
bool create_local_scope_{true}; bool create_local_scope_{true};
Scope* local_scope_{nullptr}; // not owned Scope* local_scope_{nullptr}; // not owned
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_deps_;
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_var_ref_;
}; };
std::shared_ptr<InterpreterCore> CreateInterpreterCore( std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place, const ProgramDesc& prog, const platform::Place& place,
const ProgramDesc& prog,
VariableScope* global_scope, VariableScope* global_scope,
const std::vector<std::string>& fetch_names = {}, const std::vector<std::string>& fetch_names = {},
const std::set<std::string>& skip_gc_vars = {}); const std::set<std::string>& skip_gc_vars = {});
......
...@@ -32,12 +32,14 @@ ...@@ -32,12 +32,14 @@
// Program, while "serial_run" ensures that all Ops are scheduled in a singal // Program, while "serial_run" ensures that all Ops are scheduled in a singal
// thread. In standalone executor, "sequential_run" is also "serial_run", while // thread. In standalone executor, "sequential_run" is also "serial_run", while
// "serial_run" is not necessarily "sequential_run". // "serial_run" is not necessarily "sequential_run".
PADDLE_DEFINE_EXPORTED_bool(new_executor_sequential_run, false, PADDLE_DEFINE_EXPORTED_bool(new_executor_sequential_run,
false,
"Enable sequential execution for standalone " "Enable sequential execution for standalone "
"executor, only applied to GPU OPs."); "executor, only applied to GPU OPs.");
PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_bool(
new_executor_serial_run, false, new_executor_serial_run,
false,
"Enable serial execution for standalone executor, used for debug."); "Enable serial execution for standalone executor, used for debug.");
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
...@@ -46,8 +48,47 @@ namespace paddle { ...@@ -46,8 +48,47 @@ namespace paddle {
namespace framework { namespace framework {
namespace interpreter { namespace interpreter {
using VariableIdMap = std::map<std::string, std::vector<int>>;
constexpr size_t kPrepareWorkQueueIdx = 2; constexpr size_t kPrepareWorkQueueIdx = 2;
const std::vector<WorkQueueOptions> ConstructWorkQueueOptions(
size_t host_num_threads, size_t device_num_threads, EventsWaiter* waiter) {
std::vector<WorkQueueOptions> group_options;
// for execute host Kernel
group_options.emplace_back(/*name*/ "HostTasks",
/*num_threads*/ host_num_threads,
/*allow_spinning*/ true,
/*always_spinning*/ false,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
// for launch device Kernel
group_options.emplace_back(/*name*/ "DeviceKernelLaunch",
/*num_threads*/ device_num_threads,
/*allow_spinning*/ true,
/*always_spinning*/ true,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
// for prepare deps and others
group_options.emplace_back(/*name*/ "Prepare",
/*num_threads*/ 1,
/*allow_spinning*/ true,
/*always_spinning*/ false,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
return group_options;
}
AsyncWorkQueue::AsyncWorkQueue(size_t host_num_threads,
size_t device_num_threads,
EventsWaiter* waiter)
: host_num_thread_(host_num_threads) {
queue_group_ = CreateWorkQueueGroup(
ConstructWorkQueueOptions(host_num_threads, device_num_threads, waiter));
}
void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type, void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
std::function<void()> fn) { std::function<void()> fn) {
VLOG(4) << "Add task: " << static_cast<size_t>(op_func_type) << " "; VLOG(4) << "Add task: " << static_cast<size_t>(op_func_type) << " ";
...@@ -60,36 +101,42 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type, ...@@ -60,36 +101,42 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
} }
} }
using VariableIdMap = std::map<std::string, std::vector<int>>; std::future<std::unique_ptr<AtomicVectorSizeT>>
AsyncWorkQueue::PrepareAtomicDeps(const std::vector<size_t>& dependecy_count) {
VLOG(4) << "PrepareAtomicDeps";
return queue_group_->AddAwaitableTask(
kPrepareWorkQueueIdx, interpreter::PrepareAtomicDeps, dependecy_count);
}
void AsyncWorkQueue::PrepareAtomicDeps( std::future<std::unique_ptr<AtomicVectorSizeT>>
AsyncWorkQueue::PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info) {
VLOG(4) << "PrepareAtomicVarRef";
return queue_group_->AddAwaitableTask(
kPrepareWorkQueueIdx, interpreter::PrepareAtomicVarRef, vec_meta_info);
}
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicDeps(
const std::vector<size_t>& dependecy_count) { const std::vector<size_t>& dependecy_count) {
VLOG(4) << "PrepareAtomicDeps"; VLOG(4) << "PrepareAtomicDeps";
atomic_deps_ =
queue_group_->AddAwaitableTask(kPrepareWorkQueueIdx, [&dependecy_count] { auto op_deps = std::make_unique<AtomicVectorSizeT>(dependecy_count.size());
auto op_deps = std::make_unique<std::vector<std::atomic<size_t>>>(
dependecy_count.size());
for (size_t i = 0; i < dependecy_count.size(); ++i) { for (size_t i = 0; i < dependecy_count.size(); ++i) {
(*op_deps)[i] = dependecy_count[i]; (*op_deps)[i] = dependecy_count[i];
} }
VLOG(4) << "AtomicDeps:" << op_deps.get() << " " << op_deps->size(); VLOG(4) << "AtomicDeps:" << op_deps.get() << " " << op_deps->size();
return op_deps; return op_deps;
});
} }
void AsyncWorkQueue::PrepareAtomicVarRef( std::unique_ptr<AtomicVectorSizeT> PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info) { const std::vector<VariableMetaInfo>& vec_meta_info) {
VLOG(4) << "PrepareAtomicVarRef"; VLOG(4) << "PrepareAtomicVarRef";
atomic_var_ref_ = auto var_ref = std::make_unique<AtomicVectorSizeT>(vec_meta_info.size());
queue_group_->AddAwaitableTask(kPrepareWorkQueueIdx, [&vec_meta_info] {
auto var_ref = std::make_unique<std::vector<std::atomic<size_t>>>(
vec_meta_info.size());
for (size_t i = 0; i < vec_meta_info.size(); ++i) { for (size_t i = 0; i < vec_meta_info.size(); ++i) {
(*var_ref)[i] = vec_meta_info[i].var_ref_count_; (*var_ref)[i] = vec_meta_info[i].var_ref_count_;
} }
VLOG(4) << "AtomicVarRef:" << var_ref.get() << " " << var_ref->size(); VLOG(4) << "AtomicVarRef:" << var_ref.get() << " " << var_ref->size();
return var_ref; return var_ref;
});
} }
bool var_can_be_deleted(const std::string& name, const BlockDesc& block) { bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
...@@ -160,7 +207,8 @@ get_unused_vars(const BlockDesc& block, ...@@ -160,7 +207,8 @@ get_unused_vars(const BlockDesc& block,
} }
void build_variable_scope(const framework::BlockDesc& block, void build_variable_scope(const framework::BlockDesc& block,
VariableScope* var_scope, bool use_local_scope) { VariableScope* var_scope,
bool use_local_scope) {
VLOG(3) << "Creating Variables"; VLOG(3) << "Creating Variables";
auto inner_scope = var_scope->GetMutableScope(); auto inner_scope = var_scope->GetMutableScope();
...@@ -229,7 +277,8 @@ void create_all_ops(const framework::BlockDesc& block, ...@@ -229,7 +277,8 @@ void create_all_ops(const framework::BlockDesc& block,
} }
std::tuple<VariableValueMap, VariableIdMap> build_variable_map( std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
const VariableNameMap& var_name_map, VariableScope* var_scope, const VariableNameMap& var_name_map,
VariableScope* var_scope,
bool enforce_exist = true) { bool enforce_exist = true) {
VariableValueMap name2var; VariableValueMap name2var;
VariableIdMap name2id; VariableIdMap name2id;
...@@ -293,7 +342,8 @@ void apply_device_guard(const OperatorBase* op_base, ...@@ -293,7 +342,8 @@ void apply_device_guard(const OperatorBase* op_base,
void deal_operator_base(const platform::Place& place, void deal_operator_base(const platform::Place& place,
const VariableScope* var_scope, const VariableScope* var_scope,
std::shared_ptr<OperatorBase> op_base, std::shared_ptr<OperatorBase> op_base,
OpFuncNode* op_func_node, Scope* local_scope) { OpFuncNode* op_func_node,
Scope* local_scope) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
// input, output is prepared. set the other attributes. // input, output is prepared. set the other attributes.
...@@ -325,7 +375,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -325,7 +375,8 @@ void build_op_func_list(const platform::Place& place,
const framework::BlockDesc& block, const framework::BlockDesc& block,
const std::set<std::string>& skip_gc_vars, const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope, bool use_local_scope) { VariableScope* var_scope,
bool use_local_scope) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope(); : var_scope->GetMutableScope();
auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
...@@ -428,9 +479,14 @@ void build_op_func_list(const platform::Place& place, ...@@ -428,9 +479,14 @@ void build_op_func_list(const platform::Place& place,
// NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in // NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in
// ApplyDataTransform // ApplyDataTransform
ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, ApplyDataTransform(expected_kernel_key,
&outs_map_temp, var_scope, &op_func_node, place,
vec_func_list, use_local_scope); &ins_map_temp,
&outs_map_temp,
var_scope,
&op_func_node,
vec_func_list,
use_local_scope);
op_with_kernel = const_cast<framework::OperatorWithKernel*>( op_with_kernel = const_cast<framework::OperatorWithKernel*>(
static_cast<const framework::OperatorWithKernel*>( static_cast<const framework::OperatorWithKernel*>(
op_func_node.operator_base_.get())); op_func_node.operator_base_.get()));
...@@ -463,8 +519,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -463,8 +519,8 @@ void build_op_func_list(const platform::Place& place,
op_with_kernel->Info().infer_shape_(&infer_shape_ctx); op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
} }
auto exec_ctx = ExecutionContext(*op_with_kernel, *runtime_scope, auto exec_ctx = ExecutionContext(
*dev_ctx, runtime_context); *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
auto run_phi_kernel = false; auto run_phi_kernel = false;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel( if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
...@@ -498,24 +554,26 @@ void build_op_func_list(const platform::Place& place, ...@@ -498,24 +554,26 @@ void build_op_func_list(const platform::Place& place,
<< " : expected_kernel_key : " << expected_kernel_key; << " : expected_kernel_key : " << expected_kernel_key;
if (run_phi_kernel) { if (run_phi_kernel) {
phi::KernelContext pt_kernel_context; phi::KernelContext pt_kernel_context;
op_with_kernel->BuildPhiKernelContext(runtime_context, dev_ctx, op_with_kernel->BuildPhiKernelContext(
&pt_kernel_context); runtime_context, dev_ctx, &pt_kernel_context);
op_func_node.pt_kernel_ = op_with_kernel->PhiKernel(); op_func_node.pt_kernel_ = op_with_kernel->PhiKernel();
(*op_func_node.pt_kernel_)(&pt_kernel_context); (*op_func_node.pt_kernel_)(&pt_kernel_context);
} else { } else {
auto kernels_iter = all_op_kernels.find(op->Type()); auto kernels_iter = all_op_kernels.find(op->Type());
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(), kernels_iter,
all_op_kernels.end(),
platform::errors::Unavailable( platform::errors::Unavailable(
"There are no kernels which are registered in the %s operator.", "There are no kernels which are registered in the %s operator.",
op->Type())); op->Type()));
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(kernel_iter,
kernel_iter, kernels.end(), kernels.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Operator (%s) does not have kernel for %s.", op->Type(), "Operator (%s) does not have kernel for %s.",
op->Type(),
KernelTypeToString(expected_kernel_key))); KernelTypeToString(expected_kernel_key)));
// TODO(zhiqiu): add fallback logic // TODO(zhiqiu): add fallback logic
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
...@@ -525,9 +583,13 @@ void build_op_func_list(const platform::Place& place, ...@@ -525,9 +583,13 @@ void build_op_func_list(const platform::Place& place,
// post-process grad_op.outputs if need cast complex grad into real grad. // post-process grad_op.outputs if need cast complex grad into real grad.
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it. // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
if (framework::IsComplexType(expected_kernel_key.data_type_)) { if (framework::IsComplexType(expected_kernel_key.data_type_)) {
interpreter::HandleComplexGradToRealGrad( interpreter::HandleComplexGradToRealGrad(op_func_node,
op_func_node, place, outputs_names, &runtime_context.outputs, place,
var_scope, vec_func_list, local_scope); outputs_names,
&runtime_context.outputs,
var_scope,
vec_func_list,
local_scope);
} }
if (!op_func_node.inplace_back_map.empty()) { if (!op_func_node.inplace_back_map.empty()) {
auto& m = op_func_node.inplace_back_map; auto& m = op_func_node.inplace_back_map;
...@@ -583,7 +645,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -583,7 +645,8 @@ void build_op_func_list(const platform::Place& place,
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Type %s of variable %s is not supported eager deletion.", "Type %s of variable %s is not supported eager deletion.",
framework::ToTypeName(var->Type()), var_name)); framework::ToTypeName(var->Type()),
var_name));
} }
} }
delete garbages; // free mem delete garbages; // free mem
...@@ -612,8 +675,8 @@ void add_fetch(const std::vector<std::string>& fetch_names, ...@@ -612,8 +675,8 @@ void add_fetch(const std::vector<std::string>& fetch_names,
std::vector<size_t> merge_vector(const std::vector<size_t>& first, std::vector<size_t> merge_vector(const std::vector<size_t>& first,
const std::vector<size_t>& second) { const std::vector<size_t>& second) {
std::vector<size_t> out(first.size() + second.size()); std::vector<size_t> out(first.size() + second.size());
std::merge(first.begin(), first.end(), second.begin(), second.end(), std::merge(
out.begin()); first.begin(), first.end(), second.begin(), second.end(), out.begin());
std::vector<size_t>::iterator it; std::vector<size_t>::iterator it;
it = std::unique(out.begin(), out.end()); it = std::unique(out.begin(), out.end());
...@@ -625,7 +688,8 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first, ...@@ -625,7 +688,8 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first,
void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences, void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
std::map<int, std::list<int>>* var2min_rw_op, std::map<int, std::list<int>>* var2min_rw_op,
int cur_op, int rw_var) { int cur_op,
int rw_var) {
// rw_var is inputs or outputs of cur_op // rw_var is inputs or outputs of cur_op
// this function update the var2min_rw_op set . // this function update the var2min_rw_op set .
if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) { if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) {
...@@ -637,7 +701,8 @@ void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences, ...@@ -637,7 +701,8 @@ void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
var2min_rw_op->at(rw_var).push_back(cur_op); var2min_rw_op->at(rw_var).push_back(cur_op);
} }
void AddDownstreamOp(int prior_op_idx, int posterior_op_idx, void AddDownstreamOp(int prior_op_idx,
int posterior_op_idx,
std::map<int, std::list<int>>* op_downstream_map) { std::map<int, std::list<int>>* op_downstream_map) {
if (op_downstream_map->find(prior_op_idx) == op_downstream_map->end()) { if (op_downstream_map->find(prior_op_idx) == op_downstream_map->end()) {
op_downstream_map->emplace(std::make_pair(prior_op_idx, std::list<int>())); op_downstream_map->emplace(std::make_pair(prior_op_idx, std::list<int>()));
...@@ -645,7 +710,8 @@ void AddDownstreamOp(int prior_op_idx, int posterior_op_idx, ...@@ -645,7 +710,8 @@ void AddDownstreamOp(int prior_op_idx, int posterior_op_idx,
op_downstream_map->at(prior_op_idx).push_back(posterior_op_idx); op_downstream_map->at(prior_op_idx).push_back(posterior_op_idx);
} }
void AddDownstreamOp(int prior_op_idx, int posterior_op_idx, void AddDownstreamOp(int prior_op_idx,
int posterior_op_idx,
std::map<int, std::list<int>>* op_downstream_map, std::map<int, std::list<int>>* op_downstream_map,
const std::vector<std::vector<bool>>& op_happens_before) { const std::vector<std::vector<bool>>& op_happens_before) {
if (op_downstream_map->find(prior_op_idx) != op_downstream_map->end()) { if (op_downstream_map->find(prior_op_idx) != op_downstream_map->end()) {
...@@ -675,7 +741,8 @@ const std::string StringizeDownstreamMap( ...@@ -675,7 +741,8 @@ const std::string StringizeDownstreamMap(
std::ostringstream oss; std::ostringstream oss;
for (auto pair : downstream_map) { for (auto pair : downstream_map) {
oss << pair.first << " -> "; oss << pair.first << " -> ";
std::copy(pair.second.begin(), pair.second.end(), std::copy(pair.second.begin(),
pair.second.end(),
std::ostream_iterator<int>(oss, " ")); std::ostream_iterator<int>(oss, " "));
oss << std::endl; oss << std::endl;
} }
...@@ -717,8 +784,8 @@ void ShrinkDownstreamMap(std::map<int, std::list<int>>* downstream_map, ...@@ -717,8 +784,8 @@ void ShrinkDownstreamMap(std::map<int, std::list<int>>* downstream_map,
op_happens_before->resize(op_num); op_happens_before->resize(op_num);
for (size_t i = 0; i < op_num; ++i) { for (size_t i = 0; i < op_num; ++i) {
(*op_happens_before)[i].resize(op_num); (*op_happens_before)[i].resize(op_num);
std::fill((*op_happens_before)[i].begin(), (*op_happens_before)[i].end(), std::fill(
false); (*op_happens_before)[i].begin(), (*op_happens_before)[i].end(), false);
} }
// bfs to get all next ops // bfs to get all next ops
...@@ -735,11 +802,15 @@ void ShrinkDownstreamMap(std::map<int, std::list<int>>* downstream_map, ...@@ -735,11 +802,15 @@ void ShrinkDownstreamMap(std::map<int, std::list<int>>* downstream_map,
} }
for (auto next : downstream_map->at(op)) { for (auto next : downstream_map->at(op)) {
if (!visited[next]) { if (!visited[next]) {
PADDLE_ENFORCE_EQ((*op_happens_before)[next][op_idx], false, PADDLE_ENFORCE_EQ((*op_happens_before)[next][op_idx],
false,
paddle::platform::errors::AlreadyExists( paddle::platform::errors::AlreadyExists(
"There exists circle in graph, expected " "There exists circle in graph, expected "
"%d->%d, but already got %d->%d", "%d->%d, but already got %d->%d",
op_idx, next, next, op_idx)); op_idx,
next,
next,
op_idx));
(*op_happens_before)[op_idx][next] = true; (*op_happens_before)[op_idx][next] = true;
VLOG(8) << "happens before: " << op_idx << " " << next; VLOG(8) << "happens before: " << op_idx << " " << next;
q.push(next); q.push(next);
...@@ -892,8 +963,8 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -892,8 +963,8 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (size_t op_idx = 0; op_idx < op_num; ++op_idx) { for (size_t op_idx = 0; op_idx < op_num; ++op_idx) {
if (random_op_set.count(vec_instruction[op_idx].OpBase()->Type())) { if (random_op_set.count(vec_instruction[op_idx].OpBase()->Type())) {
if (dependence_op_idx != -1) { if (dependence_op_idx != -1) {
AddDownstreamOp(dependence_op_idx, op_idx, &op_downstream_map, AddDownstreamOp(
*op_happens_before); dependence_op_idx, op_idx, &op_downstream_map, *op_happens_before);
} }
dependence_op_idx = op_idx; dependence_op_idx = op_idx;
} }
...@@ -919,8 +990,8 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -919,8 +990,8 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (size_t op_idx = 0; op_idx < op_num; ++op_idx) { for (size_t op_idx = 0; op_idx < op_num; ++op_idx) {
if (is_comm_op(vec_instruction[op_idx].OpBase()->Type())) { if (is_comm_op(vec_instruction[op_idx].OpBase()->Type())) {
if (dependence_op_idx != -1) { if (dependence_op_idx != -1) {
AddDownstreamOp(dependence_op_idx, op_idx, &op_downstream_map, AddDownstreamOp(
*op_happens_before); dependence_op_idx, op_idx, &op_downstream_map, *op_happens_before);
VLOG(4) << "Add depend from " VLOG(4) << "Add depend from "
<< vec_instruction[dependence_op_idx].OpBase()->Type() << " to " << vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
<< vec_instruction[op_idx].OpBase()->Type(); << vec_instruction[op_idx].OpBase()->Type();
...@@ -948,8 +1019,8 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -948,8 +1019,8 @@ std::map<int, std::list<int>> build_op_downstream_map(
VLOG(4) << "Add depend from " VLOG(4) << "Add depend from "
<< vec_instruction[dependence_op_idx].OpBase()->Type() << " to " << vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
<< vec_instruction[op_idx].OpBase()->Type(); << vec_instruction[op_idx].OpBase()->Type();
AddDownstreamOp(dependence_op_idx, op_idx, &op_downstream_map, AddDownstreamOp(
*op_happens_before); dependence_op_idx, op_idx, &op_downstream_map, *op_happens_before);
} }
} }
} }
...@@ -1002,10 +1073,13 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -1002,10 +1073,13 @@ std::map<int, std::list<int>> build_op_downstream_map(
// first_read_fused_out_op) // first_read_fused_out_op)
// add depend: them->first_read_fused_out_op // add depend: them->first_read_fused_out_op
for (auto j = op_idx + 1; for (auto j = op_idx + 1;
j < static_cast<size_t>(first_read_fused_out_op); ++j) { j < static_cast<size_t>(first_read_fused_out_op);
++j) {
for (auto var_id : outputs) { for (auto var_id : outputs) {
if (is_write(vec_instruction[j], var_id)) { if (is_write(vec_instruction[j], var_id)) {
AddDownstreamOp(j, first_read_fused_out_op, &op_downstream_map, AddDownstreamOp(j,
first_read_fused_out_op,
&op_downstream_map,
*op_happens_before); *op_happens_before);
VLOG(4) << j << " -> " << first_read_fused_out_op; VLOG(4) << j << " -> " << first_read_fused_out_op;
VLOG(4) VLOG(4)
...@@ -1055,7 +1129,9 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -1055,7 +1129,9 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (size_t op_idx = 0; op_idx < op_num; ++op_idx) { for (size_t op_idx = 0; op_idx < op_num; ++op_idx) {
if (!IsCpuOp(vec_instruction[op_idx])) { if (!IsCpuOp(vec_instruction[op_idx])) {
if (dependence_op_idx != -1) { if (dependence_op_idx != -1) {
AddDownstreamOp(dependence_op_idx, op_idx, &op_downstream_map, AddDownstreamOp(dependence_op_idx,
op_idx,
&op_downstream_map,
*op_happens_before); *op_happens_before);
VLOG(4) << "Add depend from " VLOG(4) << "Add depend from "
<< vec_instruction[dependence_op_idx].OpBase()->Type() << "(" << vec_instruction[dependence_op_idx].OpBase()->Type() << "("
......
...@@ -44,49 +44,23 @@ ...@@ -44,49 +44,23 @@
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
using AtomicVectorSizeT = std::vector<std::atomic<size_t>>;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace interpreter { namespace interpreter {
using AtomicVectorSizeT =
std::future<std::unique_ptr<std::vector<std::atomic<size_t>>>>;
class AsyncWorkQueue { class AsyncWorkQueue {
public: public:
AsyncWorkQueue(size_t host_num_threads, size_t deivce_num_threads, AsyncWorkQueue(size_t host_num_threads,
EventsWaiter* waiter) size_t deivce_num_threads,
: host_num_thread_(host_num_threads) { EventsWaiter* waiter);
std::vector<WorkQueueOptions> group_options;
// for execute host Kernel std::future<std::unique_ptr<AtomicVectorSizeT>> PrepareAtomicDeps(
group_options.emplace_back(/*name*/ "HostTasks", const std::vector<size_t>& dependecy_count);
/*num_threads*/ host_num_threads, std::future<std::unique_ptr<AtomicVectorSizeT>> PrepareAtomicVarRef(
/*allow_spinning*/ true, const std::vector<VariableMetaInfo>& vec_meta_info);
/*always_spinning*/ false,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
// for launch device Kernel
group_options.emplace_back(/*name*/ "DeviceKernelLaunch",
/*num_threads*/ deivce_num_threads,
/*allow_spinning*/ true,
/*always_spinning*/ true,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
// for prepare deps and others
group_options.emplace_back(/*name*/ "Prepare",
/*num_threads*/ 1,
/*allow_spinning*/ true,
/*always_spinning*/ false,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
queue_group_ = CreateWorkQueueGroup(group_options);
}
void PrepareAtomicDeps(const std::vector<size_t>& dependecy_count);
void PrepareAtomicVarRef(const std::vector<VariableMetaInfo>& vec_meta_info);
// void WaitEmpty() { queue_group_->WaitQueueGroupEmpty(); } // void WaitEmpty() { queue_group_->WaitQueueGroupEmpty(); }
...@@ -94,20 +68,16 @@ class AsyncWorkQueue { ...@@ -94,20 +68,16 @@ class AsyncWorkQueue {
void Cancel() { queue_group_->Cancel(); } void Cancel() { queue_group_->Cancel(); }
std::unique_ptr<std::vector<std::atomic<size_t>>> AtomicDeps() {
return atomic_deps_.get();
}
std::unique_ptr<std::vector<std::atomic<size_t>>> AtomicVarRef() {
return atomic_var_ref_.get();
}
private: private:
size_t host_num_thread_; size_t host_num_thread_;
std::unique_ptr<WorkQueueGroup> queue_group_; std::unique_ptr<WorkQueueGroup> queue_group_;
AtomicVectorSizeT atomic_deps_;
AtomicVectorSizeT atomic_var_ref_;
}; };
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicDeps(
const std::vector<size_t>& dependecy_count);
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info);
void build_variable_scope(const framework::BlockDesc& block, void build_variable_scope(const framework::BlockDesc& block,
VariableScope* var_scope, VariableScope* var_scope,
bool use_local_scope = true); bool use_local_scope = true);
...@@ -116,7 +86,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -116,7 +86,8 @@ void build_op_func_list(const platform::Place& place,
const framework::BlockDesc& block, const framework::BlockDesc& block,
const std::set<std::string>& skip_gc_vars, const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope, bool use_local_scope = true); VariableScope* var_scope,
bool use_local_scope = true);
std::map<int, std::list<int>> build_op_downstream_map( std::map<int, std::list<int>> build_op_downstream_map(
const std::vector<Instruction>& vec_instruction, const std::vector<Instruction>& vec_instruction,
......
...@@ -12,15 +12,14 @@ ...@@ -12,15 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <chrono> #include <chrono>
#include <iostream> #include <iostream>
#include <string> #include <string>
// #include "gperftools/profiler.h"
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
USE_OP_ITSELF(fill_constant); USE_OP_ITSELF(fill_constant);
...@@ -61,6 +60,7 @@ USE_OP_ITSELF(sgd); ...@@ -61,6 +60,7 @@ USE_OP_ITSELF(sgd);
USE_OP(squared_l2_norm); USE_OP(squared_l2_norm);
USE_OP_ITSELF(memcpy_h2d); USE_OP_ITSELF(memcpy_h2d);
USE_OP_ITSELF(memcpy_d2h); USE_OP_ITSELF(memcpy_d2h);
USE_OP_ITSELF(fetch_v2);
PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(uniform_random_raw, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(uniform_random_raw, GPU, ALL_LAYOUT);
...@@ -72,6 +72,7 @@ PD_DECLARE_KERNEL(concat_grad, GPU, ALL_LAYOUT); ...@@ -72,6 +72,7 @@ PD_DECLARE_KERNEL(concat_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add_raw, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(add_raw, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(add, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(multiply_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(divide, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(divide, KPS, ALL_LAYOUT);
...@@ -99,8 +100,6 @@ PD_DECLARE_KERNEL(cross_entropy_with_softmax, GPU, ALL_LAYOUT); ...@@ -99,8 +100,6 @@ PD_DECLARE_KERNEL(cross_entropy_with_softmax, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(cross_entropy_with_softmax_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(cross_entropy_with_softmax_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sqrt, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sqrt, GPU, ALL_LAYOUT);
DECLARE_double(eager_delete_tensor_gb);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -115,15 +114,11 @@ ProgramDesc load_from_file(const std::string& file_name) { ...@@ -115,15 +114,11 @@ ProgramDesc load_from_file(const std::string& file_name) {
return program_desc; return program_desc;
} }
TEST(StandaloneExecutor, run) { ProgramDesc GetLmMainProgram() {
FLAGS_eager_delete_tensor_gb = 0.1; ProgramDesc main_prog = load_from_file("lm_main_program");
int64_t batch_size = 20;
auto place = platform::CUDAPlace(0);
auto test_prog = load_from_file("lm_startup_program");
auto main_prog = load_from_file("lm_main_program");
auto& global_block = main_prog.Block(0); auto& global_block = main_prog.Block(0);
int64_t batch_size = 20;
auto& op1 = global_block.AllOps()[1]; auto& op1 = global_block.AllOps()[1];
auto shape1 = BOOST_GET_CONST(std::vector<int64_t>, op1->GetAttr("shape")); auto shape1 = BOOST_GET_CONST(std::vector<int64_t>, op1->GetAttr("shape"));
...@@ -139,6 +134,13 @@ TEST(StandaloneExecutor, run) { ...@@ -139,6 +134,13 @@ TEST(StandaloneExecutor, run) {
auto shape3 = BOOST_GET_CONST(std::vector<int64_t>, op3->GetAttr("shape")); auto shape3 = BOOST_GET_CONST(std::vector<int64_t>, op3->GetAttr("shape"));
shape3[0] = batch_size; shape3[0] = batch_size;
op3->SetAttr("shape", shape3); op3->SetAttr("shape", shape3);
return main_prog;
}
TEST(StandaloneExecutor, run) {
auto place = platform::CUDAPlace(0);
ProgramDesc test_prog = load_from_file("lm_startup_program");
ProgramDesc main_prog = GetLmMainProgram();
Scope scope; Scope scope;
StandaloneExecutor exec(place, test_prog, main_prog, &scope); StandaloneExecutor exec(place, test_prog, main_prog, &scope);
...@@ -159,31 +161,10 @@ TEST(StandaloneExecutor, run) { ...@@ -159,31 +161,10 @@ TEST(StandaloneExecutor, run) {
std::cout << "time cost " << diff.count() << std::endl; std::cout << "time cost " << diff.count() << std::endl;
} }
TEST(StandaloneExecutor, skip_gc_vars) { TEST(InterpreterCore, skip_gc_vars) {
FLAGS_eager_delete_tensor_gb = 0;
int64_t batch_size = 20;
auto place = platform::CUDAPlace(0); auto place = platform::CUDAPlace(0);
auto startup_prog = load_from_file("lm_startup_program"); ProgramDesc startup_prog = load_from_file("lm_startup_program");
auto main_prog = load_from_file("lm_main_program"); ProgramDesc main_prog = GetLmMainProgram();
auto& global_block = main_prog.Block(0);
auto& op1 = global_block.AllOps()[1];
auto shape1 = BOOST_GET_CONST(std::vector<int64_t>, op1->GetAttr("shape"));
shape1[0] = batch_size * 20;
op1->SetAttr("shape", shape1);
auto& op2 = global_block.AllOps()[2];
auto shape2 = BOOST_GET_CONST(std::vector<int64_t>, op2->GetAttr("shape"));
shape2[0] = batch_size;
op2->SetAttr("shape", shape2);
auto& op3 = global_block.AllOps()[3];
auto shape3 = BOOST_GET_CONST(std::vector<int64_t>, op3->GetAttr("shape"));
shape3[0] = batch_size;
op3->SetAttr("shape", shape3);
Scope scope; Scope scope;
...@@ -192,14 +173,18 @@ TEST(StandaloneExecutor, skip_gc_vars) { ...@@ -192,14 +173,18 @@ TEST(StandaloneExecutor, skip_gc_vars) {
CreateInterpreterCore(place, startup_prog, &startup_scope); CreateInterpreterCore(place, startup_prog, &startup_scope);
startup_core->Run({}, {}); startup_core->Run({}, {});
std::set<std::string> skip_gc_vars = {"uniform_0.tmp_0", "transpose_0.tmp_0", std::set<std::string> skip_gc_vars = {"uniform_0.tmp_0",
"embedding_0.tmp_0", "slice_0.tmp_0", "transpose_0.tmp_0",
"embedding_0.tmp_0",
"slice_0.tmp_0",
"split_1.tmp_2"}; "split_1.tmp_2"};
std::set<std::string> gc_vars = {"uniform_1.tmp_0", "matmul_0.tmp_0", std::set<std::string> gc_vars = {"uniform_1.tmp_0",
"split_0.tmp_0", "elementwise_add_0.tmp_0", "matmul_0.tmp_0",
"split_0.tmp_0",
"elementwise_add_0.tmp_0",
"tmp_0"}; "tmp_0"};
auto check_gc_result = [](VariableScope& scope, std::set<std::string>& vars, auto check_gc_result =
bool is_skip_gc) { [](VariableScope& scope, std::set<std::string>& vars, bool is_skip_gc) {
for (const std::string& var_name : vars) { for (const std::string& var_name : vars) {
ASSERT_EQ( ASSERT_EQ(
scope.FindVar(var_name)->GetMutable<LoDTensor>()->IsInitialized(), scope.FindVar(var_name)->GetMutable<LoDTensor>()->IsInitialized(),
...@@ -220,5 +205,68 @@ TEST(StandaloneExecutor, skip_gc_vars) { ...@@ -220,5 +205,68 @@ TEST(StandaloneExecutor, skip_gc_vars) {
check_gc_result(main_scope, gc_vars, false); check_gc_result(main_scope, gc_vars, false);
} }
void TestShareWorkQueue(const ProgramDesc& prog,
const std::vector<std::string>& feed_names,
const std::vector<LoDTensor>& feed_tensors,
const std::vector<std::string>& fetch_names,
const std::vector<float>& fetch_results) {
const platform::CPUPlace place = platform::CPUPlace();
Scope scope;
VariableScope variable_scope(&scope);
std::shared_ptr<InterpreterCore> core1 =
CreateInterpreterCore(place, prog, &variable_scope, fetch_names);
std::shared_ptr<InterpreterCore> core2 =
CreateInterpreterCore(place, prog, &variable_scope, fetch_names);
core2->ShareWorkQueueFrom(core1);
auto run_and_check = [&feed_names, &feed_tensors, &fetch_results](
std::shared_ptr<InterpreterCore> core) {
FetchList fetch_list = core->Run(feed_names, feed_tensors);
for (size_t i = 0; i < fetch_list.size(); ++i) {
const float* fetch_data =
BOOST_GET_CONST(LoDTensor, fetch_list[i]).data<float>();
ASSERT_FLOAT_EQ(*fetch_data, fetch_results.at(i));
}
};
run_and_check(core1);
run_and_check(core2);
run_and_check(core1);
run_and_check(core2);
}
TEST(InterpreterCore, workqueue_multiplexing) {
ProgramDesc program;
BlockDesc* main_block = program.MutableBlock(0);
VarDesc* var_a = main_block->Var("a");
VarDesc* var_b = main_block->Var("b");
VarDesc* var_c = main_block->Var("c");
var_a->SetType(proto::VarType::LOD_TENSOR);
var_b->SetType(proto::VarType::LOD_TENSOR);
var_c->SetType(proto::VarType::LOD_TENSOR);
OpDesc* add = main_block->AppendOp();
add->SetType("elementwise_add");
add->SetInput("X", {"a"});
add->SetInput("Y", {"b"});
add->SetOutput("Out", {"c"});
float data_a[] = {0, 1, 2, 3};
float data_b[] = {0.0, 0.1, 0.2, 0.3};
phi::DDim dims = phi::make_ddim({2, 2});
const platform::CPUPlace place = platform::CPUPlace();
LoDTensor tensor_a = LoDTensor();
LoDTensor tensor_b = LoDTensor();
std::copy_n(data_a, 4, tensor_a.mutable_data<float>(dims, place));
std::copy_n(data_b, 4, tensor_b.mutable_data<float>(dims, place));
TestShareWorkQueue(
program, {"a", "b"}, {tensor_a, tensor_b}, {"c"}, {0.0, 1.1, 2.2, 3.3});
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册