未验证 提交 9268f392 编写于 作者: L LiYuRio 提交者: GitHub

Optimize gc in executor (#50301)

上级 80dc81c5
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include <algorithm> #include <algorithm>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/global.h"
...@@ -53,40 +54,40 @@ FleetExecutor::~FleetExecutor() { ...@@ -53,40 +54,40 @@ FleetExecutor::~FleetExecutor() {
} }
} }
void FleetExecutor::Init( namespace {
const std::string& carrier_id, void GetSubBlockTask(const std::vector<TaskNode*>& tasks,
const framework::ProgramDesc& program_desc, TaskNode* cur_task,
framework::Scope* scope, std::set<TaskNode*>* sub_block_task) {
const platform::Place& place, auto& downstream = cur_task->downstream();
int64_t num_micro_batches, auto& id_to_dep_type = cur_task->id_to_dep_type();
const std::vector<TaskNode*>& task_nodes, for (auto& down : downstream) {
const std::unordered_map<int64_t, int64_t>& task_id_to_rank, int64_t task_id = down.first;
const std::vector<std::string>& inference_root_scope_vars, if (id_to_dep_type.at(task_id) == DependType::NORMAL) {
const std::vector<framework::Scope*>& micro_scope_list) { for (const auto& task : tasks) {
PADDLE_ENFORCE_GT(task_nodes.size(), if (task->task_id() == task_id) {
0, sub_block_task->emplace(task);
platform::errors::InvalidArgument( GetSubBlockTask(tasks, task, sub_block_task);
"Fleet executor is inited with empty task node")); }
// TODO(fleet_exe devs): the unused_vars should be got from run time graph }
std::vector<std::unique_ptr<framework::OperatorBase>> ops; }
for (const auto& desc : program_desc.Block(0).AllOps()) {
ops.emplace_back(framework::OpRegistry::CreateOp(*desc));
} }
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); }
// NOTE: For inference, the vars in inference_root_scope_vars void PreventVarsDelete(
// shouldn't be deleted during inf, for that they may be the result of the std::unordered_map<const framework::OperatorBase*,
// inf. If they are GCed, it will cause error during ZeroCopy the result. std::vector<std::string>>* unused_vars,
const std::vector<std::string>& vars_not_gc) {
std::vector<const framework::OperatorBase*> changed_ops; std::vector<const framework::OperatorBase*> changed_ops;
for (auto pair : unused_vars) {
for (const auto& pair : *unused_vars) {
const framework::OperatorBase* op = pair.first; const framework::OperatorBase* op = pair.first;
std::vector<std::string> unused = pair.second; std::vector<std::string> cur_unused = pair.second;
for (auto name : inference_root_scope_vars) { for (auto name : vars_not_gc) {
auto iter = std::find(unused.begin(), unused.end(), name); auto iter = std::find(cur_unused.begin(), cur_unused.end(), name);
if (iter != unused.end()) { if (iter != cur_unused.end()) {
VLOG(3) << "Removing var: [" << name VLOG(3) << "Removing var: [" << name
<< "] from the unused vars list of op: [" << op->Type() << "]"; << "] from the unused vars list of op: [" << op->Type() << "]";
unused.erase(iter); cur_unused.erase(iter);
if (std::find(changed_ops.begin(), changed_ops.end(), op) == if (std::find(changed_ops.begin(), changed_ops.end(), op) ==
changed_ops.end()) { changed_ops.end()) {
// record the op whose unused vars have been updated // record the op whose unused vars have been updated
...@@ -95,48 +96,118 @@ void FleetExecutor::Init( ...@@ -95,48 +96,118 @@ void FleetExecutor::Init(
} }
} }
// update the unused vars list in the map // update the unused vars list in the map
unused_vars[op] = unused; unused_vars->at(op) = cur_unused;
} }
for (auto op : changed_ops) { for (auto op : changed_ops) {
auto iter = unused_vars.find(op); const auto& iter = unused_vars->find(op);
if (iter->second.empty()) { if (iter->second.empty()) {
// remove those ops in the map that have empty unused vars list // remove those ops in the map that have empty unused vars list
VLOG(3) << "Removing op: [" << op->Type() << "] from unused_vars map."; VLOG(3) << "Removing op: [" << op->Type() << "] from unused_vars map.";
unused_vars.erase(iter); unused_vars->erase(iter);
} }
} }
runtime_graph_ = std::make_shared<RuntimeGraph>(); }
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
for (auto task_node : task_nodes) { std::vector<std::string> GetUnusedVarsAfterWhile(
task_node->SetUnusedVars(unused_vars); const framework::ProgramDesc& program_desc,
if (task_node->type() == "Cond") { TaskNode* cond_task,
std::vector<std::string> while_block_vars; const std::vector<std::string> vars_not_gc) {
VLOG(3) << "Vars in while sub block:"; std::vector<std::string> while_block_vars;
for (auto& var : program_desc.Block(1).AllVars()) { std::vector<std::unique_ptr<framework::OperatorBase>> ops;
VLOG(3) << var->Name(); for (const auto& desc : program_desc.Block(0).AllOps()) {
while_block_vars.emplace_back(var->Name()); ops.emplace_back(framework::OpRegistry::CreateOp(*desc));
} }
for (const auto& pair : unused_vars) { auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
if (pair.first->Type() == "while") { PreventVarsDelete(&unused_vars, vars_not_gc);
for (const auto& var_name : pair.second) { for (const auto& pair : unused_vars) {
while_block_vars.emplace_back(var_name); if (pair.first->Type() == "while") {
} for (const auto& var_name : pair.second) {
} while_block_vars.emplace_back(var_name);
} }
VLOG(3) << "Vars below will be removed after while:"; }
for (const auto& name : while_block_vars) { }
VLOG(3) << name; return while_block_vars;
}
} // namespace
void FleetExecutor::Init(
const std::string& carrier_id,
const framework::ProgramDesc& program_desc,
framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
const std::vector<std::string>& inference_root_scope_vars,
const std::vector<framework::Scope*>& micro_scope_list) {
PADDLE_ENFORCE_GT(task_nodes.size(),
0,
platform::errors::InvalidArgument(
"Fleet executor is inited with empty task node"));
// Set the unused var after running while op
std::set<TaskNode*> sub_block_tasks;
std::vector<std::string> while_block_vars;
for (const auto& task_node : task_nodes) {
if (task_node->type() == "Cond") {
GetSubBlockTask(task_nodes, task_node, &sub_block_tasks);
while_block_vars = GetUnusedVarsAfterWhile(
program_desc, task_node, inference_root_scope_vars);
VLOG(3) << "Vars will be gced after while op";
for (auto var : while_block_vars) {
VLOG(3) << var;
} }
task_node->SetWhileBlockVars(while_block_vars); task_node->SetWhileBlockVars(while_block_vars);
} }
}
std::vector<framework::OperatorBase*> sub_block_ops;
for (const auto& task_node : sub_block_tasks) {
for (const auto& op : task_node->ops()) {
sub_block_ops.emplace_back(op);
}
}
// Analyse the unused vars in block 0. The operators in block 1
// should be passed in first for prevent vars been released but removed soon.
// Since the unused vars in block 1 need to analyse separately.
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (const auto& task_node : task_nodes) {
for (const auto& op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
}
auto global_unused_vars =
framework::GetUnusedVars(program_desc.Block(0), ops, {});
// Analyse the unused vars in block 1.
std::unordered_map<const framework::OperatorBase*, std::vector<std::string>>
sub_unused_vars;
if (program_desc.Size() > 1) {
sub_unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {});
PreventVarsDelete(&sub_unused_vars, while_block_vars);
}
for (auto& unique_op : ops) {
unique_op.release();
}
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
PreventVarsDelete(&global_unused_vars, inference_root_scope_vars);
runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
for (auto task_node : task_nodes) {
if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) {
task_node->SetUnusedVars(global_unused_vars);
} else {
task_node->SetUnusedVars(sub_unused_vars);
}
int64_t interceptor_id = task_node->task_id(); int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node); interceptor_id_to_task.emplace(interceptor_id, task_node);
} }
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank); runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task); runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
for (auto& unique_op : ops) {
unique_op.release();
}
VLOG(5) << runtime_graph_->DebugString(); VLOG(5) << runtime_graph_->DebugString();
Carrier* carrier = Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id); GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册