提交 f768fbf7 编写于 作者: Q Qiao Longfei

support multi graph

test=develop
上级 ff01d705
...@@ -20,12 +20,12 @@ namespace details { ...@@ -20,12 +20,12 @@ namespace details {
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph) const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
: strategy_(std::move(strategy)), : strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)), local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)), places_(std::move(places)),
graph_(graph) { graphs_(std::move(graphs)) {
VLOG(3) << "build AsyncSSAGraphExecutor"; VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
...@@ -37,7 +37,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( ...@@ -37,7 +37,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
<< " to run the operators of the graph on each device."; << " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor( executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, graph_)); strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i]));
} }
} }
......
...@@ -29,9 +29,9 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { ...@@ -29,9 +29,9 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy, AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
ir::Graph *graph); std::vector<ir::Graph *> graphs);
~AsyncSSAGraphExecutor() final = default; ~AsyncSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graph_; } const ir::Graph &Graph() const override { return *graphs_[0]; }
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
...@@ -40,7 +40,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { ...@@ -40,7 +40,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr}; std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
ir::Graph *graph_; std::vector<ir::Graph *> graphs_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_; std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
ExceptionHolder exception_holder_; ExceptionHolder exception_holder_;
......
...@@ -188,7 +188,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -188,7 +188,7 @@ ParallelExecutor::ParallelExecutor(
const std::string &loss_var_name, Scope *scope, const std::string &loss_var_name, Scope *scope,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy, const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy,
ir::Graph *graph) std::vector<ir::Graph *> graphs)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
member_->use_cuda_ = exec_strategy.use_cuda_; member_->use_cuda_ = exec_strategy.use_cuda_;
...@@ -222,6 +222,8 @@ ParallelExecutor::ParallelExecutor( ...@@ -222,6 +222,8 @@ ParallelExecutor::ParallelExecutor(
PADDLE_ENFORCE(!member_->use_cuda_, PADDLE_ENFORCE(!member_->use_cuda_,
"gpu mode does not support async_mode_ now!"); "gpu mode does not support async_mode_ now!");
} }
ir::Graph *graph = graphs[0];
std::unique_ptr<ir::Graph> temp_owned_graph(graph); std::unique_ptr<ir::Graph> temp_owned_graph(graph);
// FIXME(Yancey1989): parallel graph mode get better performance // FIXME(Yancey1989): parallel graph mode get better performance
...@@ -262,17 +264,26 @@ ParallelExecutor::ParallelExecutor( ...@@ -262,17 +264,26 @@ ParallelExecutor::ParallelExecutor(
if (member_->local_scopes_.size() != 1 && local_scopes.empty()) { if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
BCastParamsToDevices(bcast_vars); BCastParamsToDevices(bcast_vars);
} }
// Startup Program has been run. All local scopes has correct parameters. // Startup Program has been run. All local scopes has correct parameters.
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
std::vector<ir::Graph *> async_graphs(places.size());
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
VLOG(3) << "use local async mode"; VLOG(3) << "use local async mode";
temp_owned_graph = build_strategy.Apply( temp_owned_graph =
std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name, build_strategy.Apply(std::move(temp_owned_graph), {member_->places_[0]},
{member_->local_scopes_[0]}, member_->nranks_, member_->use_cuda_, loss_var_name, {member_->local_scopes_[0]}, 1,
member_->nccl_ctxs_.get()); member_->use_cuda_, member_->nccl_ctxs_.get());
for (int i = 1; i < member_->places_.size(); ++i) {
std::unique_ptr<ir::Graph> temp_graph(graphs[i]);
temp_graph =
build_strategy.Apply(std::move(temp_graph), {member_->places_[i]},
loss_var_name, {member_->local_scopes_[i]}, 1,
member_->use_cuda_, member_->nccl_ctxs_.get());
async_graphs[i] = temp_graph.release();
}
} else { } else {
temp_owned_graph = build_strategy.Apply( temp_owned_graph = build_strategy.Apply(
std::move(temp_owned_graph), member_->places_, loss_var_name, std::move(temp_owned_graph), member_->places_, loss_var_name,
...@@ -284,7 +295,14 @@ ParallelExecutor::ParallelExecutor( ...@@ -284,7 +295,14 @@ ParallelExecutor::ParallelExecutor(
VLOG(3) << "use local async mode"; VLOG(3) << "use local async mode";
temp_owned_graph = build_strategy.Apply( temp_owned_graph = build_strategy.Apply(
std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name, std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, member_->nranks_, member_->use_cuda_); {member_->local_scopes_[0]}, 1, member_->use_cuda_);
for (int i = 1; i < member_->places_.size(); ++i) {
std::unique_ptr<ir::Graph> temp_graph(graphs[i]);
temp_graph = build_strategy.Apply(
std::move(temp_graph), {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, 1, member_->use_cuda_);
async_graphs[i] = temp_graph.release();
}
} else { } else {
temp_owned_graph = build_strategy.Apply( temp_owned_graph = build_strategy.Apply(
std::move(temp_owned_graph), member_->places_, loss_var_name, std::move(temp_owned_graph), member_->places_, loss_var_name,
...@@ -304,6 +322,8 @@ ParallelExecutor::ParallelExecutor( ...@@ -304,6 +322,8 @@ ParallelExecutor::ParallelExecutor(
graph = temp_owned_graph.release(); graph = temp_owned_graph.release();
} }
async_graphs[0] = graph;
// Step 3. Create vars in each scope. Passes may also create new vars. // Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars // skip control vars and empty vars
std::vector<details::VariableInfo> var_infos; std::vector<details::VariableInfo> var_infos;
...@@ -334,7 +354,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -334,7 +354,7 @@ ParallelExecutor::ParallelExecutor(
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
VLOG(3) << "use AsyncSSAGraphExecutor"; VLOG(3) << "use AsyncSSAGraphExecutor";
member_->executor_.reset(new details::AsyncSSAGraphExecutor( member_->executor_.reset(new details::AsyncSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, graph)); exec_strategy, member_->local_scopes_, member_->places_, async_graphs));
} else if (build_strategy.enable_parallel_graph_) { } else if (build_strategy.enable_parallel_graph_) {
VLOG(3) << "use ParallelSSAGraphExecutor"; VLOG(3) << "use ParallelSSAGraphExecutor";
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -50,7 +50,7 @@ class ParallelExecutor { ...@@ -50,7 +50,7 @@ class ParallelExecutor {
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const ExecutionStrategy &exec_strategy, const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy, const BuildStrategy &build_strategy,
ir::Graph *graph); std::vector<ir::Graph *> graphs);
~ParallelExecutor(); ~ParallelExecutor();
......
...@@ -95,6 +95,7 @@ class BlockingQueue { ...@@ -95,6 +95,7 @@ class BlockingQueue {
void Close() { void Close() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
VLOG(3) << "close queue";
closed_ = true; closed_ = true;
send_cv_.notify_all(); send_cv_.notify_all();
receive_cv_.notify_all(); receive_cv_.notify_all();
......
...@@ -35,7 +35,10 @@ class PyReader : public framework::FileReader { ...@@ -35,7 +35,10 @@ class PyReader : public framework::FileReader {
~PyReader() { queue_->Close(); } ~PyReader() { queue_->Close(); }
void Shutdown() override { queue_->Close(); } void Shutdown() override {
VLOG(3) << "PyReader shutdown!";
queue_->Close();
}
void Start() override { queue_->ReOpen(); } void Start() override { queue_->ReOpen(); }
......
...@@ -1230,7 +1230,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1230,7 +1230,7 @@ All parameter, weight, gradient are variables in Paddle.
pe.def(py::init<const std::vector<platform::Place> &, pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &, const std::string &, const std::unordered_set<std::string> &, const std::string &,
Scope *, std::vector<Scope *> &, const ExecutionStrategy &, Scope *, std::vector<Scope *> &, const ExecutionStrategy &,
const BuildStrategy &, ir::Graph *>()) const BuildStrategy &, std::vector<ir::Graph *>>())
// NOTE: even we return a vec<Scope*>* to Python use reference policy. // NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element // We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope* // of vec<Scope*> will be freed by Python GC. We can only return Scope*
......
...@@ -177,12 +177,17 @@ class ParallelExecutor(object): ...@@ -177,12 +177,17 @@ class ParallelExecutor(object):
# step7: init ParallelExecutor # step7: init ParallelExecutor
# ParallelExecutor API will be deprecated, don't support parallel graph. # ParallelExecutor API will be deprecated, don't support parallel graph.
self._graph = core.Graph(main.desc) self._graphs = []
if build_strategy.async_mode:
for _ in range(cpu_num):
self._graphs.append(core.Graph(main.desc))
else:
self._graphs.append(core.Graph(main.desc))
self.executor = core.ParallelExecutor( self.executor = core.ParallelExecutor(
places, persistable_vars, places, persistable_vars,
cpt.to_text(loss_name) if loss_name else six.u(''), scope, cpt.to_text(loss_name) if loss_name else six.u(''), scope,
local_scopes, exec_strategy, build_strategy, self._graph) local_scopes, exec_strategy, build_strategy, self._graphs)
self.scope = scope self.scope = scope
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册