提交 a7b0d5bd 编写于 作者: Y Yu Yang

Clean code

上级 e3144393
...@@ -27,15 +27,16 @@ namespace framework { ...@@ -27,15 +27,16 @@ namespace framework {
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
public: public:
explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places) explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places)
: places_(places), fetch_dev_ctxs_(places) {} : places_(places) {}
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_dev_ctxs_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
Scope *global_scope_; Scope *global_scope_;
std::unique_ptr<details::SSAGraphExecutor> executor_;
#ifdef PADDLE_WITH_CUDA
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_; std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
std::unique_ptr<details::SSAGraphExecutor> executor_; #endif
}; };
ParallelExecutor::ParallelExecutor( ParallelExecutor::ParallelExecutor(
...@@ -54,8 +55,10 @@ ParallelExecutor::ParallelExecutor( ...@@ -54,8 +55,10 @@ ParallelExecutor::ParallelExecutor(
member_->local_scopes_.push_back(&scope->NewScope()); member_->local_scopes_.push_back(&scope->NewScope());
} }
// Bcast Parameters to all GPUs // Bcast Parameters to all GPUs
BuildNCCLCommunicator(); #ifdef PADDLE_WITH_CUDA
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
#endif
if (platform::is_gpu_place(places[0]) && if (platform::is_gpu_place(places[0]) &&
member_->local_scopes_.size() != 1) { // Is CUDA member_->local_scopes_.size() != 1) { // Is CUDA
BCastParamsToGPUs(startup_program); BCastParamsToGPUs(startup_program);
...@@ -123,12 +126,6 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -123,12 +126,6 @@ void ParallelExecutor::BCastParamsToGPUs(
#endif #endif
} }
void ParallelExecutor::BuildNCCLCommunicator() const {
#ifdef PADDLE_WITH_CUDA
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
#endif
}
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) { const std::string &fetched_var_name) {
auto fetch_data = member_->executor_->Run(fetch_tensors); auto fetch_data = member_->executor_->Run(fetch_tensors);
......
...@@ -31,6 +31,8 @@ namespace framework { ...@@ -31,6 +31,8 @@ namespace framework {
class ParallelExecutorPrivate; class ParallelExecutorPrivate;
class ParallelExecutor { class ParallelExecutor {
DISABLE_COPY_AND_ASSIGN(ParallelExecutor);
public: public:
explicit ParallelExecutor(size_t num_threads, explicit ParallelExecutor(size_t num_threads,
const std::vector<platform::Place>& places, const std::vector<platform::Place>& places,
...@@ -46,8 +48,6 @@ class ParallelExecutor { ...@@ -46,8 +48,6 @@ class ParallelExecutor {
ParallelExecutorPrivate* member_; ParallelExecutorPrivate* member_;
void BCastParamsToGPUs(const ProgramDesc& startup_program) const; void BCastParamsToGPUs(const ProgramDesc& startup_program) const;
void BuildNCCLCommunicator() const;
}; };
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册