You need to sign in or sign up before continuing.
提交 c7beac14 编写于 作者: Y Yu Yang

Add dummy var

上级 5fa535b7
......@@ -53,6 +53,10 @@ struct VarHandle : public VarHandleBase {
platform::Place place_;
};
struct DummyVarHandle : public VarHandleBase {
std::string DebugString() const override { return "dummy"; }
};
struct DependencyVarHandle : public VarHandleBase {
std::string DebugString() const override { return "Dependency Variable"; }
};
......@@ -643,6 +647,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
member_->exception_.reset();
std::unordered_map<VarHandleBase *, GuardedBool> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops;
std::vector<DummyVarHandle> dummy_vars;
for (auto &place_pair : member_->vars_) {
for (auto &name_pair : place_pair.second) {
......@@ -696,17 +701,21 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
var->pending_ops_.emplace(op);
op->inputs_.emplace_back(var);
}
dummy_vars.emplace_back();
auto *var = &dummy_vars.back();
op->outputs_.emplace_back(var);
var->generated_op_ = op;
pending_vars[var] = false;
pending_ops.insert({op, op->inputs_.size()});
}
std::vector<std::future<void>> op_threads;
op_threads.reserve(pending_ops.size() + to_run.size());
for (auto *op : to_run) {
op_threads.emplace_back(RunOp(pending_vars, op));
RunOp(pending_vars, op);
}
while (!pending_ops.empty()) {
while (!pending_vars.empty()) {
VarHandleBase *ready_var = nullptr;
for (auto &pair : pending_vars) {
if (pair.second) {
......@@ -715,12 +724,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
if (ready_var == nullptr) {
// FIXME use conditional var instead of busy wait.
if (member_->exception_) {
throw * member_->exception_;
}
VLOG(3) << pending_vars.size();
continue;
}
pending_vars.erase(ready_var);
......@@ -734,20 +740,16 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
for (auto *op : to_run) {
pending_ops.erase(op);
op_threads.emplace_back(RunOp(pending_vars, op));
RunOp(pending_vars, op);
}
}
for (auto &t : op_threads) {
t.get(); // Join all workers
}
fetch_ops.clear();
*member_->global_scope_->Var(fetched_var_name)->GetMutable<LoDTensorArray>() =
fetched_data->tensors_;
}
std::future<void> ParallelExecutor::RunOp(
void ParallelExecutor::RunOp(
std::unordered_map<VarHandleBase *, GuardedBool> &pending_vars,
OpHandle *op) const {
std::vector<GuardedBool *> *ready_buffer = new std::vector<GuardedBool *>();
......@@ -768,7 +770,7 @@ std::future<void> ParallelExecutor::RunOp(
LOG(FATAL) << "Unknown exception catched";
}
};
return member_->pool_.enqueue(op_run);
member_->pool_.enqueue(op_run);
}
} // namespace framework
} // namespace paddle
......@@ -81,9 +81,8 @@ class ParallelExecutor {
void BuildNCCLCommunicator() const;
std::future<void> RunOp(
std::unordered_map<VarHandleBase*, GuardedBool>& pending_vars,
OpHandle* op) const;
void RunOp(std::unordered_map<VarHandleBase*, GuardedBool>& pending_vars,
OpHandle* op) const;
void PolishGraphToSupportDataHarzaeds() const;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册