提交 5fa535b7 编写于 作者: Y Yu Yang

Wait all thread done

上级 7bff02b2
......@@ -699,8 +699,11 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
pending_ops.insert({op, op->inputs_.size()});
}
std::vector<std::future<void>> op_threads;
op_threads.reserve(pending_ops.size() + to_run.size());
for (auto *op : to_run) {
RunOp(pending_vars, op);
op_threads.emplace_back(RunOp(pending_vars, op));
}
while (!pending_ops.empty()) {
......@@ -731,15 +734,20 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
for (auto *op : to_run) {
pending_ops.erase(op);
RunOp(pending_vars, op);
op_threads.emplace_back(RunOp(pending_vars, op));
}
}
for (auto &t : op_threads) {
t.get(); // Join all workers
}
fetch_ops.clear();
*member_->global_scope_->Var(fetched_var_name)->GetMutable<LoDTensorArray>() =
fetched_data->tensors_;
}
void ParallelExecutor::RunOp(
std::future<void> ParallelExecutor::RunOp(
std::unordered_map<VarHandleBase *, GuardedBool> &pending_vars,
OpHandle *op) const {
std::vector<GuardedBool *> *ready_buffer = new std::vector<GuardedBool *>();
......@@ -760,7 +768,7 @@ void ParallelExecutor::RunOp(
LOG(FATAL) << "Unknown exception catched";
}
};
member_->pool_.enqueue(op_run);
return member_->pool_.enqueue(op_run);
}
} // namespace framework
} // namespace paddle
......@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once
#include <future>
#include <unordered_set>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -81,8 +81,9 @@ class ParallelExecutor {
void BuildNCCLCommunicator() const;
void RunOp(std::unordered_map<VarHandleBase*, GuardedBool>& pending_vars,
OpHandle* op) const;
std::future<void> RunOp(
std::unordered_map<VarHandleBase*, GuardedBool>& pending_vars,
OpHandle* op) const;
void PolishGraphToSupportDataHarzaeds() const;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册