未验证 提交 9516108a 编写于 作者: A Aurelius84 提交者: GitHub

Modify Struct into Class to improve encapsulation and Polish code exception (#36797)

* Refactor InterpreterCore code

* make tuple
上级 a7d8837b
...@@ -22,13 +22,13 @@ void EventManager::WaitEvent(const Instruction& instruction, ...@@ -22,13 +22,13 @@ void EventManager::WaitEvent(const Instruction& instruction,
// If InterpreterCore in on CPUPlace, do nothing. // If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place)) return; if (platform::is_cpu_place(place)) return;
VLOG(3) << "Deal StreamWaitEventOrSync for " VLOG(3) << "Deal StreamWaitEventOrSync for " << instruction.OpBase()->Type();
<< instruction.kernel_func_.operator_base_->Type();
for (auto& event_iter : instruction.intput_events_) { for (auto& event_iter : instruction.InputEvents()) {
VLOG(3) << "wait var_id: " << event_iter.var_id_ VLOG(3) << "wait var_id: " << event_iter.var_id_
<< " 's event with waiter_type: " << event_iter.waiter_type_; << " 's event with waiter_type: " << event_iter.waiter_type_;
event_iter.event_->Wait(event_iter.waiter_type_, instruction.dev_ctx_); event_iter.event_->Wait(event_iter.waiter_type_,
&instruction.DeviceContext());
} }
} }
...@@ -37,9 +37,9 @@ void EventManager::RecordEvent(const Instruction& instruction, ...@@ -37,9 +37,9 @@ void EventManager::RecordEvent(const Instruction& instruction,
// If InterpreterCore in on CPUPlace, do nothing. // If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place)) return; if (platform::is_cpu_place(place)) return;
for (auto& event : instruction.output_events_) { for (auto& event : instruction.OutputEvents()) {
VLOG(3) << "Record event in out_var_id: " << event.var_id_; VLOG(3) << "Record event in out_var_id: " << event.var_id_;
event.event_->Record(instruction.dev_ctx_); event.event_->Record(&instruction.DeviceContext());
} }
} }
......
...@@ -79,11 +79,9 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -79,11 +79,9 @@ paddle::framework::FetchList InterpreterCore::Run(
const std::vector<framework::Tensor>& feed_tensors) { const std::vector<framework::Tensor>& feed_tensors) {
auto FeedInput = [&] { auto FeedInput = [&] {
for (size_t i = 0; i < feed_names_.size(); ++i) { for (size_t i = 0; i < feed_names_.size(); ++i) {
auto it = global_scope_->name2id.find(feed_names_[i]); auto* feed_var = global_scope_->Var(feed_names_[i]);
assert(it != global_scope_->name2id.end());
auto feed_tensor = global_scope_->var_list[it->second] auto feed_tensor = feed_var->GetMutable<framework::LoDTensor>();
->GetMutable<framework::LoDTensor>();
feed_tensor->ShareDataWith(feed_tensors[i]); feed_tensor->ShareDataWith(feed_tensors[i]);
} }
}; };
...@@ -93,7 +91,7 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -93,7 +91,7 @@ paddle::framework::FetchList InterpreterCore::Run(
global_scope_); global_scope_);
FeedInput(); FeedInput();
paddle::framework::interpretercore::build_op_func_list( paddle::framework::interpretercore::build_op_func_list(
place_, main_program_, &op_list_, &vec_func_list_, global_scope_); place_, main_program_, &vec_func_list_, global_scope_);
is_build_ = true; is_build_ = true;
// convert vec func_list to graph // convert vec func_list to graph
Convert(); Convert();
...@@ -103,42 +101,39 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -103,42 +101,39 @@ paddle::framework::FetchList InterpreterCore::Run(
} }
// return Fetch Tensors // return Fetch Tensors
return *(global_scope_->var_list[global_scope_->name2id["fetch_vars"]] auto* fetch_var = global_scope_->Var("fetch_vars");
->GetMutable<framework::FetchList>()); return *(fetch_var->GetMutable<framework::FetchList>());
} }
void InterpreterCore::Convert() { void InterpreterCore::Convert() {
input_var2op_info_.resize(global_scope_->var_list.size()); auto var_nums = global_scope_->VarSize();
input_var2op_info_.resize(var_nums);
vec_instruction_.reserve(vec_func_list_.size()); vec_meta_info_.resize(var_nums);
dependecy_count_.resize(vec_func_list_.size());
vec_meta_info_.resize(global_scope_->var_list.size());
for (size_t i = 0; i < vec_func_list_.size(); ++i) {
Instruction temp_inst;
auto* op_base = op_list_[i];
temp_inst.dev_ctx_ =
stream_analyzer_.ParseDeviceContext(vec_func_list_[i], *op_base);
temp_inst.kernel_func_.compute_func_ = vec_func_list_[i].kernel_func_;
temp_inst.kernel_func_.operator_base_ = op_base;
temp_inst.input_index_ = vec_func_list_[i].input_index;
temp_inst.output_index_ = vec_func_list_[i].output_index;
temp_inst.type_ = vec_func_list_[i].type_;
temp_inst.no_data_transform_index_ =
vec_func_list_[i].no_data_transform_index;
OpInOutInfo info; auto op_nums = vec_func_list_.size();
vec_instruction_.reserve(op_nums);
dependecy_count_.resize(op_nums);
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = vec_func_list_[op_idx];
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
vec_instruction_.emplace_back(op_idx, op_func_node, *dev_ctx_);
auto& instr = vec_instruction_.back();
OpInOutInfo info;
std::vector<size_t> gc_check_input_list; std::vector<size_t> gc_check_input_list;
for (auto& item : vec_func_list_[i].input_index) {
for (auto& item : op_func_node.input_index) {
for (auto id : item.second) { for (auto id : item.second) {
input_var2op_info_[id].push_back(i); input_var2op_info_.at(id).push_back(op_idx);
// var can be gc-ed // var can be gc-ed
if (!info.IsBuilt()) { if (!info.IsBuilt()) {
info.Build(op_list_[i]); info.Build(op_func_node.operator_base_);
} }
if (global_scope_->vec_meta_info_[id].vardesc_) { auto* var_desc = global_scope_->VarDesc(id);
if (info.IsInArgBufferNeeded( if (var_desc) {
global_scope_->vec_meta_info_[id].vardesc_->Name())) { if (info.IsInArgBufferNeeded(var_desc->Name())) {
gc_check_input_list.push_back(id); gc_check_input_list.push_back(id);
} }
} else { } else {
...@@ -150,22 +145,20 @@ void InterpreterCore::Convert() { ...@@ -150,22 +145,20 @@ void InterpreterCore::Convert() {
auto last = auto last =
std::unique(gc_check_input_list.begin(), gc_check_input_list.end()); std::unique(gc_check_input_list.begin(), gc_check_input_list.end());
gc_check_input_list.erase(last, gc_check_input_list.end()); gc_check_input_list.erase(last, gc_check_input_list.end());
for (auto var_id : gc_check_input_list) { for (auto var_id : gc_check_input_list) {
vec_meta_info_[var_id].var_ref_count_++; vec_meta_info_[var_id].var_ref_count_++;
instr.AddGCCheckVar(var_id);
} }
temp_inst.gc_check_var_list.swap(gc_check_input_list);
vec_instruction_.push_back(temp_inst);
} }
for (size_t i = 0; i < vec_instruction_.size(); ++i) { for (size_t i = 0; i < vec_instruction_.size(); ++i) {
// checkout ouput // checkout ouput
for (auto& item : vec_instruction_[i].output_index_) { for (auto& item : vec_instruction_[i].Outputs()) {
for (auto id : item.second) { for (auto id : item.second) {
if (input_var2op_info_[id].size() == 0) { if (input_var2op_info_.at(id).size() == 0) {
// output var not be used by any kernel // output var not be used by any kernel
vec_instruction_[i].gc_check_var_list.push_back(id); vec_instruction_[i].AddGCCheckVar(id);
vec_meta_info_[id].var_ref_count_++; vec_meta_info_[id].var_ref_count_++;
} }
} }
...@@ -174,7 +167,7 @@ void InterpreterCore::Convert() { ...@@ -174,7 +167,7 @@ void InterpreterCore::Convert() {
for (size_t i = 0; i < vec_instruction_.size(); ++i) { for (size_t i = 0; i < vec_instruction_.size(); ++i) {
std::vector<size_t> vec_temp; std::vector<size_t> vec_temp;
for (auto& item : vec_instruction_[i].output_index_) { for (auto& item : vec_instruction_[i].Outputs()) {
for (auto id : item.second) { for (auto id : item.second) {
vec_temp = vec_temp =
interpretercore::merge_vector(vec_temp, input_var2op_info_[id]); interpretercore::merge_vector(vec_temp, input_var2op_info_[id]);
...@@ -205,7 +198,7 @@ void InterpreterCore::Convert() { ...@@ -205,7 +198,7 @@ void InterpreterCore::Convert() {
BuildSkipShareLoDInfo(); BuildSkipShareLoDInfo();
for (size_t i = 0; i < vec_instruction_.size(); ++i) { for (size_t i = 0; i < vec_instruction_.size(); ++i) {
gc_event_.emplace_back(vec_instruction_[i].execution_ctx_.get()->GetPlace(), gc_event_.emplace_back(vec_instruction_[i].DeviceContext().GetPlace(),
platform::GenerateDeviceEventFlag()); platform::GenerateDeviceEventFlag());
} }
...@@ -215,15 +208,14 @@ void InterpreterCore::Convert() { ...@@ -215,15 +208,14 @@ void InterpreterCore::Convert() {
} }
bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) { bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) {
if (!global_scope_->vec_meta_info_[var_index].vardesc_) { if (!global_scope_->VarDesc(var_index)) {
return input_var2op_info_[var_index].size() == 1; return input_var2op_info_.at(var_index).size() == 1;
} else { } else {
int is_input_cnt = 0; int is_input_cnt = 0;
for (auto inst_id : input_var2op_info_[var_index]) { for (auto inst_id : input_var2op_info_.at(var_index)) {
OpInOutInfo info; OpInOutInfo info;
info.Build(vec_instruction_[inst_id].kernel_func_.operator_base_); info.Build(vec_instruction_.at(inst_id).OpBase());
if (info.IsInArgBufferNeeded( if (info.IsInArgBufferNeeded(global_scope_->VarDesc(var_index)->Name())) {
global_scope_->vec_meta_info_[var_index].vardesc_->Name())) {
is_input_cnt++; is_input_cnt++;
} }
} }
...@@ -233,35 +225,31 @@ bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) { ...@@ -233,35 +225,31 @@ bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) {
void InterpreterCore::BuildInplace() { void InterpreterCore::BuildInplace() {
for (size_t i = 0; i < vec_instruction_.size(); ++i) { for (size_t i = 0; i < vec_instruction_.size(); ++i) {
if (!vec_instruction_[i] auto& instr = vec_instruction_[i];
.kernel_func_.operator_base_->Info() auto* op_base = instr.OpBase();
.infer_inplace_) { if (!op_base->Info().infer_inplace_) {
continue; continue;
} }
auto in_to_outs = auto in_to_outs = op_base->Info().infer_inplace_(
vec_instruction_[i].kernel_func_.operator_base_->Info().infer_inplace_( platform::is_gpu_place(instr.DeviceContext().GetPlace()));
platform::is_gpu_place(vec_instruction_[i].dev_ctx_->GetPlace()));
auto& inputs = instr.Inputs();
auto& outputs = instr.Outputs();
for (auto& pair : in_to_outs) { for (auto& pair : in_to_outs) {
auto iter = vec_instruction_[i].input_index_.find(pair.first); auto iter = inputs.find(pair.first);
if (iter != vec_instruction_[i].input_index_.end()) { if (iter != inputs.end()) {
if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) { if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) {
auto iterout = vec_instruction_[i].output_index_.find(pair.second); auto iterout = outputs.find(pair.second);
if (iterout != vec_instruction_[i].output_index_.end()) { if (iterout != outputs.end()) {
auto invar = global_scope_->var_list[iter->second[0]]; auto invar = global_scope_->Var(iter->second[0]);
auto outvar = global_scope_->var_list[iterout->second[0]]; auto outvar = global_scope_->Var(iterout->second[0]);
if (invar && outvar) { if (invar && outvar) {
vec_instruction_[i].vec_inplace_in_to_out_.emplace_back(invar, instr.AddInplace(invar, outvar);
outvar); VLOG(3) << "inplace " << op_base->Type() << " "
VLOG(3) << "inplace " << global_scope_->VarDesc(iter->second[0])->Name()
<< vec_instruction_[i].kernel_func_.operator_base_->Type()
<< " "
<< global_scope_->vec_meta_info_[iter->second[0]]
.vardesc_->Name()
<< " -> " << " -> "
<< global_scope_->vec_meta_info_[iterout->second[0]] << global_scope_->VarDesc(iterout->second[0])->Name()
.vardesc_->Name()
<< std::endl; << std::endl;
} }
} }
...@@ -274,48 +262,35 @@ void InterpreterCore::BuildInplace() { ...@@ -274,48 +262,35 @@ void InterpreterCore::BuildInplace() {
void InterpreterCore::BuildAndCacheInstructionCtx( void InterpreterCore::BuildAndCacheInstructionCtx(
Instruction* instr_node, const VariableScope& var_scope, Instruction* instr_node, const VariableScope& var_scope,
const platform::Place& place) { const platform::Place& place) {
auto op_base = instr_node->kernel_func_.operator_base_;
VariableValueMap ins_map; VariableValueMap ins_map;
for (auto& var_name_item : instr_node->input_index_) { for (auto& var_name_item : instr_node->Inputs()) {
std::vector<Variable*> input_vars; std::vector<Variable*> input_vars;
input_vars.reserve(var_name_item.second.size()); input_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) { for (auto& id : var_name_item.second) {
input_vars.emplace_back(var_scope.var_list[id]); input_vars.emplace_back(var_scope.Var(id));
} }
ins_map.emplace(var_name_item.first, std::move(input_vars)); ins_map.emplace(var_name_item.first, std::move(input_vars));
} }
VariableValueMap outs_map; VariableValueMap outs_map;
for (auto& var_name_item : instr_node->output_index_) { for (auto& var_name_item : instr_node->Outputs()) {
std::vector<Variable*> out_vars; std::vector<Variable*> out_vars;
out_vars.reserve(var_name_item.second.size()); out_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) { for (auto& id : var_name_item.second) {
out_vars.emplace_back(var_scope.var_list[id]); out_vars.emplace_back(var_scope.Var(id));
} }
outs_map.emplace(var_name_item.first, std::move(out_vars)); outs_map.emplace(var_name_item.first, std::move(out_vars));
} }
// set runtime_ctx and infershape_ctx_
instr_node->runtime_ctx_.reset(new RuntimeContext({}, {})); instr_node->ResetContext(ins_map, outs_map);
instr_node->runtime_ctx_->inputs.swap(ins_map);
instr_node->runtime_ctx_->outputs.swap(outs_map);
instr_node->infershape_ctx_.reset(new InterpretercoreInferShapeContext(
*op_base, *instr_node->runtime_ctx_.get()));
auto* dev_ctx = instr_node->dev_ctx_;
Scope scope;
instr_node->execution_ctx_.reset(new ExecutionContext(
*op_base, scope, *dev_ctx, *instr_node->runtime_ctx_.get()));
} }
void InterpreterCore::BuildSkipShareLoDInfo() { void InterpreterCore::BuildSkipShareLoDInfo() {
for (size_t i = 0; i < vec_instruction_.size(); ++i) { for (size_t i = 0; i < vec_instruction_.size(); ++i) {
bool can_skip_lod = true; bool can_skip_lod = true;
for (auto& input : vec_instruction_[i].runtime_ctx_.get()->inputs) { for (auto& input : vec_instruction_[i].InnerRuntimeContext()->inputs) {
for (auto& var : input.second) { for (auto& var : input.second) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
if (var->Get<LoDTensor>().lod().size() != 0) { if (var->Get<LoDTensor>().lod().size() != 0) {
...@@ -328,23 +303,21 @@ void InterpreterCore::BuildSkipShareLoDInfo() { ...@@ -328,23 +303,21 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
} }
} }
} }
vec_instruction_[i].infershape_ctx_.get()->SetSkipLoD(can_skip_lod); vec_instruction_[i].InnerInferShapeContext()->SetSkipLoD(can_skip_lod);
} }
} }
void InterpreterCore::RunInstruction(const Instruction& instr_node) { void InterpreterCore::RunInstruction(const Instruction& instr_node) {
VLOG(3) << "RunInstruction: " VLOG(3) << "RunInstruction: " << instr_node.OpBase()->Type();
<< instr_node.kernel_func_.operator_base_->Type();
{ {
platform::RecordEvent infershape_event("InferShape"); platform::RecordEvent infershape_event("InferShape");
static_cast<const framework::OperatorWithKernel*>( static_cast<const framework::OperatorWithKernel*>(instr_node.OpBase())
instr_node.kernel_func_.operator_base_) ->InferShape(instr_node.InnerInferShapeContext().get());
->InferShape(instr_node.infershape_ctx_.get());
} }
if (FLAGS_new_executor_use_inplace) { if (FLAGS_new_executor_use_inplace) {
for (auto& pair : instr_node.vec_inplace_in_to_out_) { for (auto& pair : instr_node.InplaceInfo()) {
const auto& in = paddle::framework::details::GetTensorFromVar(pair.first); const auto& in = paddle::framework::details::GetTensorFromVar(pair.first);
auto* out = auto* out =
paddle::framework::details::GetMutableTensorFromVar(pair.second); paddle::framework::details::GetMutableTensorFromVar(pair.second);
...@@ -355,7 +328,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -355,7 +328,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
} }
{ {
platform::RecordEvent compute_event("Compute"); platform::RecordEvent compute_event("Compute");
instr_node.kernel_func_.compute_func_(*instr_node.execution_ctx_.get()); instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
} }
} }
...@@ -369,7 +342,7 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -369,7 +342,7 @@ void InterpreterCore::ExecuteInstructionList(
for (size_t i = 0; i < dependecy_count_.size(); ++i) { for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) { if (dependecy_count_[i] == 0) {
async_work_queue_.AddTask(vec_instr[i].type_, async_work_queue_.AddTask(vec_instr.at(i).KernelType(),
[&, i] { RunInstructionAsync(i); }); [&, i] { RunInstructionAsync(i); });
} }
} }
...@@ -391,43 +364,43 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -391,43 +364,43 @@ void InterpreterCore::ExecuteInstructionList(
void InterpreterCore::RunNextInstructions( void InterpreterCore::RunNextInstructions(
const Instruction& instr, std::queue<size_t>* reserved_next_ops) { const Instruction& instr, std::queue<size_t>* reserved_next_ops) {
auto& next_instr = instr.next_instruction_; auto& next_instr = instr.NextInstructions();
auto& atomic_deps = async_work_queue_.AtomicDeps(); auto& atomic_deps = async_work_queue_.AtomicDeps();
auto IsReady = [&](size_t next_id) { auto IsReady = [&](size_t next_id) {
return atomic_deps[next_id]->fetch_sub(1, std::memory_order_relaxed) == 1; return atomic_deps[next_id]->fetch_sub(1, std::memory_order_relaxed) == 1;
}; };
if (instr.type_ == OpFuncType::kQueueAsync) { if (instr.KernelType() == OpFuncType::kQueueAsync) {
// move all sync_ops into other threads // move all sync_ops into other threads
for (auto next_id : next_instr.synchronize_run_) { for (auto next_id : next_instr.SyncRunIds()) {
if (IsReady(next_id)) { if (IsReady(next_id)) {
async_work_queue_.AddTask( async_work_queue_.AddTask(
vec_instruction_[next_id].type_, vec_instruction_[next_id].KernelType(),
[&, next_id] { RunInstructionAsync(next_id); }); [&, next_id] { RunInstructionAsync(next_id); });
} }
} }
// keep all async_ops running in current thread // keep all async_ops running in current thread
for (auto next_id : next_instr.direct_run_) { for (auto next_id : next_instr.DirectRunIds()) {
if (IsReady(next_id)) { if (IsReady(next_id)) {
reserved_next_ops->push(next_id); reserved_next_ops->push(next_id);
} }
} }
for (auto next_id : next_instr.event_wait_run_) { for (auto next_id : next_instr.EventRunIds()) {
if (IsReady(next_id)) { if (IsReady(next_id)) {
reserved_next_ops->push(next_id); reserved_next_ops->push(next_id);
} }
} }
} else { } else {
// move async_ops into async_thread // move async_ops into async_thread
for (auto next_id : next_instr.event_wait_run_) { for (auto next_id : next_instr.EventRunIds()) {
if (IsReady(next_id)) { if (IsReady(next_id)) {
async_work_queue_.AddTask( async_work_queue_.AddTask(
vec_instruction_[next_id].type_, vec_instruction_[next_id].KernelType(),
[&, next_id] { RunInstructionAsync(next_id); }); [&, next_id] { RunInstructionAsync(next_id); });
} }
} }
auto direct_run_ops = interpretercore::merge_vector( auto direct_run_ops = interpretercore::merge_vector(
next_instr.synchronize_run_, next_instr.direct_run_); next_instr.SyncRunIds(), next_instr.DirectRunIds());
size_t first_op = 0; size_t first_op = 0;
for (auto next_id : direct_run_ops) { for (auto next_id : direct_run_ops) {
if (IsReady(next_id)) { if (IsReady(next_id)) {
...@@ -438,7 +411,7 @@ void InterpreterCore::RunNextInstructions( ...@@ -438,7 +411,7 @@ void InterpreterCore::RunNextInstructions(
} }
// move rest ops into other threads // move rest ops into other threads
async_work_queue_.AddTask( async_work_queue_.AddTask(
vec_instruction_[next_id].type_, vec_instruction_[next_id].KernelType(),
[&, next_id] { RunInstructionAsync(next_id); }); [&, next_id] { RunInstructionAsync(next_id); });
} }
} }
...@@ -452,8 +425,8 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { ...@@ -452,8 +425,8 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
while (!ready_ops.empty()) { while (!ready_ops.empty()) {
instr_id = ready_ops.front(); instr_id = ready_ops.front();
ready_ops.pop(); ready_ops.pop();
auto& instr_node = vec_instruction_[instr_id]; auto& instr_node = vec_instruction_.at(instr_id);
auto* op = instr_node.kernel_func_.operator_base_; auto* op = instr_node.OpBase();
platform::RecordEvent instruction_event(op->Type()); platform::RecordEvent instruction_event(op->Type());
event_manager_.WaitEvent(instr_node, place_); event_manager_.WaitEvent(instr_node, place_);
...@@ -486,28 +459,27 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { ...@@ -486,28 +459,27 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
op_run_number_.fetch_add(1, std::memory_order_relaxed); op_run_number_.fetch_add(1, std::memory_order_relaxed);
// GC infomation // GC infomation
CheckGC(instr_id, instr_node.gc_check_var_list); CheckGC(instr_node);
RunNextInstructions(instr_node, &ready_ops); RunNextInstructions(instr_node, &ready_ops);
} }
} }
void InterpreterCore::CheckGC(size_t instr_id, void InterpreterCore::CheckGC(const Instruction& instr) {
const std::vector<size_t>& gc_check_list) { size_t instr_id = instr.Id();
auto& var_scope = *global_scope_; auto& var_scope = *global_scope_;
auto& atomic_var_ref = async_work_queue_.AtomicVarRef(); auto& atomic_var_ref = async_work_queue_.AtomicVarRef();
for (auto var_id : gc_check_list) { for (auto var_id : instr.GCCheckVars()) {
bool is_ready = bool is_ready =
atomic_var_ref[var_id]->fetch_sub(1, std::memory_order_relaxed) == 1; atomic_var_ref[var_id]->fetch_sub(1, std::memory_order_relaxed) == 1;
if (is_ready && var_scope.vec_meta_info_[var_id].vardesc_ && if (is_ready && var_scope.VarDesc(var_id) &&
!var_scope.vec_meta_info_[var_id].vardesc_->Persistable()) { !var_scope.VarDesc(var_id)->Persistable()) {
gc_.Add(var_scope.var_list[var_id], gc_event_[instr_id], gc_.Add(var_scope.Var(var_id), gc_event_.at(instr_id),
vec_instruction_[instr_id].dev_ctx_); &instr.DeviceContext());
} else if (is_ready && } else if (is_ready && var_scope.VarDesc(var_id) == nullptr) {
var_scope.vec_meta_info_[var_id].vardesc_ == nullptr) { gc_.Add(var_scope.Var(var_id), gc_event_.at(instr_id),
gc_.Add(var_scope.var_list[var_id], gc_event_[instr_id], &instr.DeviceContext());
vec_instruction_[instr_id].dev_ctx_);
} }
} }
} }
...@@ -516,11 +488,11 @@ void InterpreterCore::DryRunPrepare( ...@@ -516,11 +488,11 @@ void InterpreterCore::DryRunPrepare(
const std::vector<framework::Tensor>& feed_tensors) { const std::vector<framework::Tensor>& feed_tensors) {
auto FeedInput = [&] { auto FeedInput = [&] {
for (size_t i = 0; i < feed_names_.size(); ++i) { for (size_t i = 0; i < feed_names_.size(); ++i) {
auto it = global_scope_->name2id.find(feed_names_[i]); auto* feed_var = global_scope_->FindVar(feed_names_[i]);
assert(it != global_scope_->name2id.end()); PADDLE_ENFORCE_NOT_NULL(feed_var, platform::errors::NotFound(
"feed_var shall not be nullptr."));
auto feed_tensor = global_scope_->var_list[it->second] auto feed_tensor = feed_var->GetMutable<framework::LoDTensor>();
->GetMutable<framework::LoDTensor>();
feed_tensor->ShareDataWith(feed_tensors[i]); feed_tensor->ShareDataWith(feed_tensors[i]);
} }
}; };
...@@ -530,7 +502,7 @@ void InterpreterCore::DryRunPrepare( ...@@ -530,7 +502,7 @@ void InterpreterCore::DryRunPrepare(
global_scope_); global_scope_);
FeedInput(); FeedInput();
paddle::framework::interpretercore::build_op_func_list( paddle::framework::interpretercore::build_op_func_list(
place_, main_program_, &op_list_, &vec_func_list_, global_scope_); place_, main_program_, &vec_func_list_, global_scope_);
is_build_ = true; is_build_ = true;
// convert vec func_list to graph // convert vec func_list to graph
Convert(); Convert();
......
...@@ -67,7 +67,7 @@ class InterpreterCore { ...@@ -67,7 +67,7 @@ class InterpreterCore {
void DryRunPrepare(const std::vector<framework::Tensor>& feed_tensors); void DryRunPrepare(const std::vector<framework::Tensor>& feed_tensors);
void CheckGC(size_t instr_id, const std::vector<size_t>& gc_check_list); void CheckGC(const Instruction& instr);
void RunInstructionAsync(size_t instr_id); void RunInstructionAsync(size_t instr_id);
void RunNextInstructions(const Instruction& instr_id, void RunNextInstructions(const Instruction& instr_id,
...@@ -82,16 +82,15 @@ class InterpreterCore { ...@@ -82,16 +82,15 @@ class InterpreterCore {
ProgramDesc main_program_; ProgramDesc main_program_;
VariableScope* global_scope_; VariableScope* global_scope_;
std::vector<Instruction> vec_instruction_; std::vector<paddle::framework::OpFuncNode> vec_func_list_;
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
InstructionInfo instruction_info_; InstructionInfo instruction_info_;
std::vector<size_t> dependecy_count_; std::vector<size_t> dependecy_count_;
std::vector<std::vector<size_t>> input_var2op_info_; std::vector<std::vector<size_t>> input_var2op_info_;
std::vector<VariableMetaInfo> ref_coun_info_; std::vector<VariableMetaInfo> ref_coun_info_;
std::vector<VariableMetaInfo> vec_meta_info_; std::vector<VariableMetaInfo> vec_meta_info_;
std::vector<paddle::framework::OpFuncNode> vec_func_list_;
std::vector<paddle::framework::OperatorBase*> op_list_;
std::vector<std::string> feed_names_; std::vector<std::string> feed_names_;
InterpreterProfiler dry_run_profiler_; InterpreterProfiler dry_run_profiler_;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace interpretercore { namespace interpretercore {
using VariableIdMap = std::map<std::string, std::vector<int>>;
AtomicVectorSizeT& AsyncWorkQueue::PrepareAtomicDeps( AtomicVectorSizeT& AsyncWorkQueue::PrepareAtomicDeps(
const std::vector<size_t>& dependecy_count) { const std::vector<size_t>& dependecy_count) {
...@@ -132,43 +133,29 @@ void build_variable_scope(const framework::ProgramDesc& pdesc, ...@@ -132,43 +133,29 @@ void build_variable_scope(const framework::ProgramDesc& pdesc,
VariableScope* var_scope) { VariableScope* var_scope) {
auto& global_block = pdesc.Block(0); auto& global_block = pdesc.Block(0);
for (auto& var : global_block.AllVars()) { for (auto& var_desc : global_block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) { auto var_name = var_desc->Name();
if (var_name == framework::kEmptyVarName) {
continue; continue;
} }
if (var_scope->name2id.find(var->Name()) == var_scope->name2id.end()) { if (nullptr == var_scope->FindVar(var_name)) {
var_scope->name2id[var->Name()] = var_scope->var_list.size(); var_scope->AddVar(var_desc->Name(), var_desc);
auto v = new Variable();
InitializeVariable(v, var->GetType());
var_scope->var_list.push_back(v);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = var;
var_scope->vec_meta_info_.push_back(info);
} else { } else {
auto var_id = var_scope->name2id[var->Name()]; auto* var_desc = var_scope->VarDesc(var_name);
if (nullptr == var_scope->vec_meta_info_[var_id].vardesc_) { if (nullptr == var_desc) {
VLOG(3) << "update var:" << var->Name() << " desc from nullptr into " VLOG(3) << "update var:" << var_name << " desc from nullptr into "
<< var; << var_desc;
var_scope->vec_meta_info_[var_id].vardesc_ = var; var_scope->VarMetaInfo(var_name).vardesc_ = var_desc;
} }
} }
} }
} }
void build_op_func_list(const platform::Place& place, std::vector<OperatorBase*> create_all_ops(const framework::BlockDesc& block) {
const framework::ProgramDesc& pdesc,
std::vector<OperatorBase*>* op_list,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope) {
auto& global_block = pdesc.Block(0);
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
std::vector<OperatorBase*> ops; std::vector<OperatorBase*> ops;
for (auto& op : global_block.AllOps()) { for (auto& op : block.AllOps()) {
VLOG(3) << "Build OpFuncNode from : " << op->Type(); VLOG(3) << "CreateOp from : " << op->Type();
auto& info = OpInfoMap::Instance().Get(op->Type()); auto& info = OpInfoMap::Instance().Get(op->Type());
...@@ -179,64 +166,96 @@ void build_op_func_list(const platform::Place& place, ...@@ -179,64 +166,96 @@ void build_op_func_list(const platform::Place& place,
if (info.Checker() != nullptr) { if (info.Checker() != nullptr) {
info.Checker()->Check(&op_attr_map); info.Checker()->Check(&op_attr_map);
} }
// step 1. Prepare VariableValueMap of input/output
auto op_base = auto op_base =
info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map); info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
ops.push_back(op_base); ops.push_back(op_base);
} }
return ops;
}
std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
const VariableNameMap& var_name_map, VariableScope* var_scope) {
VariableValueMap name2var;
VariableIdMap name2id;
for (auto& item : var_name_map) {
std::vector<Variable*> vars;
std::vector<int> ids;
vars.reserve(item.second.size());
for (auto& var_name : item.second) {
auto var_id = var_scope->VarId(var_name);
auto* in_var = var_scope->Var(var_id);
vars.push_back(in_var);
ids.push_back(var_id);
}
name2var[item.first] = std::move(vars);
name2id[item.first] = std::move(ids);
}
return std::make_tuple(name2var, name2id);
}
void apply_device_guard(const OperatorBase* op_base,
const platform::Place& place,
OpKernelType* expected_kernel_key) {
bool need_change_place =
(op_base->HasAttr("op_device") &&
(op_base->Attr<std::string>("op_device").length() > 0));
if (need_change_place) {
auto& op_device = op_base->Attr<std::string>("op_device");
if (op_device == "cpu" || platform::is_cpu_place(place)) {
VLOG(3) << "Switch into CPUPlace by device_guard.";
expected_kernel_key->place_ = platform::CPUPlace();
} else if (op_device.find("gpu") != std::string::npos &&
platform::is_gpu_place(place)) {
VLOG(3) << "Switch into " << place << " by device_guard.";
expected_kernel_key->place_ = place;
} else {
PADDLE_THROW(
platform::errors::Fatal("Unsupported current place %s", op_device));
}
}
}
void build_op_func_list(const platform::Place& place,
const framework::ProgramDesc& pdesc,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope) {
auto& global_block = pdesc.Block(0);
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
// Step 1: create all ops for global block.
auto ops = create_all_ops(global_block);
auto unused_var_map = get_unused_vars(global_block, ops); auto unused_var_map = get_unused_vars(global_block, ops);
size_t ops_index = 0; size_t ops_index = 0;
for (auto& op : global_block.AllOps()) { for (auto& op : global_block.AllOps()) {
VLOG(3) << op->Type(); VLOG(3) << "Build OpFuncNode from : " << op->Type();
// << op->Type() << endl;
auto op_base = ops[ops_index++]; auto op_base = ops[ops_index++];
auto inputs_names = op->Inputs(); auto inputs_names = op->Inputs();
auto outputs_names = op->Outputs(); auto outputs_names = op->Outputs();
VariableValueMap ins_map; VariableValueMap ins_map;
std::map<std::string, std::vector<int>> ins_name2id; VariableIdMap ins_name2id;
for (auto& var_name_item : inputs_names) { std::tie(ins_map, ins_name2id) =
std::vector<Variable*> input_vars; build_variable_map(inputs_names, var_scope);
std::vector<int> vec_ids;
input_vars.reserve(var_name_item.second.size());
for (auto& var_name : var_name_item.second) {
auto it = var_scope->name2id.find(var_name);
assert(it != var_scope->name2id.end());
input_vars.push_back(var_scope->var_list[it->second]);
vec_ids.push_back(it->second);
}
ins_map[var_name_item.first] = input_vars;
ins_name2id[var_name_item.first] = vec_ids;
}
VariableValueMap outs_map; VariableValueMap outs_map;
std::map<std::string, std::vector<int>> outs_name2id; VariableIdMap outs_name2id;
for (auto& var_name_item : outputs_names) { std::tie(outs_map, outs_name2id) =
std::vector<Variable*> output_vars; build_variable_map(outputs_names, var_scope);
std::vector<int> vec_ids;
output_vars.reserve(var_name_item.second.size());
for (auto& var_name : var_name_item.second) {
auto it = var_scope->name2id.find(var_name);
assert(it != var_scope->name2id.end());
output_vars.push_back(var_scope->var_list[it->second]);
vec_ids.push_back(it->second);
}
outs_map[var_name_item.first] = output_vars;
outs_name2id[var_name_item.first] = vec_ids;
}
// step 2: build OpFuncNode
OpFuncNode op_func_node; OpFuncNode op_func_node;
op_func_node.input_index = ins_name2id; op_func_node.input_index = ins_name2id;
op_func_node.output_index = outs_name2id; op_func_node.output_index = outs_name2id;
// step 2: construct RuntimeContext and analysis KernelType // construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {}); RuntimeContext runtime_context({}, {});
runtime_context.inputs.swap(ins_map); runtime_context.inputs.swap(ins_map);
runtime_context.outputs.swap(outs_map); runtime_context.outputs.swap(outs_map);
InterpretercoreInferShapeContext infer_shape_ctx(*op_base, runtime_context); InterpretercoreInferShapeContext infer_shape_ctx(*op_base, runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT inheritted
// from OperatorWithKernel.
static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape( static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape(
&infer_shape_ctx); &infer_shape_ctx);
auto kernels_iter = all_op_kernels.find(op->Type()); auto kernels_iter = all_op_kernels.find(op->Type());
...@@ -256,32 +275,18 @@ void build_op_func_list(const platform::Place& place, ...@@ -256,32 +275,18 @@ void build_op_func_list(const platform::Place& place,
->GetExpectedKernelType( ->GetExpectedKernelType(
ExecutionContext(*op_base, scope, *dev_ctx, runtime_context)); ExecutionContext(*op_base, scope, *dev_ctx, runtime_context));
// consider device_guard context // consider device_guard()
bool need_change_place = apply_device_guard(op_base, place, &expected_kernel_key);
(op_base->HasAttr("op_device") &&
(op_base->Attr<std::string>("op_device").length() > 0));
if (need_change_place) {
auto& op_device = op_base->Attr<std::string>("op_device");
if (op_device == "cpu" || platform::is_cpu_place(place)) {
VLOG(3) << "Switch into CPUPlace by device_guard.";
expected_kernel_key.place_ = platform::CPUPlace();
} else if (op_device.find("gpu") != std::string::npos &&
platform::is_gpu_place(place)) {
VLOG(3) << "Switch into " << place << " by device_guard.";
expected_kernel_key.place_ = place;
} else {
PADDLE_THROW(
platform::errors::Fatal("Unsupported current place %s", op_device));
}
}
VLOG(3) << "expected_kernel_key : " << expected_kernel_key; VLOG(3) << "expected_kernel_key : " << expected_kernel_key;
// step 3. Insert memcpy_op if needed // step 3. Insert memcpy_op if needed
VariableValueMap& ins_map_temp = runtime_context.inputs; VariableValueMap& ins_map_temp = runtime_context.inputs;
std::unordered_set<int> no_data_transform_index; std::unordered_set<int> no_data_transform_index;
for (auto& var_name_item : ins_map_temp) { for (auto& var_name_item : ins_map_temp) {
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto var = var_name_item.second[i]; auto var = var_name_item.second[i];
auto& var_name = inputs_names[var_name_item.first].at(i);
auto tensor_in = static_cast<const Tensor*>(&(var->Get<LoDTensor>())); auto tensor_in = static_cast<const Tensor*>(&(var->Get<LoDTensor>()));
if (!tensor_in->IsInitialized()) { if (!tensor_in->IsInitialized()) {
continue; continue;
...@@ -293,32 +298,19 @@ void build_op_func_list(const platform::Place& place, ...@@ -293,32 +298,19 @@ void build_op_func_list(const platform::Place& place,
if (platform::is_same_place(kernel_type_for_var.place_, if (platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_key.place_)) { expected_kernel_key.place_)) {
// record no need data transformer input var_id // record no need data transformer input var_id
auto& var_name = inputs_names[var_name_item.first][i];
VLOG(3) << op->Type() << " found no data_transform var: " << var_name VLOG(3) << op->Type() << " found no data_transform var: " << var_name
<< " with id: " << var_scope->name2id[var_name]; << " with id: " << var_name;
no_data_transform_index.emplace(var_scope->name2id[var_name]); no_data_transform_index.emplace(var_scope->VarId(var_name));
} else { } else {
if (op_base->Type() == "fetch_v2") { if (op_base->Type() == "fetch_v2") {
op_base->SetAttr("deepcopy", false); op_base->SetAttr("deepcopy", false);
} }
// need trans place
// 1. add var in scope
// 2. add copy op
std::string new_var_name = std::string new_var_name =
"temp_1" + std::to_string(var_scope->var_list.size() + 1); var_name + "_copy_" + std::to_string(var_scope->VarSize() + 1);
auto v = new Variable(); var_scope->AddVar(new_var_name, nullptr);
v->GetMutable<LoDTensor>();
var_scope->name2id[new_var_name] = var_scope->var_list.size();
var_scope->var_list.push_back(v);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = nullptr;
var_scope->vec_meta_info_.push_back(info);
VariableNameMap copy_in_map; VariableNameMap copy_in_map;
auto x_iter = inputs_names.find(var_name_item.first); copy_in_map["X"] = {var_name};
copy_in_map["X"] = {x_iter->second[i]};
VariableNameMap copy_out_map; VariableNameMap copy_out_map;
copy_out_map["Out"] = {new_var_name}; copy_out_map["Out"] = {new_var_name};
AttributeMap attr_map; AttributeMap attr_map;
...@@ -328,23 +320,23 @@ void build_op_func_list(const platform::Place& place, ...@@ -328,23 +320,23 @@ void build_op_func_list(const platform::Place& place,
: is_gpu_place(expected_kernel_key.place_) ? 1 : -1; : is_gpu_place(expected_kernel_key.place_) ? 1 : -1;
std::map<std::string, std::vector<int>> copy_ins_name2id; std::map<std::string, std::vector<int>> copy_ins_name2id;
copy_ins_name2id["X"] = ins_name2id[var_name_item.first]; copy_ins_name2id["X"] = ins_name2id.at(var_name_item.first);
std::map<std::string, std::vector<int>> copy_out_name2id; std::map<std::string, std::vector<int>> copy_out_name2id;
copy_out_name2id["Out"] = {var_scope->name2id[new_var_name]}; copy_out_name2id["Out"] = {var_scope->VarId(new_var_name)};
op_func_node.input_index[var_name_item.first][i] = op_func_node.input_index[var_name_item.first][i] =
var_scope->name2id[new_var_name]; var_scope->VarId(new_var_name);
VariableValueMap copy_ins_value_map; VariableValueMap copy_ins_value_map;
copy_ins_value_map["X"] = {var}; copy_ins_value_map["X"] = {var};
VariableValueMap copy_outs_value_map; VariableValueMap copy_outs_value_map;
copy_outs_value_map["Out"] = {v}; copy_outs_value_map["Out"] = {var_scope->Var(new_var_name)};
// memcpy_d2h, memcpy_h2d // memcpy_d2h, memcpy_h2d
auto memcpy_op_type = get_memcpy_type(kernel_type_for_var.place_, auto memcpy_op_type = get_memcpy_type(kernel_type_for_var.place_,
expected_kernel_key.place_); expected_kernel_key.place_);
VLOG(3) << string::Sprintf("Insert %s with %s(%s) -> %s(%s).", VLOG(3) << string::Sprintf("Insert %s with %s(%s) -> %s(%s).",
memcpy_op_type, x_iter->second[i], memcpy_op_type, var_name,
kernel_type_for_var.place_, new_var_name, kernel_type_for_var.place_, new_var_name,
expected_kernel_key.place_); expected_kernel_key.place_);
auto& copy_info = OpInfoMap::Instance().Get(memcpy_op_type); auto& copy_info = OpInfoMap::Instance().Get(memcpy_op_type);
...@@ -385,16 +377,16 @@ void build_op_func_list(const platform::Place& place, ...@@ -385,16 +377,16 @@ void build_op_func_list(const platform::Place& place,
// as kQueueSync and execute them in thread pool. // as kQueueSync and execute them in thread pool.
copy_op_func_node.type_ = OpFuncType::kQueueSync; copy_op_func_node.type_ = OpFuncType::kQueueSync;
copy_op_func_node.dev_ctx_ = dev_ctx; copy_op_func_node.dev_ctx_ = dev_ctx;
op_list->push_back(copy_op); copy_op_func_node.operator_base_ = copy_op;
vec_func_list->push_back(copy_op_func_node); vec_func_list->push_back(copy_op_func_node);
var_name_item.second[i] = v; var_name_item.second[i] = var_scope->Var(new_var_name);
} }
} }
} }
op_func_node.no_data_transform_index = std::move(no_data_transform_index); op_func_node.no_data_transform_index = std::move(no_data_transform_index);
// step 4. Run op kernel // step 4. Run op kernel
op_list->push_back(op_base); op_func_node.operator_base_ = op_base;
VLOG(3) << op_base->Type() VLOG(3) << op_base->Type()
<< " : expected_kernel_key : " << expected_kernel_key; << " : expected_kernel_key : " << expected_kernel_key;
...@@ -436,9 +428,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -436,9 +428,7 @@ void build_op_func_list(const platform::Place& place,
new std::deque<std::shared_ptr<memory::Allocation>>(); new std::deque<std::shared_ptr<memory::Allocation>>();
for (auto& var_name : delete_vars) { for (auto& var_name : delete_vars) {
auto it = var_scope->name2id.find(var_name); auto* var = var_scope->FindVar(var_name);
assert(it != var_scope->name2id.end());
auto* var = var_scope->var_list[it->second];
if (var == nullptr) { if (var == nullptr) {
continue; continue;
} }
......
...@@ -101,7 +101,6 @@ void build_variable_scope(const framework::ProgramDesc& pdesc, ...@@ -101,7 +101,6 @@ void build_variable_scope(const framework::ProgramDesc& pdesc,
void build_op_func_list(const platform::Place& place, void build_op_func_list(const platform::Place& place,
const framework::ProgramDesc& pdesc, const framework::ProgramDesc& pdesc,
std::vector<OperatorBase*>* op_list,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope); VariableScope* var_scope);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/device_event_base.h" #include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/event.h" #include "paddle/fluid/platform/event.h"
...@@ -463,7 +464,6 @@ class InterpretercoreInferShapeContext : public InferShapeContext { ...@@ -463,7 +464,6 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
struct OpKernelFunc { struct OpKernelFunc {
OpKernelComputeFunc compute_func_; OpKernelComputeFunc compute_func_;
OperatorBase* operator_base_;
}; };
struct VariableMetaInfo { struct VariableMetaInfo {
...@@ -471,13 +471,108 @@ struct VariableMetaInfo { ...@@ -471,13 +471,108 @@ struct VariableMetaInfo {
paddle::framework::VarDesc* vardesc_; paddle::framework::VarDesc* vardesc_;
}; };
struct VariableScope { // TODO(Aurelius84): Consider inherit ScopeBase to unify interface.
class VariableScope {
public:
Variable* FindVar(const std::string& name) const {
if (!HasVar(name)) {
return nullptr;
}
auto var_id = VarId(name);
CheckExist(var_id);
return var_list[var_id];
}
bool HasVar(const std::string& name) const {
return name2id.find(name) != name2id.end();
}
int VarId(const std::string& name) const {
CheckExist(name);
return name2id.at(name);
}
Variable* Var(int id) const { return var_list.at(id); }
Variable* Var(const std::string& name) const {
return var_list.at(VarId(name));
}
size_t VarSize() const { return var_list.size(); }
void AddVar(const std::string& name, VarDesc* var_desc) { // NOLINT
name2id[name] = VarSize();
auto v = new Variable();
if (nullptr == var_desc) {
v->GetMutable<LoDTensor>();
} else {
InitializeVariable(v, var_desc->GetType());
}
var_list.push_back(v);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = var_desc;
vec_meta_info_.push_back(info);
}
void AddVar(const std::string& name, Variable& var) { // NOLINT
name2id[name] = VarSize();
var_list.push_back(&var);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = nullptr;
vec_meta_info_.push_back(info);
}
paddle::framework::VarDesc* VarDesc(const std::string& name) const {
return VarDesc(VarId(name));
}
paddle::framework::VarDesc* VarDesc(int id) const {
CheckExist(id);
return vec_meta_info_[id].vardesc_;
}
VariableMetaInfo& VarMetaInfo(const std::string& name) {
return vec_meta_info_[VarId(name)];
}
void CheckExist(int id) const {
PADDLE_ENFORCE_LT(id, var_list.size(),
platform::errors::PreconditionNotMet(
"Required var_id < %d, but received var_id = %d.",
var_list.size(), id));
}
void CheckExist(const std::string& name) const {
PADDLE_ENFORCE_EQ(
HasVar(name), true,
platform::errors::NotFound("%s not in VariableScope.", name));
}
private:
std::vector<Variable*> var_list; std::vector<Variable*> var_list;
std::map<std::string, int> name2id; std::map<std::string, int> name2id;
std::vector<VariableMetaInfo> vec_meta_info_; std::vector<VariableMetaInfo> vec_meta_info_;
}; };
struct NextInstruction { class NextInstruction {
public:
void AddDirectRun(size_t id) { direct_run_.push_back(id); }
void ADDEventRun(size_t id) { event_wait_run_.push_back(id); }
void AddSyncRun(size_t id) { synchronize_run_.push_back(id); }
const std::vector<size_t>& DirectRunIds() const { return direct_run_; }
const std::vector<size_t>& EventRunIds() const { return event_wait_run_; }
const std::vector<size_t>& SyncRunIds() const { return synchronize_run_; }
private:
std::vector<size_t> direct_run_; std::vector<size_t> direct_run_;
std::vector<size_t> event_wait_run_; std::vector<size_t> event_wait_run_;
std::vector<size_t> synchronize_run_; std::vector<size_t> synchronize_run_;
...@@ -503,49 +598,138 @@ enum class OpFuncType { ...@@ -503,49 +598,138 @@ enum class OpFuncType {
}; };
class RuntimeInferShapeContext; class RuntimeInferShapeContext;
struct Instruction { struct OpFuncNode {
OpKernelFunc kernel_func_; OperatorBase* operator_base_;
std::map<std::string, std::vector<int>> input_index;
std::map<std::string, std::vector<int>> output_index;
std::unordered_set<int> no_data_transform_index;
OpKernelComputeFunc kernel_func_;
platform::DeviceContext* dev_ctx_; // not owned
OpFuncType type_;
};
class Instruction {
public:
Instruction(size_t id, const OpFuncNode& op_func_node,
const platform::DeviceContext& dev_ctx)
: id_(id), op_func_node_(op_func_node), dev_ctx_(dev_ctx) {
PADDLE_ENFORCE_GE(id, 0, platform::errors::PreconditionNotMet(
"Required id >= 0, but received id = %d", id));
}
size_t Id() const { return id_; }
const std::map<std::string, std::vector<int>>& Inputs() const {
return op_func_node_.input_index;
}
const std::map<std::string, std::vector<int>>& Outputs() const {
return op_func_node_.output_index;
}
const std::unordered_set<int>& NoDataTransformVars() const {
return op_func_node_.no_data_transform_index;
}
OpKernelComputeFunc KernelFunc() const { return op_func_node_.kernel_func_; }
OpFuncType KernelType() const { return op_func_node_.type_; }
OperatorBase* OpBase() const {
auto* op_base = op_func_node_.operator_base_;
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
"op_base shall not be nullptr."));
return op_base;
}
NextInstruction& NextInstructions() { return next_instruction_; }
const NextInstruction& NextInstructions() const { return next_instruction_; }
void AddGCCheckVar(size_t id) { gc_check_var_list_.push_back(id); }
const std::vector<size_t>& GCCheckVars() const { return gc_check_var_list_; }
void ResetContext(const VariableValueMap& in_vars,
const VariableValueMap& out_vars) {
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
infershape_ctx_.reset(
new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get()));
// NOTE: Because execution_ctx_ is constructed by `scope&`, so we fake an
// empty here to avoid illegal local reference.
static framework::Scope scope_;
execution_ctx_.reset(
new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get()));
}
std::shared_ptr<RuntimeContext> InnerRuntimeContext() const {
return runtime_ctx_;
}
std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext()
const {
return infershape_ctx_;
}
std::shared_ptr<ExecutionContext> InnerExecutionContext() const {
return execution_ctx_;
}
const platform::DeviceContext& DeviceContext() const { return dev_ctx_; }
const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const {
return vec_inplace_in_to_out_;
}
void AddInplace(Variable* in, Variable* out) {
vec_inplace_in_to_out_.emplace_back(in, out);
}
const std::vector<EventInter>& InputEvents() const { return intput_events_; }
const std::vector<EventInter>& OutputEvents() const { return output_events_; }
void AddInputEvent(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) {
intput_events_.emplace_back(var_id, event, waiter_type);
}
void AddOutputEvent(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) {
output_events_.emplace_back(var_id, event, waiter_type);
}
private:
size_t id_;
const OpFuncNode& op_func_node_; // not owned
const platform::DeviceContext& dev_ctx_; // not owned
std::shared_ptr<RuntimeContext> runtime_ctx_; std::shared_ptr<RuntimeContext> runtime_ctx_;
std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_; std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_;
std::shared_ptr<ExecutionContext> execution_ctx_; std::shared_ptr<ExecutionContext> execution_ctx_;
std::map<std::string, std::vector<int>> input_index_;
std::map<std::string, std::vector<int>> output_index_;
std::unordered_set<int> no_data_transform_index_;
std::vector<size_t> gc_check_var_list; std::vector<size_t> gc_check_var_list_;
NextInstruction next_instruction_; NextInstruction next_instruction_;
std::vector<EventInter> intput_events_; std::vector<EventInter> intput_events_;
std::vector<EventInter> output_events_; std::vector<EventInter> output_events_;
platform::DeviceContext* dev_ctx_; // not owned
OpFuncType type_;
std::vector<std::pair<Variable*, Variable*>> vec_inplace_in_to_out_; std::vector<std::pair<Variable*, Variable*>> vec_inplace_in_to_out_;
}; };
struct OpFuncNode {
// int unsed;
std::map<std::string, std::vector<int>> input_index;
std::map<std::string, std::vector<int>> output_index;
std::unordered_set<int> no_data_transform_index;
OpKernelComputeFunc kernel_func_;
platform::DeviceContext* dev_ctx_; // not owned
OpFuncType type_;
};
namespace interpretercore { namespace interpretercore {
static constexpr char kMemcpyH2D[] = "memcpy_h2d"; static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h"; static constexpr char kMemcpyD2H[] = "memcpy_d2h";
static bool IsMemcpyH2D(const Instruction& instr) { static bool IsMemcpyH2D(const Instruction& instr) {
return instr.kernel_func_.operator_base_->Type() == kMemcpyH2D; return instr.OpBase()->Type() == kMemcpyH2D;
} }
static bool IsMemcpyD2H(const Instruction& instr) { static bool IsMemcpyD2H(const Instruction& instr) {
return instr.kernel_func_.operator_base_->Type() == kMemcpyD2H; return instr.OpBase()->Type() == kMemcpyD2H;
} }
} // namespace interpretercore } // namespace interpretercore
......
...@@ -33,23 +33,16 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -33,23 +33,16 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
auto name_list = outer_scope_->LocalVarNames(); auto name_list = outer_scope_->LocalVarNames();
for (auto name : name_list) { for (auto name : name_list) {
auto v = outer_scope_->Var(name); auto v = outer_scope_->Var(name);
if (global_scope_.name2id.find(name) == global_scope_.name2id.end()) { if (!global_scope_.HasVar(name)) {
global_scope_.name2id[name] = global_scope_.var_list.size(); global_scope_.AddVar(name, *v);
global_scope_.var_list.push_back(v);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = nullptr;
global_scope_.vec_meta_info_.push_back(info);
} }
} }
} }
// run startup program // run startup program
std::vector<paddle::framework::OpFuncNode> vec_func_list; std::vector<paddle::framework::OpFuncNode> vec_func_list;
std::vector<paddle::framework::OperatorBase*> op_list;
paddle::framework::interpretercore::build_op_func_list( paddle::framework::interpretercore::build_op_func_list(
place_, startup_prog, &op_list, &vec_func_list, &global_scope_); place_, startup_prog, &vec_func_list, &global_scope_);
} }
paddle::framework::FetchList StandaloneExecutor::Run( paddle::framework::FetchList StandaloneExecutor::Run(
...@@ -80,16 +73,8 @@ void StandaloneExecutor::BuildVariableOuterScope( ...@@ -80,16 +73,8 @@ void StandaloneExecutor::BuildVariableOuterScope(
continue; continue;
} }
if (var_scope->name2id.find(var->Name()) == var_scope->name2id.end()) { if (!var_scope->HasVar(var->Name())) {
var_scope->name2id[var->Name()] = var_scope->var_list.size(); var_scope->AddVar(var->Name(), var);
auto v = outer_scope->Var(var->Name());
InitializeVariable(v, var->GetType());
var_scope->var_list.push_back(v);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = var;
var_scope->vec_meta_info_.push_back(info);
} }
} }
} }
......
...@@ -31,15 +31,15 @@ namespace framework { ...@@ -31,15 +31,15 @@ namespace framework {
std::vector<size_t> StreamAnalyzer::ParseEventVarIds( std::vector<size_t> StreamAnalyzer::ParseEventVarIds(
const Instruction& cur_instr, const Instruction& next_instr) { const Instruction& cur_instr, const Instruction& next_instr) {
std::unordered_set<size_t> unique_var_ids; std::unordered_set<size_t> unique_var_ids;
for (auto& item : cur_instr.output_index_) { for (auto& item : cur_instr.Outputs()) {
unique_var_ids.insert(item.second.begin(), item.second.end()); unique_var_ids.insert(item.second.begin(), item.second.end());
} }
std::vector<size_t> new_event_var_ids; std::vector<size_t> new_event_var_ids;
for (auto& item : next_instr.input_index_) { for (auto& item : next_instr.Inputs()) {
for (auto var_id : item.second) { for (auto var_id : item.second) {
if (unique_var_ids.count(var_id) > 0 && if (unique_var_ids.count(var_id) > 0 &&
next_instr.no_data_transform_index_.count(var_id) == 0) { next_instr.NoDataTransformVars().count(var_id) == 0) {
new_event_var_ids.push_back(var_id); new_event_var_ids.push_back(var_id);
} }
} }
...@@ -57,8 +57,7 @@ void StreamAnalyzer::AssociateInputWithEvents( ...@@ -57,8 +57,7 @@ void StreamAnalyzer::AssociateInputWithEvents(
var_id2event_.emplace(var_id, std::move(device_event)); var_id2event_.emplace(var_id, std::move(device_event));
} }
// Add events for next_instr.inputs // Add events for next_instr.inputs
next_instr->intput_events_.emplace_back(var_id, var_id2event_.at(var_id), next_instr->AddInputEvent(var_id, var_id2event_.at(var_id), waiter_type);
waiter_type);
} }
} }
...@@ -66,13 +65,13 @@ void StreamAnalyzer::Schedule(const std::vector<size_t>& downstream_ops, ...@@ -66,13 +65,13 @@ void StreamAnalyzer::Schedule(const std::vector<size_t>& downstream_ops,
std::vector<Instruction>* instructions, std::vector<Instruction>* instructions,
size_t op_index) { size_t op_index) {
auto& cur_instr = instructions->at(op_index); auto& cur_instr = instructions->at(op_index);
auto& next_instruction = cur_instr.next_instruction_; auto& next_instruction = cur_instr.NextInstructions();
std::vector<size_t> event_var_ids; std::vector<size_t> event_var_ids;
for (auto next_op_id : downstream_ops) { for (auto next_op_id : downstream_ops) {
auto& next_instr = instructions->at(next_op_id); auto& next_instr = instructions->at(next_op_id);
if (IsDirectRun(cur_instr, next_instr)) { if (IsDirectRun(cur_instr, next_instr)) {
next_instruction.direct_run_.emplace_back(next_op_id); next_instruction.AddDirectRun(next_op_id);
} else { } else {
// Always insert events between different stream // Always insert events between different stream
auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr); auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr);
...@@ -83,24 +82,24 @@ void StreamAnalyzer::Schedule(const std::vector<size_t>& downstream_ops, ...@@ -83,24 +82,24 @@ void StreamAnalyzer::Schedule(const std::vector<size_t>& downstream_ops,
AssociateInputWithEvents(new_event_var_ids, &next_instr, waiter_type); AssociateInputWithEvents(new_event_var_ids, &next_instr, waiter_type);
if (waiter_type == platform::kCPU) { // GPU -> CPU if (waiter_type == platform::kCPU) { // GPU -> CPU
next_instruction.synchronize_run_.emplace_back(next_op_id); next_instruction.AddSyncRun(next_op_id);
} else { // GPU -> GPU(different stream) } else { // GPU -> GPU(different stream)
next_instruction.event_wait_run_.emplace_back(next_op_id); next_instruction.ADDEventRun(next_op_id);
} }
} }
} }
// Create events for these cross-stream vars // Create events for these cross-stream vars
VLOG(3) << cur_instr.kernel_func_.operator_base_->Type() VLOG(3) << cur_instr.OpBase()->Type()
<< " event_var_ids.size: " << event_var_ids.size(); << " event_var_ids.size: " << event_var_ids.size();
for (auto var_id : event_var_ids) { for (auto var_id : event_var_ids) {
cur_instr.output_events_.emplace_back(var_id, var_id2event_.at(var_id), cur_instr.AddOutputEvent(var_id, var_id2event_.at(var_id),
platform::kCUDA /*not used*/); platform::kCUDA /*not used*/);
} }
} }
platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
const OpFuncNode& op_func_node, const OperatorBase& op_base) { const OpFuncNode& op_func_node) {
auto& op_type = op_base.Type(); auto& op_type = op_func_node.operator_base_->Type();
auto* dev_ctx = op_func_node.dev_ctx_; auto* dev_ctx = op_func_node.dev_ctx_;
if (op_type == interpretercore::kMemcpyH2D) { if (op_type == interpretercore::kMemcpyH2D) {
VLOG(3) << "Get dev_ctx from d2h_context_pool_"; VLOG(3) << "Get dev_ctx from d2h_context_pool_";
...@@ -122,13 +121,13 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -122,13 +121,13 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
*/ */
bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr, bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
const Instruction& next_instr) { const Instruction& next_instr) {
return (cur_instr.dev_ctx_ == next_instr.dev_ctx_ || return (&cur_instr.DeviceContext() == &next_instr.DeviceContext() ||
interpretercore::IsMemcpyD2H(cur_instr) || interpretercore::IsMemcpyD2H(cur_instr) ||
interpretercore::IsMemcpyH2D(next_instr)); interpretercore::IsMemcpyH2D(next_instr));
} }
platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) { platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) {
if (instr.type_ == OpFuncType::kQueueSync) { if (instr.KernelType() == OpFuncType::kQueueSync) {
return platform::kCPU; return platform::kCPU;
} else { } else {
return platform::kCUDA; return platform::kCUDA;
......
...@@ -32,8 +32,7 @@ class StreamAnalyzer { ...@@ -32,8 +32,7 @@ class StreamAnalyzer {
void Schedule(const std::vector<size_t>& downstream_ops, void Schedule(const std::vector<size_t>& downstream_ops,
std::vector<Instruction>* instructions, size_t op_index); std::vector<Instruction>* instructions, size_t op_index);
platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node, platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node);
const OperatorBase& op_base);
private: private:
std::vector<size_t> ParseEventVarIds(const Instruction& cur_instr, std::vector<size_t> ParseEventVarIds(const Instruction& cur_instr,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册