diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 297ee92fc3c84c2feec9cb85bd8671ce8ad94ed0..3e805bd5b480241954960f92a72514723c3a8bb7 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -56,6 +56,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( fetches.resize(fetch_tensors.size()); std::unordered_map> fetched_vars; std::vector fetch_ops; + std::vector ready_fetch_ops; for (auto &fetch_var_name : fetch_tensors) { for (auto &var_map : graph_->Get(details::kGraphVars)) { @@ -70,8 +71,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( auto &var_name = fetch_tensors[i]; auto fetched_var_it = fetched_vars.find(var_name); PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(), - "Cannot find fetched variable.(Perhaps the main_program " - "is not set to ParallelExecutor)"); + "Cannot find fetched variable(%s).(Perhaps the main_program " + "is not set to ParallelExecutor)", + var_name); auto &vars = fetched_var_it->second; @@ -88,7 +90,11 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( op->AddInput(var); } - (*op_deps)[op] = static_cast(op->NotReadyInputSize()); + int dep = static_cast(op->NotReadyInputSize()); + (*op_deps)[op] = dep; + if (dep == 0) { + ready_fetch_ops.emplace_back(op); + } } size_t num_complete = 0; @@ -97,7 +103,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( for (auto op : bootstrap_ops_) { RunOpAsync(op_deps.get(), op, complete_q); } - + for (auto op : ready_fetch_ops) { + RunOpAsync(op_deps.get(), op, complete_q); + } while (num_complete != op_deps->size()) { size_t num_comp = complete_q->Pop(); if (num_comp == -1UL) { diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index 232d82a5da596a78d2999c4a4c4f7dda0c7cad7e..81e200c0dae4484a46afa16e69db68ff746484c6 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -13,9 +13,9 @@ // limitations under the License. #include "paddle/fluid/framework/details/fetch_op_handle.h" - #include #include +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace framework { @@ -44,6 +44,7 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const { } void FetchOpHandle::RunImpl() { + platform::RecordEvent record_event(Name()); WaitInputVarGenerated(platform::CPUPlace()); tensors_.resize(inputs_.size()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 5dde0b76b816bc5309b455d58deb8942300c6af5..67246a4dd448b0ce2f115d6438c5fdd6cc39ca6d 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -80,7 +80,6 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( } set.clear(); }; - auto run_all_op = [&](OpHandleBase *op) { RunOp(ready_vars, op); }; // Clean run context run_op_futures_.clear(); exception_holder_.Clear(); @@ -116,7 +115,7 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( auto &deps = pending_ops[op]; --deps; if (deps == 0) { - run_all_op(op); + ready_ops.insert(op); } } } diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py index bda8b666dcde22b0e4bacdb5db252267f4c7e34b..645b0188d5f45935ace074ba343de246af476b41 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py @@ -38,7 +38,15 @@ def Lenet(data, class_dim): class TestFetchAndFeed(unittest.TestCase): - def parallel_exe(self, use_cuda, run_parallel_exe, seed=1): + @classmethod + def setUpClass(cls): + os.environ['CPU_NUM'] = str(4) + + def parallel_exe(self, + use_cuda, + run_parallel_exe, + use_experimental_executor=False, + seed=1): main_program = fluid.Program() startup = fluid.Program() startup.random_seed = seed @@ -63,8 +71,12 @@ class TestFetchAndFeed(unittest.TestCase): build_strategy = fluid.BuildStrategy() build_strategy.enable_inplace = False build_strategy.memory_optimize = False + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.use_experimental_executor = use_experimental_executor train_cp = compiler.CompiledProgram(main_program).with_data_parallel( - loss_name=loss.name, build_strategy=build_strategy) + loss_name=loss.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) run_parallel_exe(train_cp, exe, use_cuda, data, label, loss) @@ -131,8 +143,7 @@ class TestFetchAndFeed(unittest.TestCase): if batch_id == 2: break - def test_fetch(self): - os.environ['CPU_NUM'] = str(4) + def test_fetch_with_threaded_executor(self): if core.is_compiled_with_cuda(): self.parallel_exe( use_cuda=True, @@ -140,8 +151,18 @@ class TestFetchAndFeed(unittest.TestCase): self.parallel_exe( use_cuda=False, run_parallel_exe=self.run_parallel_exe_with_fetch) + def test_fetch_with_fast_threaded_executor(self): + if core.is_compiled_with_cuda(): + self.parallel_exe( + use_cuda=True, + run_parallel_exe=self.run_parallel_exe_with_fetch, + use_experimental_executor=True) + self.parallel_exe( + use_cuda=False, + run_parallel_exe=self.run_parallel_exe_with_fetch, + use_experimental_executor=True) + def test_feed(self): - os.environ['CPU_NUM'] = str(4) if core.is_compiled_with_cuda(): self.parallel_exe( use_cuda=True, run_parallel_exe=self.run_parallel_exe_with_feed)