提交 dab7f369 编写于 作者: Q Qiao Longfei

optimize code test=develop

上级 cf0511f2
...@@ -21,12 +21,12 @@ namespace details { ...@@ -21,12 +21,12 @@ namespace details {
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph) ir::Graph* graph)
: strategy_(std::move(strategy)), : strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)), local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)), places_(std::move(places)),
graph_(std::move(graph)) { graph_(graph) {
VLOG(3) << "build AsyncSSAGraphExecutor"; VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
...@@ -38,7 +38,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( ...@@ -38,7 +38,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
<< " to run the operators of the graph on each device."; << " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor( executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, graph_.get())); strategy_, {local_scopes_[i]}, {places_[i]}, graph_));
} }
} }
......
...@@ -29,7 +29,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { ...@@ -29,7 +29,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy, AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph); ir::Graph *graph);
~AsyncSSAGraphExecutor() final = default; ~AsyncSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graph_; } const ir::Graph &Graph() const override { return *graph_; }
...@@ -40,7 +40,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { ...@@ -40,7 +40,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr}; std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
std::unique_ptr<ir::Graph> graph_; ir::Graph *graph_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_; std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
ExceptionHolder exception_holder_; ExceptionHolder exception_holder_;
......
...@@ -269,25 +269,26 @@ ParallelExecutor::ParallelExecutor( ...@@ -269,25 +269,26 @@ ParallelExecutor::ParallelExecutor(
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
VLOG(3) << "use local async mode"; VLOG(3) << "use local async mode";
temp_owned_graph = temp_owned_graph = build_strategy.Apply(
build_strategy.Apply(std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name, std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, member_->nranks_, {member_->local_scopes_[0]}, member_->nranks_, member_->use_cuda_,
member_->use_cuda_, member_->nccl_ctxs_.get()); member_->nccl_ctxs_.get());
} else { } else {
temp_owned_graph = build_strategy.Apply(std::move(temp_owned_graph), member_->places_, loss_var_name, temp_owned_graph = build_strategy.Apply(
member_->local_scopes_, member_->nranks_, std::move(temp_owned_graph), member_->places_, loss_var_name,
member_->use_cuda_, member_->nccl_ctxs_.get()); member_->local_scopes_, member_->nranks_, member_->use_cuda_,
member_->nccl_ctxs_.get());
} }
#else #else
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
VLOG(3) << "use local async mode"; VLOG(3) << "use local async mode";
temp_owned_graph = build_strategy.Apply(std::move(temp_owned_graph), {member_->places_[0]}, temp_owned_graph = build_strategy.Apply(
loss_var_name, {member_->local_scopes_[0]}, std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name,
member_->nranks_, member_->use_cuda_); {member_->local_scopes_[0]}, member_->nranks_, member_->use_cuda_);
} else { } else {
temp_owned_graph = build_strategy.Apply(std::move(temp_owned_graph), member_->places_, loss_var_name, temp_owned_graph = build_strategy.Apply(
member_->local_scopes_, member_->nranks_, std::move(temp_owned_graph), member_->places_, loss_var_name,
member_->use_cuda_); member_->local_scopes_, member_->nranks_, member_->use_cuda_);
} }
#endif #endif
...@@ -333,8 +334,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -333,8 +334,7 @@ ParallelExecutor::ParallelExecutor(
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
VLOG(3) << "use AsyncSSAGraphExecutor"; VLOG(3) << "use AsyncSSAGraphExecutor";
member_->executor_.reset(new details::AsyncSSAGraphExecutor( member_->executor_.reset(new details::AsyncSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, exec_strategy, member_->local_scopes_, member_->places_, graph));
graph));
} else if (build_strategy.enable_parallel_graph_) { } else if (build_strategy.enable_parallel_graph_) {
VLOG(3) << "use ParallelSSAGraphExecutor"; VLOG(3) << "use ParallelSSAGraphExecutor";
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册