未验证 提交 1342e2ea 编写于 作者: C chengduo 提交者: GitHub

Fix the bug of the fast threaded executor (#16514)

* Fix the bug of the fast threaded executor. I
上级 d6582449
...@@ -56,6 +56,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -56,6 +56,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
fetches.resize(fetch_tensors.size()); fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<FetchOpHandle *> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
std::vector<OpHandleBase *> ready_fetch_ops;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
...@@ -70,8 +71,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -70,8 +71,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
auto &var_name = fetch_tensors[i]; auto &var_name = fetch_tensors[i];
auto fetched_var_it = fetched_vars.find(var_name); auto fetched_var_it = fetched_vars.find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(), PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
"Cannot find fetched variable.(Perhaps the main_program " "Cannot find fetched variable(%s).(Perhaps the main_program "
"is not set to ParallelExecutor)"); "is not set to ParallelExecutor)",
var_name);
auto &vars = fetched_var_it->second; auto &vars = fetched_var_it->second;
...@@ -88,7 +90,11 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -88,7 +90,11 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
op->AddInput(var); op->AddInput(var);
} }
(*op_deps)[op] = static_cast<int>(op->NotReadyInputSize()); int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op] = dep;
if (dep == 0) {
ready_fetch_ops.emplace_back(op);
}
} }
size_t num_complete = 0; size_t num_complete = 0;
...@@ -97,7 +103,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -97,7 +103,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
for (auto op : bootstrap_ops_) { for (auto op : bootstrap_ops_) {
RunOpAsync(op_deps.get(), op, complete_q); RunOpAsync(op_deps.get(), op, complete_q);
} }
for (auto op : ready_fetch_ops) {
RunOpAsync(op_deps.get(), op, complete_q);
}
while (num_complete != op_deps->size()) { while (num_complete != op_deps->size()) {
size_t num_comp = complete_q->Pop(); size_t num_comp = complete_q->Pop();
if (num_comp == -1UL) { if (num_comp == -1UL) {
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -44,6 +44,7 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const { ...@@ -44,6 +44,7 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const {
} }
void FetchOpHandle::RunImpl() { void FetchOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated(platform::CPUPlace()); WaitInputVarGenerated(platform::CPUPlace());
tensors_.resize(inputs_.size()); tensors_.resize(inputs_.size());
......
...@@ -80,7 +80,6 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -80,7 +80,6 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
} }
set.clear(); set.clear();
}; };
auto run_all_op = [&](OpHandleBase *op) { RunOp(ready_vars, op); };
// Clean run context // Clean run context
run_op_futures_.clear(); run_op_futures_.clear();
exception_holder_.Clear(); exception_holder_.Clear();
...@@ -116,7 +115,7 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -116,7 +115,7 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
auto &deps = pending_ops[op]; auto &deps = pending_ops[op];
--deps; --deps;
if (deps == 0) { if (deps == 0) {
run_all_op(op); ready_ops.insert(op);
} }
} }
} }
......
...@@ -38,7 +38,15 @@ def Lenet(data, class_dim): ...@@ -38,7 +38,15 @@ def Lenet(data, class_dim):
class TestFetchAndFeed(unittest.TestCase): class TestFetchAndFeed(unittest.TestCase):
def parallel_exe(self, use_cuda, run_parallel_exe, seed=1): @classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
def parallel_exe(self,
use_cuda,
run_parallel_exe,
use_experimental_executor=False,
seed=1):
main_program = fluid.Program() main_program = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
startup.random_seed = seed startup.random_seed = seed
...@@ -63,8 +71,12 @@ class TestFetchAndFeed(unittest.TestCase): ...@@ -63,8 +71,12 @@ class TestFetchAndFeed(unittest.TestCase):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
build_strategy.memory_optimize = False build_strategy.memory_optimize = False
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = use_experimental_executor
train_cp = compiler.CompiledProgram(main_program).with_data_parallel( train_cp = compiler.CompiledProgram(main_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy) loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
run_parallel_exe(train_cp, exe, use_cuda, data, label, loss) run_parallel_exe(train_cp, exe, use_cuda, data, label, loss)
...@@ -131,8 +143,7 @@ class TestFetchAndFeed(unittest.TestCase): ...@@ -131,8 +143,7 @@ class TestFetchAndFeed(unittest.TestCase):
if batch_id == 2: if batch_id == 2:
break break
def test_fetch(self): def test_fetch_with_threaded_executor(self):
os.environ['CPU_NUM'] = str(4)
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.parallel_exe( self.parallel_exe(
use_cuda=True, use_cuda=True,
...@@ -140,8 +151,18 @@ class TestFetchAndFeed(unittest.TestCase): ...@@ -140,8 +151,18 @@ class TestFetchAndFeed(unittest.TestCase):
self.parallel_exe( self.parallel_exe(
use_cuda=False, run_parallel_exe=self.run_parallel_exe_with_fetch) use_cuda=False, run_parallel_exe=self.run_parallel_exe_with_fetch)
def test_fetch_with_fast_threaded_executor(self):
if core.is_compiled_with_cuda():
self.parallel_exe(
use_cuda=True,
run_parallel_exe=self.run_parallel_exe_with_fetch,
use_experimental_executor=True)
self.parallel_exe(
use_cuda=False,
run_parallel_exe=self.run_parallel_exe_with_fetch,
use_experimental_executor=True)
def test_feed(self): def test_feed(self):
os.environ['CPU_NUM'] = str(4)
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.parallel_exe( self.parallel_exe(
use_cuda=True, run_parallel_exe=self.run_parallel_exe_with_feed) use_cuda=True, run_parallel_exe=self.run_parallel_exe_with_feed)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册