提交 c18c2f6a 编写于 作者: Y Yu Yang

Sync all computation streams at the end of run

上级 c372ce28
...@@ -482,7 +482,6 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -482,7 +482,6 @@ void ParallelExecutor::ConstructDependencyGraph(
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : main_program.Block(0).AllOps()) { for (auto *op : main_program.Block(0).AllOps()) {
bool change_forward = false; bool change_forward = false;
if (!is_forwarding) { if (!is_forwarding) {
// FIXME(yy): Do not hard code like this // FIXME(yy): Do not hard code like this
if (op->OutputArgumentNames().size() == 1 && if (op->OutputArgumentNames().size() == 1 &&
...@@ -573,7 +572,7 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -573,7 +572,7 @@ void ParallelExecutor::ConstructDependencyGraph(
Dependency graph has been constructed. However, there are still data Dependency graph has been constructed. However, there are still data
harzaeds need to be handled. harzaeds need to be handled.
*/ */
PolishGraphToSupportDataHarzaeds(); PolishGraphToSupportDataHazards();
} }
/** /**
...@@ -583,7 +582,7 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -583,7 +582,7 @@ void ParallelExecutor::ConstructDependencyGraph(
* *
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/ */
void ParallelExecutor::PolishGraphToSupportDataHarzaeds() const { void ParallelExecutor::PolishGraphToSupportDataHazards() const {
for (auto &place_pair : member_->vars_) { for (auto &place_pair : member_->vars_) {
for (auto &name_pair : place_pair.second) { for (auto &name_pair : place_pair.second) {
if (name_pair.second.size() <= 1) { if (name_pair.second.size() <= 1) {
...@@ -813,6 +812,13 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -813,6 +812,13 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
fetch_ops.clear(); fetch_ops.clear();
*member_->global_scope_->Var(fetched_var_name)->GetMutable<LoDTensorArray>() = *member_->global_scope_->Var(fetched_var_name)->GetMutable<LoDTensorArray>() =
fetched_data->tensors_; fetched_data->tensors_;
// FIXME:
// It could be optimized by using multiple events in an operator.
// Manually sync computation during iter.
for (auto &p : member_->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
} }
void ParallelExecutor::RunOp( void ParallelExecutor::RunOp(
......
...@@ -65,7 +65,7 @@ class ParallelExecutor { ...@@ -65,7 +65,7 @@ class ParallelExecutor {
std::unordered_map<VarHandleBase*, std::atomic<bool>>& pending_vars, std::unordered_map<VarHandleBase*, std::atomic<bool>>& pending_vars,
OpHandle* op) const; OpHandle* op) const;
void PolishGraphToSupportDataHarzaeds() const; void PolishGraphToSupportDataHazards() const;
}; };
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册