未验证 提交 9268f392 编写于 作者: L LiYuRio 提交者: GitHub

Optimize gc in executor (#50301)

上级 80dc81c5
......@@ -14,6 +14,7 @@
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include <algorithm>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/global.h"
......@@ -53,40 +54,40 @@ FleetExecutor::~FleetExecutor() {
}
}
void FleetExecutor::Init(
const std::string& carrier_id,
const framework::ProgramDesc& program_desc,
framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
const std::vector<std::string>& inference_root_scope_vars,
const std::vector<framework::Scope*>& micro_scope_list) {
PADDLE_ENFORCE_GT(task_nodes.size(),
0,
platform::errors::InvalidArgument(
"Fleet executor is inited with empty task node"));
// TODO(fleet_exe devs): the unused_vars should be got from run time graph
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (const auto& desc : program_desc.Block(0).AllOps()) {
ops.emplace_back(framework::OpRegistry::CreateOp(*desc));
namespace {
void GetSubBlockTask(const std::vector<TaskNode*>& tasks,
TaskNode* cur_task,
std::set<TaskNode*>* sub_block_task) {
auto& downstream = cur_task->downstream();
auto& id_to_dep_type = cur_task->id_to_dep_type();
for (auto& down : downstream) {
int64_t task_id = down.first;
if (id_to_dep_type.at(task_id) == DependType::NORMAL) {
for (const auto& task : tasks) {
if (task->task_id() == task_id) {
sub_block_task->emplace(task);
GetSubBlockTask(tasks, task, sub_block_task);
}
}
}
}
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
}
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
void PreventVarsDelete(
std::unordered_map<const framework::OperatorBase*,
std::vector<std::string>>* unused_vars,
const std::vector<std::string>& vars_not_gc) {
std::vector<const framework::OperatorBase*> changed_ops;
for (auto pair : unused_vars) {
for (const auto& pair : *unused_vars) {
const framework::OperatorBase* op = pair.first;
std::vector<std::string> unused = pair.second;
for (auto name : inference_root_scope_vars) {
auto iter = std::find(unused.begin(), unused.end(), name);
if (iter != unused.end()) {
std::vector<std::string> cur_unused = pair.second;
for (auto name : vars_not_gc) {
auto iter = std::find(cur_unused.begin(), cur_unused.end(), name);
if (iter != cur_unused.end()) {
VLOG(3) << "Removing var: [" << name
<< "] from the unused vars list of op: [" << op->Type() << "]";
unused.erase(iter);
cur_unused.erase(iter);
if (std::find(changed_ops.begin(), changed_ops.end(), op) ==
changed_ops.end()) {
// record the op whose unused vars have been updated
......@@ -95,48 +96,118 @@ void FleetExecutor::Init(
}
}
// update the unused vars list in the map
unused_vars[op] = unused;
unused_vars->at(op) = cur_unused;
}
for (auto op : changed_ops) {
auto iter = unused_vars.find(op);
const auto& iter = unused_vars->find(op);
if (iter->second.empty()) {
// remove those ops in the map that have empty unused vars list
VLOG(3) << "Removing op: [" << op->Type() << "] from unused_vars map.";
unused_vars.erase(iter);
unused_vars->erase(iter);
}
}
runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars);
if (task_node->type() == "Cond") {
std::vector<std::string> while_block_vars;
VLOG(3) << "Vars in while sub block:";
for (auto& var : program_desc.Block(1).AllVars()) {
VLOG(3) << var->Name();
while_block_vars.emplace_back(var->Name());
}
for (const auto& pair : unused_vars) {
if (pair.first->Type() == "while") {
for (const auto& var_name : pair.second) {
while_block_vars.emplace_back(var_name);
}
}
}
std::vector<std::string> GetUnusedVarsAfterWhile(
const framework::ProgramDesc& program_desc,
TaskNode* cond_task,
const std::vector<std::string> vars_not_gc) {
std::vector<std::string> while_block_vars;
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (const auto& desc : program_desc.Block(0).AllOps()) {
ops.emplace_back(framework::OpRegistry::CreateOp(*desc));
}
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
PreventVarsDelete(&unused_vars, vars_not_gc);
for (const auto& pair : unused_vars) {
if (pair.first->Type() == "while") {
for (const auto& var_name : pair.second) {
while_block_vars.emplace_back(var_name);
}
VLOG(3) << "Vars below will be removed after while:";
for (const auto& name : while_block_vars) {
VLOG(3) << name;
}
}
return while_block_vars;
}
} // namespace
void FleetExecutor::Init(
const std::string& carrier_id,
const framework::ProgramDesc& program_desc,
framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
const std::vector<std::string>& inference_root_scope_vars,
const std::vector<framework::Scope*>& micro_scope_list) {
PADDLE_ENFORCE_GT(task_nodes.size(),
0,
platform::errors::InvalidArgument(
"Fleet executor is inited with empty task node"));
// Set the unused var after running while op
std::set<TaskNode*> sub_block_tasks;
std::vector<std::string> while_block_vars;
for (const auto& task_node : task_nodes) {
if (task_node->type() == "Cond") {
GetSubBlockTask(task_nodes, task_node, &sub_block_tasks);
while_block_vars = GetUnusedVarsAfterWhile(
program_desc, task_node, inference_root_scope_vars);
VLOG(3) << "Vars will be gced after while op";
for (auto var : while_block_vars) {
VLOG(3) << var;
}
task_node->SetWhileBlockVars(while_block_vars);
}
}
std::vector<framework::OperatorBase*> sub_block_ops;
for (const auto& task_node : sub_block_tasks) {
for (const auto& op : task_node->ops()) {
sub_block_ops.emplace_back(op);
}
}
// Analyse the unused vars in block 0. The operators in block 1
// should be passed in first for prevent vars been released but removed soon.
// Since the unused vars in block 1 need to analyse separately.
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (const auto& task_node : task_nodes) {
for (const auto& op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
}
auto global_unused_vars =
framework::GetUnusedVars(program_desc.Block(0), ops, {});
// Analyse the unused vars in block 1.
std::unordered_map<const framework::OperatorBase*, std::vector<std::string>>
sub_unused_vars;
if (program_desc.Size() > 1) {
sub_unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {});
PreventVarsDelete(&sub_unused_vars, while_block_vars);
}
for (auto& unique_op : ops) {
unique_op.release();
}
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
PreventVarsDelete(&global_unused_vars, inference_root_scope_vars);
runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
for (auto task_node : task_nodes) {
if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) {
task_node->SetUnusedVars(global_unused_vars);
} else {
task_node->SetUnusedVars(sub_unused_vars);
}
int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node);
}
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
for (auto& unique_op : ops) {
unique_op.release();
}
VLOG(5) << runtime_graph_->DebugString();
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册