未验证 提交 3e9d8548 编写于 作者: L LiYuRio 提交者: GitHub

fix gc and infinite buffer size (#50122)

上级 9f231147
...@@ -50,7 +50,9 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) { ...@@ -50,7 +50,9 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
auto max_ready_size = it->second.first; auto max_ready_size = it->second.first;
auto ready_size = it->second.second; auto ready_size = it->second.second;
ready_size += 1; ready_size += 1;
PADDLE_ENFORCE_LE(ready_size, if (max_ready_size != INFINITE_BUFFER_SIZE) {
PADDLE_ENFORCE_LE(
ready_size,
max_ready_size, max_ready_size,
platform::errors::OutOfRange( platform::errors::OutOfRange(
"upstream=%lld ready_size must <= max_ready_size, but " "upstream=%lld ready_size must <= max_ready_size, but "
...@@ -58,6 +60,7 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) { ...@@ -58,6 +60,7 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
up_id, up_id,
ready_size, ready_size,
max_ready_size)); max_ready_size));
}
it->second.second = ready_size; it->second.second = ready_size;
} }
...@@ -96,6 +99,9 @@ bool ComputeInterceptor::CanWriteOutput() { ...@@ -96,6 +99,9 @@ bool ComputeInterceptor::CanWriteOutput() {
for (auto& outs : out_buffs_) { for (auto& outs : out_buffs_) {
auto max_buffer_size = outs.second.first; auto max_buffer_size = outs.second.first;
auto used_size = outs.second.second; auto used_size = outs.second.second;
if (max_buffer_size == INFINITE_BUFFER_SIZE) {
continue;
}
// full, return false // full, return false
if (used_size == max_buffer_size) { if (used_size == max_buffer_size) {
VLOG(3) << "Interceptor " << GetInterceptorId() VLOG(3) << "Interceptor " << GetInterceptorId()
...@@ -112,6 +118,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() { ...@@ -112,6 +118,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
auto max_buff_size = outs.second.first; auto max_buff_size = outs.second.first;
auto used_size = outs.second.second; auto used_size = outs.second.second;
used_size += 1; used_size += 1;
if (max_buff_size != INFINITE_BUFFER_SIZE) {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
used_size, used_size,
max_buff_size, max_buff_size,
...@@ -121,6 +128,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() { ...@@ -121,6 +128,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
down_id, down_id,
used_size, used_size,
max_buff_size)); max_buff_size));
}
outs.second.second = used_size; outs.second.second = used_size;
InterceptorMessage ready_msg; InterceptorMessage ready_msg;
......
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
const int64_t INFINITE_BUFFER_SIZE = -1;
class ComputeInterceptor : public Interceptor { class ComputeInterceptor : public Interceptor {
public: public:
ComputeInterceptor(int64_t interceptor_id, TaskNode* node); ComputeInterceptor(int64_t interceptor_id, TaskNode* node);
......
...@@ -111,21 +111,22 @@ void FleetExecutor::Init( ...@@ -111,21 +111,22 @@ void FleetExecutor::Init(
task_node->SetUnusedVars(unused_vars); task_node->SetUnusedVars(unused_vars);
if (task_node->type() == "Cond") { if (task_node->type() == "Cond") {
std::vector<std::string> while_block_vars; std::vector<std::string> while_block_vars;
std::vector<std::string> vars_in_parent; VLOG(3) << "Vars in while sub block:";
std::vector<std::string> vars_in_sub;
for (auto& var : program_desc.Block(0).AllVars()) {
vars_in_parent.emplace_back(var->Name());
}
for (auto& var : program_desc.Block(1).AllVars()) { for (auto& var : program_desc.Block(1).AllVars()) {
vars_in_sub.emplace_back(var->Name()); VLOG(3) << var->Name();
} while_block_vars.emplace_back(var->Name());
std::sort(vars_in_parent.begin(), vars_in_parent.end()); }
std::sort(vars_in_sub.begin(), vars_in_sub.end()); for (const auto& pair : unused_vars) {
std::set_difference(vars_in_sub.begin(), if (pair.first->Type() == "while") {
vars_in_sub.end(), for (const auto& var_name : pair.second) {
vars_in_parent.begin(), while_block_vars.emplace_back(var_name);
vars_in_parent.end(), }
std::back_inserter(while_block_vars)); }
}
VLOG(3) << "Vars below will be removed after while:";
for (const auto& name : while_block_vars) {
VLOG(3) << name;
}
task_node->SetWhileBlockVars(while_block_vars); task_node->SetWhileBlockVars(while_block_vars);
} }
int64_t interceptor_id = task_node->task_id(); int64_t interceptor_id = task_node->task_id();
......
...@@ -2534,8 +2534,9 @@ class Executor: ...@@ -2534,8 +2534,9 @@ class Executor:
place = core.Place() place = core.Place()
place.set_place(self.place) place.set_place(self.place)
# NOTE: the last argument is used to force create some vars in root scope, inference_root_scope_vars = (
# won't be used during train. fleet_opt["fetch_var"] if "fetch_var" in fleet_opt else []
)
self._fleet_executor.init( self._fleet_executor.init(
carrier_id, carrier_id,
program.desc, program.desc,
...@@ -2544,7 +2545,7 @@ class Executor: ...@@ -2544,7 +2545,7 @@ class Executor:
num_micro_batches, num_micro_batches,
tasks, tasks,
task_id_to_rank, task_id_to_rank,
[], inference_root_scope_vars,
micro_scope_list, micro_scope_list,
) )
......
...@@ -165,19 +165,24 @@ class TestFleetExecutor(unittest.TestCase): ...@@ -165,19 +165,24 @@ class TestFleetExecutor(unittest.TestCase):
lazy_initialize=True, lazy_initialize=True,
) )
infinite_buff_size = -1
task_a.add_downstream_task(task_b.task_id(), 2) task_a.add_downstream_task(task_b.task_id(), 2)
task_b.add_upstream_task(task_a.task_id(), 2) task_b.add_upstream_task(task_a.task_id(), 2)
task_b.add_downstream_task(task_c.task_id(), 100) task_b.add_downstream_task(task_c.task_id(), infinite_buff_size)
task_c.add_upstream_task(task_b.task_id(), 100) task_c.add_upstream_task(task_b.task_id(), infinite_buff_size)
task_c.add_downstream_task(task_d.task_id(), 2) task_c.add_downstream_task(task_d.task_id(), 2)
task_d.add_upstream_task(task_c.task_id(), 2) task_d.add_upstream_task(task_c.task_id(), 2)
task_d.add_downstream_task(task_b.task_id(), 100, core.DependType.LOOP) task_d.add_downstream_task(
task_b.add_upstream_task(task_d.task_id(), 100, core.DependType.LOOP) task_b.task_id(), infinite_buff_size, core.DependType.LOOP
)
task_b.add_upstream_task(
task_d.task_id(), infinite_buff_size, core.DependType.LOOP
)
task_b.add_downstream_task( task_b.add_downstream_task(
task_e.task_id(), 100, core.DependType.STOP_LOOP task_e.task_id(), infinite_buff_size, core.DependType.STOP_LOOP
) )
task_e.add_upstream_task( task_e.add_upstream_task(
task_b.task_id(), 100, core.DependType.STOP_LOOP task_b.task_id(), infinite_buff_size, core.DependType.STOP_LOOP
) )
main_program._pipeline_opt = { main_program._pipeline_opt = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册