From ba227df9419bbb2f8b3ac5636674c176cced3f19 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 20 Mar 2018 18:41:57 +0800 Subject: [PATCH] Expose num_threads --- paddle/fluid/framework/parallel_executor.cc | 6 +++--- paddle/fluid/framework/parallel_executor.h | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f1b8a20e4..bbfaac733 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -245,7 +245,7 @@ struct FetchOpHandle : public OpHandle { class ParallelExecutorPrivate { public: - explicit ParallelExecutorPrivate(size_t num_threads = 0) + explicit ParallelExecutorPrivate(size_t num_threads) : pool_(num_threads == 0 ? nullptr : new ThreadPool(num_threads)) {} std::vector places_; @@ -389,11 +389,11 @@ struct NCCLAllReduceOpHandle : public OpHandle { }; ParallelExecutor::ParallelExecutor( - const std::vector &places, + size_t num_threads, const std::vector &places, const std::unordered_set ¶ms, const ProgramDesc &startup_program, const ProgramDesc &main_program, const std::string &loss_var_name, Scope *scope) - : member_(new ParallelExecutorPrivate()) { + : member_(new ParallelExecutorPrivate(num_threads)) { member_->places_ = places; member_->global_scope_ = scope; // Step 1. RunStartupProgram and Bcast the params to devs. diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 2345bffcc..c206e726a 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -35,7 +35,8 @@ class VarHandleBase; class ParallelExecutor { public: - explicit ParallelExecutor(const std::vector& places, + explicit ParallelExecutor(size_t num_threads, + const std::vector& places, const std::unordered_set& params, const ProgramDesc& startup_program, const ProgramDesc& main_program, -- GitLab