提交 ea11a0a8 编写于 作者: Y Yu Yang

Use volitie

上级 515e516e
...@@ -97,6 +97,10 @@ struct ComputationOpHandle : public OpHandle { ...@@ -97,6 +97,10 @@ struct ComputationOpHandle : public OpHandle {
void Run() override { void Run() override {
// Wait other op if necessary // Wait other op if necessary
if (platform::is_gpu_place(place_)) {
int dev_id = boost::get<platform::CUDAPlace>(place_).device;
cudaSetDevice(dev_id);
}
auto *cur_ctx = dev_ctx_[place_]; auto *cur_ctx = dev_ctx_[place_];
for (auto *in : inputs_) { for (auto *in : inputs_) {
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) { if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
...@@ -637,22 +641,20 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -637,22 +641,20 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size()); auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size());
// Version --> VarHandle // Version --> VarHandle
member_->exception_.reset(); member_->exception_.reset();
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars; std::unordered_map<VarHandleBase *, volatile bool> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops; std::unordered_map<OpHandle *, size_t> pending_ops;
for (auto &place_pair : member_->vars_) { for (auto &place_pair : member_->vars_) {
for (auto &name_pair : place_pair.second) { for (auto &name_pair : place_pair.second) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
pending_vars[&version_pair.second].store( pending_vars[&version_pair.second] =
version_pair.second.generated_op_ == nullptr, version_pair.second.generated_op_ == nullptr;
std::memory_order_relaxed);
} }
} }
} }
for (auto &var : member_->dep_vars_) { for (auto &var : member_->dep_vars_) {
pending_vars[var.get()].store(var->generated_op_ == nullptr, pending_vars[var.get()] = var->generated_op_ == nullptr;
std::memory_order_relaxed);
} }
std::vector<OpHandle *> to_run; std::vector<OpHandle *> to_run;
...@@ -704,7 +706,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -704,7 +706,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
while (!pending_ops.empty()) { while (!pending_ops.empty()) {
VarHandleBase *ready_var = nullptr; VarHandleBase *ready_var = nullptr;
for (auto &pair : pending_vars) { for (auto &pair : pending_vars) {
if (pair.second.load(std::memory_order_acquire)) { if (pair.second) {
ready_var = pair.first; ready_var = pair.first;
} }
} }
...@@ -737,10 +739,10 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -737,10 +739,10 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
void ParallelExecutor::RunOp( void ParallelExecutor::RunOp(
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars, std::unordered_map<VarHandleBase *, volatile bool> &pending_vars,
OpHandle *op) const { OpHandle *op) const {
std::vector<std::atomic<bool> *> *ready_buffer = std::vector<volatile bool *> *ready_buffer =
new std::vector<std::atomic<bool> *>(); new std::vector<volatile bool *>();
for (auto *var : op->outputs_) { for (auto *var : op->outputs_) {
ready_buffer->emplace_back(&pending_vars[var]); ready_buffer->emplace_back(&pending_vars[var]);
} }
...@@ -751,7 +753,7 @@ void ParallelExecutor::RunOp( ...@@ -751,7 +753,7 @@ void ParallelExecutor::RunOp(
op->Run(); op->Run();
VLOG(10) << "Done " << this; VLOG(10) << "Done " << this;
for (auto *ready : *ready_buffer) { for (auto *ready : *ready_buffer) {
ready->store(true, std::memory_order_release); *ready = true;
} }
delete ready_buffer; delete ready_buffer;
} catch (platform::EnforceNotMet ex) { } catch (platform::EnforceNotMet ex) {
......
...@@ -60,9 +60,8 @@ class ParallelExecutor { ...@@ -60,9 +60,8 @@ class ParallelExecutor {
void BuildNCCLCommunicator() const; void BuildNCCLCommunicator() const;
void RunOp( void RunOp(std::unordered_map<VarHandleBase*, volatile bool>& pending_vars,
std::unordered_map<VarHandleBase*, std::atomic<bool>>& pending_vars, OpHandle* op) const;
OpHandle* op) const;
void PolishGraphToSupportDataHarzaeds() const; void PolishGraphToSupportDataHarzaeds() const;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册