diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 9aedaa131400f3bfd6be24953050071e8970a557..a03ac900e9f66ea1d15aee318a01cf8b87d072d5 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -50,14 +50,17 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) { auto max_ready_size = it->second.first; auto ready_size = it->second.second; ready_size += 1; - PADDLE_ENFORCE_LE(ready_size, - max_ready_size, - platform::errors::OutOfRange( - "upstream=%lld ready_size must <= max_ready_size, but " - "now ready_size=%lld, max_ready_size=%lld", - up_id, - ready_size, - max_ready_size)); + if (max_ready_size != INFINITE_BUFFER_SIZE) { + PADDLE_ENFORCE_LE( + ready_size, + max_ready_size, + platform::errors::OutOfRange( + "upstream=%lld ready_size must <= max_ready_size, but " + "now ready_size=%lld, max_ready_size=%lld", + up_id, + ready_size, + max_ready_size)); + } it->second.second = ready_size; } @@ -96,6 +99,9 @@ bool ComputeInterceptor::CanWriteOutput() { for (auto& outs : out_buffs_) { auto max_buffer_size = outs.second.first; auto used_size = outs.second.second; + if (max_buffer_size == INFINITE_BUFFER_SIZE) { + continue; + } // full, return false if (used_size == max_buffer_size) { VLOG(3) << "Interceptor " << GetInterceptorId() @@ -112,15 +118,17 @@ void ComputeInterceptor::SendDataReadyToDownStream() { auto max_buff_size = outs.second.first; auto used_size = outs.second.second; used_size += 1; - PADDLE_ENFORCE_LE( - used_size, - max_buff_size, - platform::errors::OutOfRange("downstream=%lld used buff size must <= " - "max_buff_size, but now used_size=%lld, " - "max_buff_size=%lld", - down_id, - used_size, - max_buff_size)); + if (max_buff_size != INFINITE_BUFFER_SIZE) { + PADDLE_ENFORCE_LE( + used_size, + max_buff_size, + platform::errors::OutOfRange("downstream=%lld used buff size must <= " + "max_buff_size, but now used_size=%lld, " + "max_buff_size=%lld", + down_id, + used_size, + max_buff_size)); + } outs.second.second = used_size; InterceptorMessage ready_msg; diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h index 9709cd4437f1019fea80cf04ecce5a38f74bb463..eade47fd8787e61999a8c627af316d221e5aba48 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h @@ -22,6 +22,8 @@ namespace paddle { namespace distributed { +const int64_t INFINITE_BUFFER_SIZE = -1; + class ComputeInterceptor : public Interceptor { public: ComputeInterceptor(int64_t interceptor_id, TaskNode* node); diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 88363696ede257492b6f703c2a8ddaa97d5b5b15..ae3776d2c5beacbccc7d63f05aff7882a9b2440a 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -111,21 +111,22 @@ void FleetExecutor::Init( task_node->SetUnusedVars(unused_vars); if (task_node->type() == "Cond") { std::vector while_block_vars; - std::vector vars_in_parent; - std::vector vars_in_sub; - for (auto& var : program_desc.Block(0).AllVars()) { - vars_in_parent.emplace_back(var->Name()); - } + VLOG(3) << "Vars in while sub block:"; for (auto& var : program_desc.Block(1).AllVars()) { - vars_in_sub.emplace_back(var->Name()); + VLOG(3) << var->Name(); + while_block_vars.emplace_back(var->Name()); + } + for (const auto& pair : unused_vars) { + if (pair.first->Type() == "while") { + for (const auto& var_name : pair.second) { + while_block_vars.emplace_back(var_name); + } + } + } + VLOG(3) << "Vars below will be removed after while:"; + for (const auto& name : while_block_vars) { + VLOG(3) << name; } - std::sort(vars_in_parent.begin(), vars_in_parent.end()); - std::sort(vars_in_sub.begin(), vars_in_sub.end()); - std::set_difference(vars_in_sub.begin(), - vars_in_sub.end(), - vars_in_parent.begin(), - vars_in_parent.end(), - std::back_inserter(while_block_vars)); task_node->SetWhileBlockVars(while_block_vars); } int64_t interceptor_id = task_node->task_id(); diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index da9d12802434f39326a7f276f320caecb9a05c86..6e094588e686a54d6b6ba6d4137035c2d2cb8b91 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -2534,8 +2534,9 @@ class Executor: place = core.Place() place.set_place(self.place) - # NOTE: the last argument is used to force create some vars in root scope, - # won't be used during train. + inference_root_scope_vars = ( + fleet_opt["fetch_var"] if "fetch_var" in fleet_opt else [] + ) self._fleet_executor.init( carrier_id, program.desc, @@ -2544,7 +2545,7 @@ class Executor: num_micro_batches, tasks, task_id_to_rank, - [], + inference_root_scope_vars, micro_scope_list, ) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py b/python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py index 1ca8c869a96bdfbc2847df0aa81101a28a9e3042..f6418cdee2ccecda52f94b74c388233b3b8d8032 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py @@ -165,19 +165,24 @@ class TestFleetExecutor(unittest.TestCase): lazy_initialize=True, ) + infinite_buff_size = -1 task_a.add_downstream_task(task_b.task_id(), 2) task_b.add_upstream_task(task_a.task_id(), 2) - task_b.add_downstream_task(task_c.task_id(), 100) - task_c.add_upstream_task(task_b.task_id(), 100) + task_b.add_downstream_task(task_c.task_id(), infinite_buff_size) + task_c.add_upstream_task(task_b.task_id(), infinite_buff_size) task_c.add_downstream_task(task_d.task_id(), 2) task_d.add_upstream_task(task_c.task_id(), 2) - task_d.add_downstream_task(task_b.task_id(), 100, core.DependType.LOOP) - task_b.add_upstream_task(task_d.task_id(), 100, core.DependType.LOOP) + task_d.add_downstream_task( + task_b.task_id(), infinite_buff_size, core.DependType.LOOP + ) + task_b.add_upstream_task( + task_d.task_id(), infinite_buff_size, core.DependType.LOOP + ) task_b.add_downstream_task( - task_e.task_id(), 100, core.DependType.STOP_LOOP + task_e.task_id(), infinite_buff_size, core.DependType.STOP_LOOP ) task_e.add_upstream_task( - task_b.task_id(), 100, core.DependType.STOP_LOOP + task_b.task_id(), infinite_buff_size, core.DependType.STOP_LOOP ) main_program._pipeline_opt = {