提交 dab7f369 编写于 作者: Q Qiao Longfei

optimize code test=develop

上级 cf0511f2
......@@ -21,12 +21,12 @@ namespace details {
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph)
ir::Graph* graph)
: strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
graph_(std::move(graph)) {
graph_(graph) {
VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
......@@ -38,7 +38,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
<< " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, graph_.get()));
strategy_, {local_scopes_[i]}, {places_[i]}, graph_));
}
}
......
......@@ -29,7 +29,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph);
ir::Graph *graph);
~AsyncSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graph_; }
......@@ -40,7 +40,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_;
std::unique_ptr<ir::Graph> graph_;
ir::Graph *graph_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
ExceptionHolder exception_holder_;
......
......@@ -269,25 +269,26 @@ ParallelExecutor::ParallelExecutor(
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
VLOG(3) << "use local async mode";
temp_owned_graph =
build_strategy.Apply(std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, member_->nranks_,
member_->use_cuda_, member_->nccl_ctxs_.get());
temp_owned_graph = build_strategy.Apply(
std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, member_->nranks_, member_->use_cuda_,
member_->nccl_ctxs_.get());
} else {
temp_owned_graph = build_strategy.Apply(std::move(temp_owned_graph), member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_,
member_->use_cuda_, member_->nccl_ctxs_.get());
temp_owned_graph = build_strategy.Apply(
std::move(temp_owned_graph), member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_, member_->use_cuda_,
member_->nccl_ctxs_.get());
}
#else
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
VLOG(3) << "use local async mode";
temp_owned_graph = build_strategy.Apply(std::move(temp_owned_graph), {member_->places_[0]},
loss_var_name, {member_->local_scopes_[0]},
member_->nranks_, member_->use_cuda_);
temp_owned_graph = build_strategy.Apply(
std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, member_->nranks_, member_->use_cuda_);
} else {
temp_owned_graph = build_strategy.Apply(std::move(temp_owned_graph), member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_,
member_->use_cuda_);
temp_owned_graph = build_strategy.Apply(
std::move(temp_owned_graph), member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_, member_->use_cuda_);
}
#endif
......@@ -333,8 +334,7 @@ ParallelExecutor::ParallelExecutor(
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
VLOG(3) << "use AsyncSSAGraphExecutor";
member_->executor_.reset(new details::AsyncSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_,
graph));
exec_strategy, member_->local_scopes_, member_->places_, graph));
} else if (build_strategy.enable_parallel_graph_) {
VLOG(3) << "use ParallelSSAGraphExecutor";
#ifdef PADDLE_WITH_CUDA
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册