From f051c768e59da9ff993891df7e2c7e20ecbd97da Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 9 Apr 2018 13:49:54 +0800 Subject: [PATCH] Simplify DataStructure in SSAGraph --- .../details/multi_devices_graph_builder.cc | 15 +++++----- paddle/fluid/framework/details/ssa_graph.h | 6 +++- .../framework/details/ssa_graph_builder.cc | 30 ++++++++++--------- .../details/threaded_ssa_graph_executor.cc | 4 +-- 4 files changed, 31 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 128a5344fbb..01f5da96316 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -147,15 +147,16 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( if (vars.empty()) { // This device has no data. continue. continue; } - auto *prev_grad = &vars[vars.size() - 1]; - op_handle->AddInput(prev_grad); + auto &prev_grad = vars[vars.size() - 1]; + op_handle->AddInput(prev_grad.get()); - auto &var = vars[vars.size()]; - var.place_ = p; - var.name_ = og; - var.version_ = vars.size() - 1; + vars.emplace_back(new VarHandle); + auto &var = vars.back(); + var->place_ = p; + var->name_ = og; + var->version_ = vars.size() - 1; - op_handle->AddOutput(&var); + op_handle->AddOutput(var.get()); } #else PADDLE_ENFORCE("Not implemented"); diff --git a/paddle/fluid/framework/details/ssa_graph.h b/paddle/fluid/framework/details/ssa_graph.h index ac3e2d86993..72684e7f97f 100644 --- a/paddle/fluid/framework/details/ssa_graph.h +++ b/paddle/fluid/framework/details/ssa_graph.h @@ -16,6 +16,8 @@ #include #include +#include + #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/var_handle.h" @@ -24,7 +26,9 @@ namespace framework { namespace details { struct SSAGraph { - std::vector>> vars_; + std::vector< + std::unordered_map>>> + vars_; // aux variables to represent dependency. Useful to resolve data hazard. std::unordered_set> dep_vars_; std::vector> ops_; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 0a4febd22f3..be5fb757758 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { auto it_old = name_pair.second.rbegin(); ++it_old; for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { - auto *write_op = it_new->second.generated_op_; - auto &read_ops = it_old->second.pending_ops_; + auto *write_op = (*it_new)->generated_op_; + auto &read_ops = (*it_old)->pending_ops_; for (auto *read_op : read_ops) { // Manually add a dependency var from read_op to write_op; @@ -54,14 +54,15 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( auto &var_holder = var_holders[each_var_name]; VarHandle *var = nullptr; if (var_holder.empty()) { + var_holder.emplace_back(new VarHandle); auto &init_var = var_holder[0]; - init_var.place_ = place; - init_var.name_ = each_var_name; - init_var.generated_op_ = nullptr; - init_var.version_ = 0; - var = &init_var; + init_var->place_ = place; + init_var->name_ = each_var_name; + init_var->generated_op_ = nullptr; + init_var->version_ = 0; + var = init_var.get(); } else { - var = &var_holder.rbegin()->second; + var = var_holder.rbegin()->get(); } return var; } @@ -72,11 +73,12 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, size_t place_offset) { auto &vars = graph->vars_[place_offset][each_var_name]; size_t version = vars.size(); - auto &var = vars[version]; - var.version_ = version; - var.name_ = each_var_name; - var.place_ = place; - op_handle->AddOutput(&var); + vars.emplace_back(new VarHandle()); + auto &var = vars.back(); + var->version_ = version; + var->name_ = each_var_name; + var->place_ = place; + op_handle->AddOutput(var.get()); } template @@ -84,7 +86,7 @@ void IterAllVar(const SSAGraph &graph, Callback callback) { for (auto &each : graph.vars_) { for (auto &pair1 : each) { for (auto &pair2 : pair1.second) { - callback(pair2.second); + callback(*pair2); } } } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 596e5731868..62af4c1d79d 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -69,7 +69,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( for (auto &var_map : graph_->vars_) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { - InsertPendingVar(version_pair.second); + InsertPendingVar(*version_pair); } } } @@ -95,7 +95,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( for (auto &var_map : graph_->vars_) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { - fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second); + fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); } } } -- GitLab