未验证 提交 5d5cb256 编写于 作者: L LiYuRio 提交者: GitHub

fix gc bug and start interceptor (#50344)

上级 5cae5fdd
...@@ -111,7 +111,15 @@ void PreventVarsDelete( ...@@ -111,7 +111,15 @@ void PreventVarsDelete(
std::vector<std::string> GetUnusedVarsAfterWhile( std::vector<std::string> GetUnusedVarsAfterWhile(
const framework::ProgramDesc& program_desc, const framework::ProgramDesc& program_desc,
TaskNode* cond_task, TaskNode* cond_task,
const std::vector<std::string> vars_not_gc) { const std::vector<std::string>& vars_not_gc) {
// NOTE: Since while op won't appear in task node, in order to analyze
// the vars which should be free after calling while op, we rebuild the
// whole program and get the unused vars after calling while op.
// The vars in while block should not be free until the while op is finished.
// In a word, the vars need to be free after while op is:
// 1. Vars in parent block and being used in while block.
// 2. Local vars only defined in while block.
// The unused vars above will be free in cond interceptor.
std::vector<std::string> while_block_vars; std::vector<std::string> while_block_vars;
std::vector<std::unique_ptr<framework::OperatorBase>> ops; std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (const auto& desc : program_desc.Block(0).AllOps()) { for (const auto& desc : program_desc.Block(0).AllOps()) {
...@@ -124,6 +132,9 @@ std::vector<std::string> GetUnusedVarsAfterWhile( ...@@ -124,6 +132,9 @@ std::vector<std::string> GetUnusedVarsAfterWhile(
for (const auto& var_name : pair.second) { for (const auto& var_name : pair.second) {
while_block_vars.emplace_back(var_name); while_block_vars.emplace_back(var_name);
} }
for (auto& var : program_desc.Block(1).AllVars()) {
while_block_vars.emplace_back(var->Name());
}
} }
} }
return while_block_vars; return while_block_vars;
...@@ -178,13 +189,6 @@ void FleetExecutor::Init( ...@@ -178,13 +189,6 @@ void FleetExecutor::Init(
auto global_unused_vars = auto global_unused_vars =
framework::GetUnusedVars(program_desc.Block(0), ops, {}); framework::GetUnusedVars(program_desc.Block(0), ops, {});
// Analyse the unused vars in block 1.
std::unordered_map<const framework::OperatorBase*, std::vector<std::string>>
sub_unused_vars;
if (program_desc.Size() > 1) {
sub_unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {});
PreventVarsDelete(&sub_unused_vars, while_block_vars);
}
for (auto& unique_op : ops) { for (auto& unique_op : ops) {
unique_op.release(); unique_op.release();
} }
...@@ -199,8 +203,6 @@ void FleetExecutor::Init( ...@@ -199,8 +203,6 @@ void FleetExecutor::Init(
for (auto task_node : task_nodes) { for (auto task_node : task_nodes) {
if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) { if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) {
task_node->SetUnusedVars(global_unused_vars); task_node->SetUnusedVars(global_unused_vars);
} else {
task_node->SetUnusedVars(sub_unused_vars);
} }
int64_t interceptor_id = task_node->task_id(); int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node); interceptor_id_to_task.emplace(interceptor_id, task_node);
......
...@@ -68,14 +68,15 @@ void StartInterceptor::SendDataReadyToDownStream() { ...@@ -68,14 +68,15 @@ void StartInterceptor::SendDataReadyToDownStream() {
} }
if (finish_count_ == batch_size_) { if (finish_count_ == batch_size_) {
for (int64_t i = 0; i < batch_size_; ++i) { for (int64_t i = 0; i < batch_size_; ++i) {
int64_t scope_id = step_ % node_->max_run_times();
for (auto& outs : out_buffs_) { for (auto& outs : out_buffs_) {
auto down_id = outs.first; auto down_id = outs.first;
InterceptorMessage ready_msg; InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY); ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(step_); ready_msg.set_scope_idx(scope_id);
VLOG(3) << "StartInterceptor " << interceptor_id_ VLOG(3) << "StartInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id << " Send data_is_ready msg to " << down_id
<< " in scope: " << step_; << " in scope: " << scope_id;
Send(down_id, ready_msg); Send(down_id, ready_msg);
} }
step_++; step_++;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册