提交 ba227df9 编写于 作者: Y Yu Yang

Expose num_threads

上级 1533bf12
......@@ -245,7 +245,7 @@ struct FetchOpHandle : public OpHandle {
class ParallelExecutorPrivate {
public:
explicit ParallelExecutorPrivate(size_t num_threads = 0)
explicit ParallelExecutorPrivate(size_t num_threads)
: pool_(num_threads == 0 ? nullptr : new ThreadPool(num_threads)) {}
std::vector<platform::Place> places_;
......@@ -389,11 +389,11 @@ struct NCCLAllReduceOpHandle : public OpHandle {
};
ParallelExecutor::ParallelExecutor(
const std::vector<platform::Place> &places,
size_t num_threads, const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &params,
const ProgramDesc &startup_program, const ProgramDesc &main_program,
const std::string &loss_var_name, Scope *scope)
: member_(new ParallelExecutorPrivate()) {
: member_(new ParallelExecutorPrivate(num_threads)) {
member_->places_ = places;
member_->global_scope_ = scope;
// Step 1. RunStartupProgram and Bcast the params to devs.
......
......@@ -35,7 +35,8 @@ class VarHandleBase;
class ParallelExecutor {
public:
explicit ParallelExecutor(const std::vector<platform::Place>& places,
explicit ParallelExecutor(size_t num_threads,
const std::vector<platform::Place>& places,
const std::unordered_set<std::string>& params,
const ProgramDesc& startup_program,
const ProgramDesc& main_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册