提交 94c80347 编写于 作者: Y Yancey1989

update by comment

上级 af91444c
...@@ -34,7 +34,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ...@@ -34,7 +34,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
? 1UL ? 1UL
: strategy_.num_threads_ / places_.size(); : strategy_.num_threads_ / places_.size();
VLOG(1) << "set num_threads: " << strategy_.num_threads_ VLOG(1) << "set num_threads: " << strategy_.num_threads_
<< " to schedule operators on each device."; << " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor( executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i]))); strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i])));
...@@ -45,10 +45,10 @@ FeedFetchList ParallelSSAGraphExecutor::Run( ...@@ -45,10 +45,10 @@ FeedFetchList ParallelSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
std::vector<std::future<FeedFetchList>> run_futures; std::vector<std::future<FeedFetchList>> run_futures;
std::vector<FeedFetchList> fetch_datas; std::vector<FeedFetchList> fetch_data;
FeedFetchList ret; FeedFetchList ret;
fetch_datas.reserve(places_.size()); fetch_data.reserve(places_.size());
ret.reserve(fetch_tensors.size()); ret.reserve(fetch_tensors.size());
exception_holder_.Clear(); exception_holder_.Clear();
...@@ -65,7 +65,7 @@ FeedFetchList ParallelSSAGraphExecutor::Run( ...@@ -65,7 +65,7 @@ FeedFetchList ParallelSSAGraphExecutor::Run(
if (pool_) { if (pool_) {
run_futures.emplace_back(pool_->enqueue(std::move(call))); run_futures.emplace_back(pool_->enqueue(std::move(call)));
} else { } else {
fetch_datas.emplace_back(std::move(call())); fetch_data.emplace_back(std::move(call()));
} }
} }
...@@ -74,7 +74,7 @@ FeedFetchList ParallelSSAGraphExecutor::Run( ...@@ -74,7 +74,7 @@ FeedFetchList ParallelSSAGraphExecutor::Run(
if (exception_holder_.IsCaught()) { if (exception_holder_.IsCaught()) {
f.wait(); f.wait();
} else { } else {
fetch_datas.emplace_back(std::move(f.get())); fetch_data.emplace_back(std::move(f.get()));
} }
} }
} }
...@@ -86,7 +86,7 @@ FeedFetchList ParallelSSAGraphExecutor::Run( ...@@ -86,7 +86,7 @@ FeedFetchList ParallelSSAGraphExecutor::Run(
std::vector<const LoDTensor *> lodtensor_ptrs; std::vector<const LoDTensor *> lodtensor_ptrs;
lodtensor_ptrs.reserve(local_scopes_.size()); lodtensor_ptrs.reserve(local_scopes_.size());
for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) { for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) {
lodtensor_ptrs.push_back(&fetch_datas.at(scope_idx).at(fetch_idx)); lodtensor_ptrs.push_back(&fetch_data.at(scope_idx).at(fetch_idx));
} }
ret.emplace_back(); ret.emplace_back();
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace()); ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
......
...@@ -469,8 +469,9 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -469,8 +469,9 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
bool ParallelExecutor::EnableParallelGraphExecution( bool ParallelExecutor::EnableParallelGraphExecution(
const ProgramDesc &main_program, const ExecutionStrategy &exec_strategy, const ProgramDesc &main_program, const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy) const { const BuildStrategy &build_strategy) const {
bool enable_parallel_graph = true; if (!FLAGS_enable_parallel_graph) return false;
bool enable_parallel_graph = true;
// TODO(Yancey1989): support sparse update in ParallelGraph mode. // TODO(Yancey1989): support sparse update in ParallelGraph mode.
for (auto &var_desc : main_program.Block(0).AllVars()) { for (auto &var_desc : main_program.Block(0).AllVars()) {
if (var_desc->GetType() == proto::VarType::SELECTED_ROWS) { if (var_desc->GetType() == proto::VarType::SELECTED_ROWS) {
...@@ -492,7 +493,7 @@ bool ParallelExecutor::EnableParallelGraphExecution( ...@@ -492,7 +493,7 @@ bool ParallelExecutor::EnableParallelGraphExecution(
if (build_strategy.enable_sequential_execution_ || if (build_strategy.enable_sequential_execution_ ||
exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental) exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental)
enable_parallel_graph = false; enable_parallel_graph = false;
return enable_parallel_graph && FLAGS_enable_parallel_graph; return enable_parallel_graph;
} }
ParallelExecutor::~ParallelExecutor() { ParallelExecutor::~ParallelExecutor() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册