提交 2eeaa8d5 编写于 作者: X Xin Pan

Graph in ParallelExecutor Builder

上级 7781297c
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -66,11 +67,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -66,11 +67,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
} }
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, const OpDesc &op,
const OpDesc &op,
size_t place_id) const { size_t place_id) const {
auto p = places_[place_id]; auto p = places_[place_id];
auto *op_handle = result->ops_.back().get(); auto *op_handle =
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
...@@ -169,18 +170,21 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -169,18 +170,21 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
std::unique_ptr<Graph> graph(new Graph);
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
all_vars_.emplace(var->Name(), var); all_vars_.emplace(var->Name(), var);
} }
auto graph = new SSAGraph(); Graph &result = *graph;
SSAGraph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast; std::unordered_set<std::string> og_has_been_broadcast;
// We cannot invoke resize. It is a bug of GCC 4.8 // We cannot invoke resize. It is a bug of GCC 4.8
result.vars_ = std::vector< result.attrs["vars"] = new std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>( std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size()); places_.size());
result.attrs["dep_vars"] =
new std::unordered_set<std::unique_ptr<VarHandleBase>>();
result.attrs["ops"] = new std::vector<std::unique_ptr<OpHandleBase>>();
// find send/recv vars so that we can place the distributed training // find send/recv vars so that we can place the distributed training
// realted op in the place 0 // realted op in the place 0
...@@ -303,7 +307,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -303,7 +307,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
return std::unique_ptr<SSAGraph>(graph); std::unique_ptr<SSAGraph> ssa_graph(new SSAGraph);
ssa_graph->vars_ =
std::move(*boost::any_cast<GraphVars *>(graph->attrs["vars"]));
ssa_graph->ops_ =
std::move(*boost::any_cast<GraphOps *>(graph->attrs["ops"]));
ssa_graph->dep_vars_ =
std::move(*boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"]));
return std::move(ssa_graph);
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
...@@ -327,7 +339,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( ...@@ -327,7 +339,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
#endif #endif
} }
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
const std::string &p_name, const std::string &p_name,
size_t src_dev_id) const { size_t src_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -336,42 +348,50 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, ...@@ -336,42 +348,50 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_); auto *op_handle = new BroadcastOpHandle(local_scopes_, places_);
#endif #endif
result->ops_.emplace_back(op_handle); boost::any_cast<GraphOps *>(result->attrs["ops"])->emplace_back(op_handle);
auto *in = result->vars_.at(src_dev_id).at(p_name).back().get(); auto *in = boost::any_cast<GraphVars *>(result->attrs["vars"])
->at(src_dev_id)
.at(p_name)
.back()
.get();
op_handle->AddInput(in); op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->vars_.at(i).at(p_name); auto &vars =
boost::any_cast<GraphVars *>(result->attrs["vars"])->at(i).at(p_name);
auto *out_var = new VarHandle(vars.size(), i, p_name, p); auto *out_var = new VarHandle(vars.size(), i, p_name, p);
vars.emplace_back(out_var); vars.emplace_back(out_var);
op_handle->AddOutput(out_var); op_handle->AddOutput(out_var);
} }
} }
void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
const OpDesc &op, const OpDesc &op,
int dev_id) const { int dev_id) const {
result->ops_.emplace_back( boost::any_cast<GraphOps *>(result->attrs["ops"])
new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); ->emplace_back(
new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, op, dev_id); CreateOpHandleIOs(result, op, dev_id);
} }
void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
const std::string &og) const { const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back( boost::any_cast<GraphOps *>(result->attrs["ops"])
new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); ->emplace_back(new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else #else
result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_)); boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new AllReduceOpHandle(local_scopes_, places_));
#endif #endif
auto *op_handle = result->ops_.back().get(); auto *op_handle =
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og]; auto &vars = (*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
...@@ -383,19 +403,23 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, ...@@ -383,19 +403,23 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
} }
void MultiDevSSAGraphBuilder::InsertDataBalanceOp( void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
SSAGraph *result, const std::vector<std::string> &datas) const { Graph *result, const std::vector<std::string> &datas) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back( boost::any_cast<GraphOps *>(result->attrs["ops"])
new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); ->emplace_back(
new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else #else
result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_)); boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new DataBalanceOpHandle(local_scopes_, places_));
#endif #endif
auto *op_handle = result->ops_.back().get(); auto *op_handle =
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) { for (const std::string &d_name : datas) {
auto &vars = result->vars_[i][d_name]; auto &vars =
(*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][d_name];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get()); op_handle->AddInput(vars.back().get());
auto var = new VarHandle(vars.size(), i, d_name, p); auto var = new VarHandle(vars.size(), i, d_name, p);
...@@ -441,7 +465,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { ...@@ -441,7 +465,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
return got == var_name_on_devices_.end() ? -1 : got->second; return got == var_name_on_devices_.end() ? -1 : got->second;
} }
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -456,7 +480,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { ...@@ -456,7 +480,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
auto *op_handle = auto *op_handle =
new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i], new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i],
places_[i], communication_dev_ctx); places_[i], communication_dev_ctx);
result->ops_.emplace_back(op_handle); boost::any_cast<GraphOps *>(result->attrs["ops"])->emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale // FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators. // factor. So it does not depend on any other operators.
...@@ -469,37 +493,41 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { ...@@ -469,37 +493,41 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
} }
} }
void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
const OpDesc &op, const OpDesc &op,
size_t num_places) const { size_t num_places) const {
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx]; auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx]; auto s = local_scopes_[scope_idx];
result->ops_.emplace_back(new ComputationOpHandle(op, s, p)); boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new ComputationOpHandle(op, s, p));
CreateOpHandleIOs(result, op, scope_idx); CreateOpHandleIOs(result, op, scope_idx);
} }
} }
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
const std::string &og, const std::string &og,
int dst_dev_id) const { int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back( boost::any_cast<GraphOps *>(result->attrs["ops"])
new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); ->emplace_back(new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else #else
result->ops_.emplace_back(new ReduceOpHandle(local_scopes_, places_)); boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new ReduceOpHandle(local_scopes_, places_));
#endif #endif
auto *op_handle = result->ops_.back().get(); auto *op_handle =
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og]; auto &vars = (*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
} }
auto &vars = result->vars_[dst_dev_id][og]; auto &vars =
(*boost::any_cast<GraphVars *>(result->attrs["vars"]))[dst_dev_id][og];
auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]); auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
...@@ -508,19 +536,20 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, ...@@ -508,19 +536,20 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
// Find the first occurence of `prev_op_name` and make current `op` depend // Find the first occurence of `prev_op_name` and make current `op` depend
// on it. // on it.
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->ops_) { for (auto &prev_op : (*boost::any_cast<GraphOps *>(result->attrs["ops"]))) {
if (prev_op->Name() == prev_op_name) { if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle(); auto *dep_var = new DummyVarHandle();
prev_op->AddOutput(dep_var); prev_op->AddOutput(dep_var);
result->dep_vars_.emplace(dep_var); boost::any_cast<GraphDepVars *>(result->attrs["dep_vars"])
->emplace(dep_var);
op->AddInput(dep_var); op->AddInput(dep_var);
} }
} }
} }
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
const OpDesc &op) const { const OpDesc &op) const {
int op_dev_id = -1; int op_dev_id = -1;
if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") { if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") {
...@@ -550,12 +579,14 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, ...@@ -550,12 +579,14 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
CreateComputationalOp(result, op, op_dev_id); CreateComputationalOp(result, op, op_dev_id);
if (op.Type() == "concat") { if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); ConnectOp(result,
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
"fetch_barrier");
} }
} }
// Create RPC related op handles that connects its in ops and out ops. // Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result,
const OpDesc &op) const { const OpDesc &op) const {
int op_dev_id = -1; int op_dev_id = -1;
if (op.Type() == "send") { if (op.Type() == "send") {
...@@ -584,15 +615,22 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, ...@@ -584,15 +615,22 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
op.Type()); op.Type());
result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], boost::any_cast<GraphOps *>(result->attrs["ops"])
op.Type(), places_[op_dev_id])); ->emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], op.Type(),
places_[op_dev_id]));
if (op.Type() == "send_barrier") { if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send"); ConnectOp(result,
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
"send");
} else if (op.Type() == "recv") { } else if (op.Type() == "recv") {
ConnectOp(result, result->ops_.back().get(), "send_barrier"); ConnectOp(result,
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
"send_barrier");
} else if (op.Type() == "fetch_barrier") { } else if (op.Type() == "fetch_barrier") {
ConnectOp(result, result->ops_.back().get(), "recv"); ConnectOp(result,
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
"recv");
} else if (op.Type() == "send") { } else if (op.Type() == "send") {
// do nothing // do nothing
} else { } else {
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -50,7 +51,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -50,7 +51,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
int GetVarDeviceID(const std::string &varname) const override; int GetVarDeviceID(const std::string &varname) const override;
private: private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, void CreateOpHandleIOs(Graph *result, const OpDesc &op,
size_t device_id) const; size_t device_id) const;
private: private:
...@@ -65,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -65,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; void CreateRPCOp(Graph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const; void CreateDistTrainOp(Graph *result, const OpDesc &op) const;
/** /**
* Is this operator as the end-point operator before/after send operator. * Is this operator as the end-point operator before/after send operator.
...@@ -81,17 +82,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -81,17 +82,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::vector<std::string> FindDistTrainRecvVars( std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const; const ProgramDesc &program) const;
void ConnectOp(SSAGraph *result, OpHandleBase *op, void ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const; const std::string &prev_op_name) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op, void CreateComputationalOps(Graph *result, const OpDesc &op,
size_t num_places) const; size_t num_places) const;
void CreateScaleLossGradOp(SSAGraph *result) const; void CreateScaleLossGradOp(Graph *result) const;
VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og, VarHandle *CreateReduceOp(Graph *result, const std::string &og,
int dst_dev_id) const; int dst_dev_id) const;
void CreateComputationalOp(SSAGraph *result, const OpDesc &op, void CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const;
int dev_id) const;
bool IsParameterGradientOnce( bool IsParameterGradientOnce(
const std::string &og, const std::string &og,
...@@ -99,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -99,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
int GetOpDeviceID(const OpDesc &op) const; int GetOpDeviceID(const OpDesc &op) const;
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertAllReduceOp(Graph *result, const std::string &og) const;
void InsertDataBalanceOp(SSAGraph *result, void InsertDataBalanceOp(Graph *result,
const std::vector<std::string> &datas) const; const std::vector<std::string> &datas) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, void CreateBroadcastOp(Graph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
for (auto &var_map : graph->vars_) { for (auto &var_map : *boost::any_cast<GraphVars *>(graph->attrs["vars"])) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) { if (name_pair.second.size() <= 1) {
continue; continue;
...@@ -40,7 +40,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { ...@@ -40,7 +40,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
auto *dep_var = new DummyVarHandle(); auto *dep_var = new DummyVarHandle();
read_op->AddOutput(dep_var); read_op->AddOutput(dep_var);
write_op->AddInput(dep_var); write_op->AddInput(dep_var);
graph->dep_vars_.emplace(dep_var); boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"])
->emplace(dep_var);
} }
} }
} }
...@@ -48,9 +49,10 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { ...@@ -48,9 +49,10 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
} }
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
SSAGraph *graph, const std::string &each_var_name, Graph *graph, const std::string &each_var_name,
const platform::Place &place, size_t place_offset) { const platform::Place &place, size_t place_offset) {
auto &var_holders = graph->vars_[place_offset]; auto &var_holders =
(*boost::any_cast<GraphVars *>(graph->attrs["vars"]))[place_offset];
auto &var_holder = var_holders[each_var_name]; auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr; VarHandle *var = nullptr;
if (var_holder.empty()) { if (var_holder.empty()) {
...@@ -62,24 +64,29 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ...@@ -62,24 +64,29 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
return var; return var;
} }
void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name, const std::string &each_var_name,
const platform::Place &place, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name]; auto &vars =
(*boost::any_cast<GraphVars *>(graph->attrs["vars"]))[place_offset]
[each_var_name];
size_t version = vars.size(); size_t version = vars.size();
auto var = new VarHandle(version, place_offset, each_var_name, place); auto var = new VarHandle(version, place_offset, each_var_name, place);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
for (auto &op : graph->ops_) { GraphOps &all_ops = *boost::any_cast<GraphOps *>(graph->attrs["ops"]);
for (auto &op : all_ops) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
continue; continue;
} }
auto *dummy_leaf = new DummyVarHandle(); auto *dummy_leaf = new DummyVarHandle();
graph->dep_vars_.emplace(dummy_leaf); boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"])
->emplace(dummy_leaf);
op->AddOutput(dummy_leaf); op->AddOutput(dummy_leaf);
} }
} }
......
...@@ -16,15 +16,24 @@ ...@@ -16,15 +16,24 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars;
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
class SSAGraphBuilder { class SSAGraphBuilder {
public: public:
SSAGraphBuilder() {} SSAGraphBuilder() {}
...@@ -42,20 +51,20 @@ class SSAGraphBuilder { ...@@ -42,20 +51,20 @@ class SSAGraphBuilder {
* *
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/ */
static void PolishGraphToSupportDataHazards(SSAGraph *graph); static void PolishGraphToSupportDataHazards(Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph, static VarHandle *CreateOrGetLatestVarHandle(Graph *graph,
const std::string &each_var_name, const std::string &each_var_name,
const platform::Place &place, const platform::Place &place,
size_t place_offset); size_t place_offset);
// Add an output variable (each_var_name, place, place_offset) to op_handle, // Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph // which belongs to graph
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name, const std::string &each_var_name,
const platform::Place &place, size_t place_offset); const platform::Place &place, size_t place_offset);
static void AddOutputToLeafOps(SSAGraph *graph); static void AddOutputToLeafOps(Graph *graph);
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -27,7 +27,7 @@ namespace framework { ...@@ -27,7 +27,7 @@ namespace framework {
class Graph { class Graph {
public: public:
std::map<std::string, std::vector<boost::any>> attrs; std::map<std::string, boost::any> attrs;
std::vector<Node *> inputs; std::vector<Node *> inputs;
std::vector<Node *> outputs; std::vector<Node *> outputs;
......
...@@ -14,6 +14,27 @@ limitations under the License. */ ...@@ -14,6 +14,27 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle { namespace paddle {
namespace framework {} // namespace framework namespace framework {
class Pass {
public:
Pass() = default;
virtual ~Pass() {}
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) {
return std::move(graph);
}
};
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc& program) {
std::unique_ptr<Graph> g(new Graph);
return std::move(g);
}
} // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册