未验证 提交 161344bf 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #9774 from reyoung/feature/simplify_data_structures

Simplify DataStructure in SSAGraph
...@@ -59,7 +59,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -59,7 +59,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto graph = new SSAGraph(); auto graph = new SSAGraph();
SSAGraph &result = *graph; SSAGraph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast; std::unordered_set<std::string> og_has_been_broadcast;
result.vars_.resize(places_.size());
// We cannot invoke resize. It is a bug of GCC 4.8
result.vars_ = std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size());
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
...@@ -147,15 +151,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -147,15 +151,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (vars.empty()) { // This device has no data. continue. if (vars.empty()) { // This device has no data. continue.
continue; continue;
} }
auto *prev_grad = &vars[vars.size() - 1]; auto &prev_grad = vars[vars.size() - 1];
op_handle->AddInput(prev_grad); op_handle->AddInput(prev_grad.get());
auto &var = vars[vars.size()]; vars.emplace_back(new VarHandle);
var.place_ = p; auto &var = vars.back();
var.name_ = og; var->place_ = p;
var.version_ = vars.size() - 1; var->name_ = og;
var->version_ = vars.size() - 1;
op_handle->AddOutput(&var); op_handle->AddOutput(var.get());
} }
#else #else
PADDLE_ENFORCE("Not implemented"); PADDLE_ENFORCE("Not implemented");
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <map> #include <map>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/var_handle.h"
...@@ -24,7 +26,9 @@ namespace framework { ...@@ -24,7 +26,9 @@ namespace framework {
namespace details { namespace details {
struct SSAGraph { struct SSAGraph {
std::vector<std::unordered_map<std::string, std::map<int, VarHandle>>> vars_; std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
vars_;
// aux variables to represent dependency. Useful to resolve data hazard. // aux variables to represent dependency. Useful to resolve data hazard.
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_; std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
std::vector<std::unique_ptr<OpHandleBase>> ops_; std::vector<std::unique_ptr<OpHandleBase>> ops_;
......
...@@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { ...@@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
auto it_old = name_pair.second.rbegin(); auto it_old = name_pair.second.rbegin();
++it_old; ++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = it_new->second.generated_op_; auto *write_op = (*it_new)->generated_op_;
auto &read_ops = it_old->second.pending_ops_; auto &read_ops = (*it_old)->pending_ops_;
for (auto *read_op : read_ops) { for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op; // Manually add a dependency var from read_op to write_op;
...@@ -54,14 +54,15 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ...@@ -54,14 +54,15 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
auto &var_holder = var_holders[each_var_name]; auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr; VarHandle *var = nullptr;
if (var_holder.empty()) { if (var_holder.empty()) {
var_holder.emplace_back(new VarHandle);
auto &init_var = var_holder[0]; auto &init_var = var_holder[0];
init_var.place_ = place; init_var->place_ = place;
init_var.name_ = each_var_name; init_var->name_ = each_var_name;
init_var.generated_op_ = nullptr; init_var->generated_op_ = nullptr;
init_var.version_ = 0; init_var->version_ = 0;
var = &init_var; var = init_var.get();
} else { } else {
var = &var_holder.rbegin()->second; var = var_holder.rbegin()->get();
} }
return var; return var;
} }
...@@ -72,11 +73,12 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, ...@@ -72,11 +73,12 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
size_t place_offset) { size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name]; auto &vars = graph->vars_[place_offset][each_var_name];
size_t version = vars.size(); size_t version = vars.size();
auto &var = vars[version]; vars.emplace_back(new VarHandle());
var.version_ = version; auto &var = vars.back();
var.name_ = each_var_name; var->version_ = version;
var.place_ = place; var->name_ = each_var_name;
op_handle->AddOutput(&var); var->place_ = place;
op_handle->AddOutput(var.get());
} }
template <typename Callback> template <typename Callback>
...@@ -84,7 +86,7 @@ void IterAllVar(const SSAGraph &graph, Callback callback) { ...@@ -84,7 +86,7 @@ void IterAllVar(const SSAGraph &graph, Callback callback) {
for (auto &each : graph.vars_) { for (auto &each : graph.vars_) {
for (auto &pair1 : each) { for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) { for (auto &pair2 : pair1.second) {
callback(pair2.second); callback(*pair2);
} }
} }
} }
......
...@@ -69,7 +69,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -69,7 +69,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->vars_) { for (auto &var_map : graph_->vars_) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
InsertPendingVar(version_pair.second); InsertPendingVar(*version_pair);
} }
} }
} }
...@@ -95,7 +95,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -95,7 +95,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->vars_) { for (auto &var_map : graph_->vars_) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second); fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册