diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f1b8a20e41cc223b5e68e66eaa8221c4aec01295..bbfaac7339d0becf4c83d97c8c1fecb31d75a02c 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -245,7 +245,7 @@ struct FetchOpHandle : public OpHandle { class ParallelExecutorPrivate { public: - explicit ParallelExecutorPrivate(size_t num_threads = 0) + explicit ParallelExecutorPrivate(size_t num_threads) : pool_(num_threads == 0 ? nullptr : new ThreadPool(num_threads)) {} std::vector places_; @@ -389,11 +389,11 @@ struct NCCLAllReduceOpHandle : public OpHandle { }; ParallelExecutor::ParallelExecutor( - const std::vector &places, + size_t num_threads, const std::vector &places, const std::unordered_set ¶ms, const ProgramDesc &startup_program, const ProgramDesc &main_program, const std::string &loss_var_name, Scope *scope) - : member_(new ParallelExecutorPrivate()) { + : member_(new ParallelExecutorPrivate(num_threads)) { member_->places_ = places; member_->global_scope_ = scope; // Step 1. RunStartupProgram and Bcast the params to devs. diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 2345bffcc765d41e974d3a2be7fb8346544f2ae8..c206e726a71d1c8729ee65213f83118d4cde7d1a 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -35,7 +35,8 @@ class VarHandleBase; class ParallelExecutor { public: - explicit ParallelExecutor(const std::vector& places, + explicit ParallelExecutor(size_t num_threads, + const std::vector& places, const std::unordered_set& params, const ProgramDesc& startup_program, const ProgramDesc& main_program,