提交 8744f9a0 编写于 作者: Q Qiao Longfei

fix parallel executor async mode

上级 e70b1727
......@@ -188,7 +188,7 @@ ParallelExecutor::ParallelExecutor(
const std::string &loss_var_name, Scope *scope,
const std::vector<Scope *> &local_scopes,
const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy,
std::vector<ir::Graph *> graphs)
ir::Graph *graph)
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;
member_->use_cuda_ = exec_strategy.use_cuda_;
......@@ -218,12 +218,18 @@ ParallelExecutor::ParallelExecutor(
}
}
std::vector<ir::Graph *> graphs;
if (build_strategy.async_mode_) {
PADDLE_ENFORCE(!member_->use_cuda_,
"gpu mode does not support async_mode_ now!");
graphs.push_back(graph);
for (int i = 1; i < places.size(); ++i) {
auto *tmp_graph = new ir::Graph(graph->OriginProgram());
async_graphs_.emplace_back(tmp_graph);
graphs.push_back(tmp_graph);
}
}
ir::Graph *graph = graphs[0];
std::unique_ptr<ir::Graph> temp_owned_graph(graph);
// FIXME(Yancey1989): parallel graph mode get better performance
......
......@@ -50,7 +50,7 @@ class ParallelExecutor {
const std::vector<Scope *> &local_scopes,
const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy,
std::vector<ir::Graph *> graphs);
ir::Graph *graph);
~ParallelExecutor();
......@@ -76,6 +76,7 @@ class ParallelExecutor {
const BuildStrategy &build_strategy) const;
ParallelExecutorPrivate *member_;
std::vector<std::unique_ptr<ir::Graph>> async_graphs_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::unique_ptr<ncclUniqueId> local_nccl_id_;
#endif
......
......@@ -1271,7 +1271,7 @@ All parameter, weight, gradient are variables in Paddle.
pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &, const std::string &,
Scope *, std::vector<Scope *> &, const ExecutionStrategy &,
const BuildStrategy &, std::vector<ir::Graph *>>())
const BuildStrategy &, ir::Graph *>())
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册