提交 ea11a0a8 编写于 作者: Y Yu Yang

Use volitie

上级 515e516e
......@@ -97,6 +97,10 @@ struct ComputationOpHandle : public OpHandle {
void Run() override {
// Wait other op if necessary
if (platform::is_gpu_place(place_)) {
int dev_id = boost::get<platform::CUDAPlace>(place_).device;
cudaSetDevice(dev_id);
}
auto *cur_ctx = dev_ctx_[place_];
for (auto *in : inputs_) {
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
......@@ -637,22 +641,20 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size());
// Version --> VarHandle
member_->exception_.reset();
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
std::unordered_map<VarHandleBase *, volatile bool> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops;
for (auto &place_pair : member_->vars_) {
for (auto &name_pair : place_pair.second) {
for (auto &version_pair : name_pair.second) {
pending_vars[&version_pair.second].store(
version_pair.second.generated_op_ == nullptr,
std::memory_order_relaxed);
pending_vars[&version_pair.second] =
version_pair.second.generated_op_ == nullptr;
}
}
}
for (auto &var : member_->dep_vars_) {
pending_vars[var.get()].store(var->generated_op_ == nullptr,
std::memory_order_relaxed);
pending_vars[var.get()] = var->generated_op_ == nullptr;
}
std::vector<OpHandle *> to_run;
......@@ -704,7 +706,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
while (!pending_ops.empty()) {
VarHandleBase *ready_var = nullptr;
for (auto &pair : pending_vars) {
if (pair.second.load(std::memory_order_acquire)) {
if (pair.second) {
ready_var = pair.first;
}
}
......@@ -737,10 +739,10 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
void ParallelExecutor::RunOp(
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
std::unordered_map<VarHandleBase *, volatile bool> &pending_vars,
OpHandle *op) const {
std::vector<std::atomic<bool> *> *ready_buffer =
new std::vector<std::atomic<bool> *>();
std::vector<volatile bool *> *ready_buffer =
new std::vector<volatile bool *>();
for (auto *var : op->outputs_) {
ready_buffer->emplace_back(&pending_vars[var]);
}
......@@ -751,7 +753,7 @@ void ParallelExecutor::RunOp(
op->Run();
VLOG(10) << "Done " << this;
for (auto *ready : *ready_buffer) {
ready->store(true, std::memory_order_release);
*ready = true;
}
delete ready_buffer;
} catch (platform::EnforceNotMet ex) {
......
......@@ -60,9 +60,8 @@ class ParallelExecutor {
void BuildNCCLCommunicator() const;
void RunOp(
std::unordered_map<VarHandleBase*, std::atomic<bool>>& pending_vars,
OpHandle* op) const;
void RunOp(std::unordered_map<VarHandleBase*, volatile bool>& pending_vars,
OpHandle* op) const;
void PolishGraphToSupportDataHarzaeds() const;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册