提交 c7beac14 编写于 作者: Y Yu Yang

Add dummy var

上级 5fa535b7
...@@ -53,6 +53,10 @@ struct VarHandle : public VarHandleBase { ...@@ -53,6 +53,10 @@ struct VarHandle : public VarHandleBase {
platform::Place place_; platform::Place place_;
}; };
struct DummyVarHandle : public VarHandleBase {
std::string DebugString() const override { return "dummy"; }
};
struct DependencyVarHandle : public VarHandleBase { struct DependencyVarHandle : public VarHandleBase {
std::string DebugString() const override { return "Dependency Variable"; } std::string DebugString() const override { return "Dependency Variable"; }
}; };
...@@ -643,6 +647,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -643,6 +647,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
member_->exception_.reset(); member_->exception_.reset();
std::unordered_map<VarHandleBase *, GuardedBool> pending_vars; std::unordered_map<VarHandleBase *, GuardedBool> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops; std::unordered_map<OpHandle *, size_t> pending_ops;
std::vector<DummyVarHandle> dummy_vars;
for (auto &place_pair : member_->vars_) { for (auto &place_pair : member_->vars_) {
for (auto &name_pair : place_pair.second) { for (auto &name_pair : place_pair.second) {
...@@ -696,17 +701,21 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -696,17 +701,21 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
var->pending_ops_.emplace(op); var->pending_ops_.emplace(op);
op->inputs_.emplace_back(var); op->inputs_.emplace_back(var);
} }
dummy_vars.emplace_back();
auto *var = &dummy_vars.back();
op->outputs_.emplace_back(var);
var->generated_op_ = op;
pending_vars[var] = false;
pending_ops.insert({op, op->inputs_.size()}); pending_ops.insert({op, op->inputs_.size()});
} }
std::vector<std::future<void>> op_threads;
op_threads.reserve(pending_ops.size() + to_run.size());
for (auto *op : to_run) { for (auto *op : to_run) {
op_threads.emplace_back(RunOp(pending_vars, op)); RunOp(pending_vars, op);
} }
while (!pending_ops.empty()) { while (!pending_vars.empty()) {
VarHandleBase *ready_var = nullptr; VarHandleBase *ready_var = nullptr;
for (auto &pair : pending_vars) { for (auto &pair : pending_vars) {
if (pair.second) { if (pair.second) {
...@@ -715,12 +724,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -715,12 +724,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
if (ready_var == nullptr) { if (ready_var == nullptr) {
// FIXME use conditional var instead of busy wait. // FIXME use conditional var instead of busy wait.
if (member_->exception_) { if (member_->exception_) {
throw * member_->exception_; throw * member_->exception_;
} }
VLOG(3) << pending_vars.size();
continue; continue;
} }
pending_vars.erase(ready_var); pending_vars.erase(ready_var);
...@@ -734,20 +740,16 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -734,20 +740,16 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
for (auto *op : to_run) { for (auto *op : to_run) {
pending_ops.erase(op); pending_ops.erase(op);
op_threads.emplace_back(RunOp(pending_vars, op)); RunOp(pending_vars, op);
} }
} }
for (auto &t : op_threads) {
t.get(); // Join all workers
}
fetch_ops.clear(); fetch_ops.clear();
*member_->global_scope_->Var(fetched_var_name)->GetMutable<LoDTensorArray>() = *member_->global_scope_->Var(fetched_var_name)->GetMutable<LoDTensorArray>() =
fetched_data->tensors_; fetched_data->tensors_;
} }
std::future<void> ParallelExecutor::RunOp( void ParallelExecutor::RunOp(
std::unordered_map<VarHandleBase *, GuardedBool> &pending_vars, std::unordered_map<VarHandleBase *, GuardedBool> &pending_vars,
OpHandle *op) const { OpHandle *op) const {
std::vector<GuardedBool *> *ready_buffer = new std::vector<GuardedBool *>(); std::vector<GuardedBool *> *ready_buffer = new std::vector<GuardedBool *>();
...@@ -768,7 +770,7 @@ std::future<void> ParallelExecutor::RunOp( ...@@ -768,7 +770,7 @@ std::future<void> ParallelExecutor::RunOp(
LOG(FATAL) << "Unknown exception catched"; LOG(FATAL) << "Unknown exception catched";
} }
}; };
return member_->pool_.enqueue(op_run); member_->pool_.enqueue(op_run);
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -81,8 +81,7 @@ class ParallelExecutor { ...@@ -81,8 +81,7 @@ class ParallelExecutor {
void BuildNCCLCommunicator() const; void BuildNCCLCommunicator() const;
std::future<void> RunOp( void RunOp(std::unordered_map<VarHandleBase*, GuardedBool>& pending_vars,
std::unordered_map<VarHandleBase*, GuardedBool>& pending_vars,
OpHandle* op) const; OpHandle* op) const;
void PolishGraphToSupportDataHarzaeds() const; void PolishGraphToSupportDataHarzaeds() const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册