提交 a87ce91c 编写于 作者: Y Yu Yang

Use mtx

上级 ea11a0a8
......@@ -641,7 +641,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size());
// Version --> VarHandle
member_->exception_.reset();
std::unordered_map<VarHandleBase *, volatile bool> pending_vars;
std::unordered_map<VarHandleBase *, GuardedBool> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops;
for (auto &place_pair : member_->vars_) {
......@@ -739,10 +739,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
void ParallelExecutor::RunOp(
std::unordered_map<VarHandleBase *, volatile bool> &pending_vars,
std::unordered_map<VarHandleBase *, GuardedBool> &pending_vars,
OpHandle *op) const {
std::vector<volatile bool *> *ready_buffer =
new std::vector<volatile bool *>();
std::vector<GuardedBool *> *ready_buffer = new std::vector<GuardedBool *>();
for (auto *var : op->outputs_) {
ready_buffer->emplace_back(&pending_vars[var]);
}
......
......@@ -32,6 +32,27 @@ class ParallelExecutorPrivate;
class VarHandle;
class OpHandle;
class VarHandleBase;
struct GuardedBool {
public:
GuardedBool() {}
operator bool() const {
std::lock_guard<std::mutex> g(mtx_);
return value_;
}
GuardedBool& operator=(bool o) {
std::lock_guard<std::mutex> g(mtx_);
value_ = o;
return *this;
}
private:
mutable std::mutex mtx_;
bool value_;
};
class ParallelExecutor {
public:
explicit ParallelExecutor(const std::vector<platform::Place>& places,
......@@ -60,7 +81,7 @@ class ParallelExecutor {
void BuildNCCLCommunicator() const;
void RunOp(std::unordered_map<VarHandleBase*, volatile bool>& pending_vars,
void RunOp(std::unordered_map<VarHandleBase*, GuardedBool>& pending_vars,
OpHandle* op) const;
void PolishGraphToSupportDataHarzaeds() const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册