未验证 提交 1342e2ea 编写于 作者: C chengduo 提交者: GitHub

Fix the bug of the fast threaded executor (#16514)

* Fix the bug of the fast threaded executor. I
上级 d6582449
......@@ -56,6 +56,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<FetchOpHandle *> fetch_ops;
std::vector<OpHandleBase *> ready_fetch_ops;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
......@@ -70,8 +71,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
auto &var_name = fetch_tensors[i];
auto fetched_var_it = fetched_vars.find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
"Cannot find fetched variable.(Perhaps the main_program "
"is not set to ParallelExecutor)");
"Cannot find fetched variable(%s).(Perhaps the main_program "
"is not set to ParallelExecutor)",
var_name);
auto &vars = fetched_var_it->second;
......@@ -88,7 +90,11 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
op->AddInput(var);
}
(*op_deps)[op] = static_cast<int>(op->NotReadyInputSize());
int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op] = dep;
if (dep == 0) {
ready_fetch_ops.emplace_back(op);
}
}
size_t num_complete = 0;
......@@ -97,7 +103,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
for (auto op : bootstrap_ops_) {
RunOpAsync(op_deps.get(), op, complete_q);
}
for (auto op : ready_fetch_ops) {
RunOpAsync(op_deps.get(), op, complete_q);
}
while (num_complete != op_deps->size()) {
size_t num_comp = complete_q->Pop();
if (num_comp == -1UL) {
......
......@@ -13,9 +13,9 @@
// limitations under the License.
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
......@@ -44,6 +44,7 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const {
}
void FetchOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated(platform::CPUPlace());
tensors_.resize(inputs_.size());
......
......@@ -80,7 +80,6 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
}
set.clear();
};
auto run_all_op = [&](OpHandleBase *op) { RunOp(ready_vars, op); };
// Clean run context
run_op_futures_.clear();
exception_holder_.Clear();
......@@ -116,7 +115,7 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
run_all_op(op);
ready_ops.insert(op);
}
}
}
......
......@@ -38,7 +38,15 @@ def Lenet(data, class_dim):
class TestFetchAndFeed(unittest.TestCase):
def parallel_exe(self, use_cuda, run_parallel_exe, seed=1):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
def parallel_exe(self,
use_cuda,
run_parallel_exe,
use_experimental_executor=False,
seed=1):
main_program = fluid.Program()
startup = fluid.Program()
startup.random_seed = seed
......@@ -63,8 +71,12 @@ class TestFetchAndFeed(unittest.TestCase):
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = use_experimental_executor
train_cp = compiler.CompiledProgram(main_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
run_parallel_exe(train_cp, exe, use_cuda, data, label, loss)
......@@ -131,8 +143,7 @@ class TestFetchAndFeed(unittest.TestCase):
if batch_id == 2:
break
def test_fetch(self):
os.environ['CPU_NUM'] = str(4)
def test_fetch_with_threaded_executor(self):
if core.is_compiled_with_cuda():
self.parallel_exe(
use_cuda=True,
......@@ -140,8 +151,18 @@ class TestFetchAndFeed(unittest.TestCase):
self.parallel_exe(
use_cuda=False, run_parallel_exe=self.run_parallel_exe_with_fetch)
def test_fetch_with_fast_threaded_executor(self):
if core.is_compiled_with_cuda():
self.parallel_exe(
use_cuda=True,
run_parallel_exe=self.run_parallel_exe_with_fetch,
use_experimental_executor=True)
self.parallel_exe(
use_cuda=False,
run_parallel_exe=self.run_parallel_exe_with_fetch,
use_experimental_executor=True)
def test_feed(self):
os.environ['CPU_NUM'] = str(4)
if core.is_compiled_with_cuda():
self.parallel_exe(
use_cuda=True, run_parallel_exe=self.run_parallel_exe_with_feed)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册