未验证 提交 3e9d8548 编写于 作者: L LiYuRio 提交者: GitHub

fix gc and infinite buffer size (#50122)

上级 9f231147
......@@ -50,14 +50,17 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
auto max_ready_size = it->second.first;
auto ready_size = it->second.second;
ready_size += 1;
PADDLE_ENFORCE_LE(ready_size,
max_ready_size,
platform::errors::OutOfRange(
"upstream=%lld ready_size must <= max_ready_size, but "
"now ready_size=%lld, max_ready_size=%lld",
up_id,
ready_size,
max_ready_size));
if (max_ready_size != INFINITE_BUFFER_SIZE) {
PADDLE_ENFORCE_LE(
ready_size,
max_ready_size,
platform::errors::OutOfRange(
"upstream=%lld ready_size must <= max_ready_size, but "
"now ready_size=%lld, max_ready_size=%lld",
up_id,
ready_size,
max_ready_size));
}
it->second.second = ready_size;
}
......@@ -96,6 +99,9 @@ bool ComputeInterceptor::CanWriteOutput() {
for (auto& outs : out_buffs_) {
auto max_buffer_size = outs.second.first;
auto used_size = outs.second.second;
if (max_buffer_size == INFINITE_BUFFER_SIZE) {
continue;
}
// full, return false
if (used_size == max_buffer_size) {
VLOG(3) << "Interceptor " << GetInterceptorId()
......@@ -112,15 +118,17 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
auto max_buff_size = outs.second.first;
auto used_size = outs.second.second;
used_size += 1;
PADDLE_ENFORCE_LE(
used_size,
max_buff_size,
platform::errors::OutOfRange("downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld",
down_id,
used_size,
max_buff_size));
if (max_buff_size != INFINITE_BUFFER_SIZE) {
PADDLE_ENFORCE_LE(
used_size,
max_buff_size,
platform::errors::OutOfRange("downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld",
down_id,
used_size,
max_buff_size));
}
outs.second.second = used_size;
InterceptorMessage ready_msg;
......
......@@ -22,6 +22,8 @@
namespace paddle {
namespace distributed {
const int64_t INFINITE_BUFFER_SIZE = -1;
class ComputeInterceptor : public Interceptor {
public:
ComputeInterceptor(int64_t interceptor_id, TaskNode* node);
......
......@@ -111,21 +111,22 @@ void FleetExecutor::Init(
task_node->SetUnusedVars(unused_vars);
if (task_node->type() == "Cond") {
std::vector<std::string> while_block_vars;
std::vector<std::string> vars_in_parent;
std::vector<std::string> vars_in_sub;
for (auto& var : program_desc.Block(0).AllVars()) {
vars_in_parent.emplace_back(var->Name());
}
VLOG(3) << "Vars in while sub block:";
for (auto& var : program_desc.Block(1).AllVars()) {
vars_in_sub.emplace_back(var->Name());
VLOG(3) << var->Name();
while_block_vars.emplace_back(var->Name());
}
for (const auto& pair : unused_vars) {
if (pair.first->Type() == "while") {
for (const auto& var_name : pair.second) {
while_block_vars.emplace_back(var_name);
}
}
}
VLOG(3) << "Vars below will be removed after while:";
for (const auto& name : while_block_vars) {
VLOG(3) << name;
}
std::sort(vars_in_parent.begin(), vars_in_parent.end());
std::sort(vars_in_sub.begin(), vars_in_sub.end());
std::set_difference(vars_in_sub.begin(),
vars_in_sub.end(),
vars_in_parent.begin(),
vars_in_parent.end(),
std::back_inserter(while_block_vars));
task_node->SetWhileBlockVars(while_block_vars);
}
int64_t interceptor_id = task_node->task_id();
......
......@@ -2534,8 +2534,9 @@ class Executor:
place = core.Place()
place.set_place(self.place)
# NOTE: the last argument is used to force create some vars in root scope,
# won't be used during train.
inference_root_scope_vars = (
fleet_opt["fetch_var"] if "fetch_var" in fleet_opt else []
)
self._fleet_executor.init(
carrier_id,
program.desc,
......@@ -2544,7 +2545,7 @@ class Executor:
num_micro_batches,
tasks,
task_id_to_rank,
[],
inference_root_scope_vars,
micro_scope_list,
)
......
......@@ -165,19 +165,24 @@ class TestFleetExecutor(unittest.TestCase):
lazy_initialize=True,
)
infinite_buff_size = -1
task_a.add_downstream_task(task_b.task_id(), 2)
task_b.add_upstream_task(task_a.task_id(), 2)
task_b.add_downstream_task(task_c.task_id(), 100)
task_c.add_upstream_task(task_b.task_id(), 100)
task_b.add_downstream_task(task_c.task_id(), infinite_buff_size)
task_c.add_upstream_task(task_b.task_id(), infinite_buff_size)
task_c.add_downstream_task(task_d.task_id(), 2)
task_d.add_upstream_task(task_c.task_id(), 2)
task_d.add_downstream_task(task_b.task_id(), 100, core.DependType.LOOP)
task_b.add_upstream_task(task_d.task_id(), 100, core.DependType.LOOP)
task_d.add_downstream_task(
task_b.task_id(), infinite_buff_size, core.DependType.LOOP
)
task_b.add_upstream_task(
task_d.task_id(), infinite_buff_size, core.DependType.LOOP
)
task_b.add_downstream_task(
task_e.task_id(), 100, core.DependType.STOP_LOOP
task_e.task_id(), infinite_buff_size, core.DependType.STOP_LOOP
)
task_e.add_upstream_task(
task_b.task_id(), 100, core.DependType.STOP_LOOP
task_b.task_id(), infinite_buff_size, core.DependType.STOP_LOOP
)
main_program._pipeline_opt = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册