提交 1f53193a 编写于 作者: Y Yu Yang

Use atomic code

上级 c7beac14
...@@ -645,7 +645,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -645,7 +645,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size()); auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size());
// Version --> VarHandle // Version --> VarHandle
member_->exception_.reset(); member_->exception_.reset();
std::unordered_map<VarHandleBase *, GuardedBool> pending_vars; std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops; std::unordered_map<OpHandle *, size_t> pending_ops;
std::vector<DummyVarHandle> dummy_vars; std::vector<DummyVarHandle> dummy_vars;
...@@ -694,7 +694,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -694,7 +694,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
op->offset_ = i; op->offset_ = i;
op->local_scopes_ = &member_->local_scopes_; op->local_scopes_ = &member_->local_scopes_;
for (auto &p : member_->places_) { for (auto &p : member_->places_) {
op->dev_ctx_[p] = this->member_->GetNCCLCtx(p).ctx_.get(); op->dev_ctx_[p] = member_->GetNCCLCtx(p).ctx_.get();
} }
for (auto *var : vars) { for (auto *var : vars) {
...@@ -718,7 +718,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -718,7 +718,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
while (!pending_vars.empty()) { while (!pending_vars.empty()) {
VarHandleBase *ready_var = nullptr; VarHandleBase *ready_var = nullptr;
for (auto &pair : pending_vars) { for (auto &pair : pending_vars) {
if (pair.second) { if (pair.second.load(std::memory_order_consume)) {
ready_var = pair.first; ready_var = pair.first;
} }
} }
...@@ -750,9 +750,10 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -750,9 +750,10 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
void ParallelExecutor::RunOp( void ParallelExecutor::RunOp(
std::unordered_map<VarHandleBase *, GuardedBool> &pending_vars, std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
OpHandle *op) const { OpHandle *op) const {
std::vector<GuardedBool *> *ready_buffer = new std::vector<GuardedBool *>(); std::vector<std::atomic<bool> *> *ready_buffer =
new std::vector<std::atomic<bool> *>();
for (auto *var : op->outputs_) { for (auto *var : op->outputs_) {
ready_buffer->emplace_back(&pending_vars[var]); ready_buffer->emplace_back(&pending_vars[var]);
} }
...@@ -761,7 +762,7 @@ void ParallelExecutor::RunOp( ...@@ -761,7 +762,7 @@ void ParallelExecutor::RunOp(
try { try {
op->Run(); op->Run();
for (auto *ready : *ready_buffer) { for (auto *ready : *ready_buffer) {
*ready = true; ready->store(true, std::memory_order_release);
} }
delete ready_buffer; delete ready_buffer;
} catch (platform::EnforceNotMet ex) { } catch (platform::EnforceNotMet ex) {
......
...@@ -33,26 +33,6 @@ class VarHandle; ...@@ -33,26 +33,6 @@ class VarHandle;
class OpHandle; class OpHandle;
class VarHandleBase; class VarHandleBase;
struct GuardedBool {
public:
GuardedBool() {}
operator bool() const {
std::lock_guard<std::mutex> g(mtx_);
return value_;
}
GuardedBool& operator=(bool o) {
std::lock_guard<std::mutex> g(mtx_);
value_ = o;
return *this;
}
private:
mutable std::mutex mtx_;
bool value_;
};
class ParallelExecutor { class ParallelExecutor {
public: public:
explicit ParallelExecutor(const std::vector<platform::Place>& places, explicit ParallelExecutor(const std::vector<platform::Place>& places,
...@@ -81,8 +61,9 @@ class ParallelExecutor { ...@@ -81,8 +61,9 @@ class ParallelExecutor {
void BuildNCCLCommunicator() const; void BuildNCCLCommunicator() const;
void RunOp(std::unordered_map<VarHandleBase*, GuardedBool>& pending_vars, void RunOp(
OpHandle* op) const; std::unordered_map<VarHandleBase*, std::atomic<bool>>& pending_vars,
OpHandle* op) const;
void PolishGraphToSupportDataHarzaeds() const; void PolishGraphToSupportDataHarzaeds() const;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册