提交 a87ce91c 编写于 作者: Y Yu Yang

Use mtx

上级 ea11a0a8
...@@ -641,7 +641,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -641,7 +641,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size()); auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size());
// Version --> VarHandle // Version --> VarHandle
member_->exception_.reset(); member_->exception_.reset();
std::unordered_map<VarHandleBase *, volatile bool> pending_vars; std::unordered_map<VarHandleBase *, GuardedBool> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops; std::unordered_map<OpHandle *, size_t> pending_ops;
for (auto &place_pair : member_->vars_) { for (auto &place_pair : member_->vars_) {
...@@ -739,10 +739,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -739,10 +739,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
void ParallelExecutor::RunOp( void ParallelExecutor::RunOp(
std::unordered_map<VarHandleBase *, volatile bool> &pending_vars, std::unordered_map<VarHandleBase *, GuardedBool> &pending_vars,
OpHandle *op) const { OpHandle *op) const {
std::vector<volatile bool *> *ready_buffer = std::vector<GuardedBool *> *ready_buffer = new std::vector<GuardedBool *>();
new std::vector<volatile bool *>();
for (auto *var : op->outputs_) { for (auto *var : op->outputs_) {
ready_buffer->emplace_back(&pending_vars[var]); ready_buffer->emplace_back(&pending_vars[var]);
} }
......
...@@ -32,6 +32,27 @@ class ParallelExecutorPrivate; ...@@ -32,6 +32,27 @@ class ParallelExecutorPrivate;
class VarHandle; class VarHandle;
class OpHandle; class OpHandle;
class VarHandleBase; class VarHandleBase;
struct GuardedBool {
public:
GuardedBool() {}
operator bool() const {
std::lock_guard<std::mutex> g(mtx_);
return value_;
}
GuardedBool& operator=(bool o) {
std::lock_guard<std::mutex> g(mtx_);
value_ = o;
return *this;
}
private:
mutable std::mutex mtx_;
bool value_;
};
class ParallelExecutor { class ParallelExecutor {
public: public:
explicit ParallelExecutor(const std::vector<platform::Place>& places, explicit ParallelExecutor(const std::vector<platform::Place>& places,
...@@ -60,7 +81,7 @@ class ParallelExecutor { ...@@ -60,7 +81,7 @@ class ParallelExecutor {
void BuildNCCLCommunicator() const; void BuildNCCLCommunicator() const;
void RunOp(std::unordered_map<VarHandleBase*, volatile bool>& pending_vars, void RunOp(std::unordered_map<VarHandleBase*, GuardedBool>& pending_vars,
OpHandle* op) const; OpHandle* op) const;
void PolishGraphToSupportDataHarzaeds() const; void PolishGraphToSupportDataHarzaeds() const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册