提交 2fa8df1c 编写于 作者: X Xin Pan

separate graph building pass and graph-based pe builder

上级 37e51443
...@@ -96,7 +96,7 @@ struct TestBroadcastOpHandle { ...@@ -96,7 +96,7 @@ struct TestBroadcastOpHandle {
} }
param_scopes_[input_scope_idx]->Var("input"); param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n(new ir::Node(ir::Node::Type::kOperation)); std::unique_ptr<ir::Node> n(new ir::Node());
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
...@@ -114,7 +114,7 @@ struct TestBroadcastOpHandle { ...@@ -114,7 +114,7 @@ struct TestBroadcastOpHandle {
#endif #endif
} }
std::unique_ptr<ir::Node> v(new ir::Node(ir::Node::Type::kVariable)); std::unique_ptr<ir::Node> v(new ir::Node());
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
gpu_list_[input_scope_idx]); gpu_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
...@@ -122,7 +122,7 @@ struct TestBroadcastOpHandle { ...@@ -122,7 +122,7 @@ struct TestBroadcastOpHandle {
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v2(new ir::Node(ir::Node::Type::kVariable)); std::unique_ptr<ir::Node> v2(new ir::Node());
vars_.emplace_back(new DummyVarHandle(v2.get())); vars_.emplace_back(new DummyVarHandle(v2.get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
...@@ -133,7 +133,7 @@ struct TestBroadcastOpHandle { ...@@ -133,7 +133,7 @@ struct TestBroadcastOpHandle {
if (!use_gpu_) { if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
} }
std::unique_ptr<ir::Node> v3(new ir::Node(ir::Node::Type::kVariable)); std::unique_ptr<ir::Node> v3(new ir::Node());
VarHandle* out_var_handle = VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
...@@ -141,7 +141,7 @@ struct TestBroadcastOpHandle { ...@@ -141,7 +141,7 @@ struct TestBroadcastOpHandle {
} }
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v4(new ir::Node(ir::Node::Type::kVariable)); std::unique_ptr<ir::Node> v4(new ir::Node());
vars_.emplace_back(new DummyVarHandle(v4.get())); vars_.emplace_back(new DummyVarHandle(v4.get()));
DummyVarHandle* out_dummy_var_handle = DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
......
...@@ -82,13 +82,13 @@ struct TestGatherOpHandle { ...@@ -82,13 +82,13 @@ struct TestGatherOpHandle {
} }
param_scopes_[input_scope_idx]->Var("out"); param_scopes_[input_scope_idx]->Var("out");
nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); nodes.emplace_back(new ir::Node());
op_handle_.reset( op_handle_.reset(
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); nodes.emplace_back(new ir::Node());
auto* in_var_handle = auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
...@@ -96,7 +96,7 @@ struct TestGatherOpHandle { ...@@ -96,7 +96,7 @@ struct TestGatherOpHandle {
} }
// add dummy var // add dummy var
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); nodes.emplace_back(new ir::Node());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* in_dummy_var_handle = DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
...@@ -104,14 +104,14 @@ struct TestGatherOpHandle { ...@@ -104,14 +104,14 @@ struct TestGatherOpHandle {
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); nodes.emplace_back(new ir::Node());
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
"out", gpu_list_[input_scope_idx]); "out", gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
// add dummy var // add dummy var
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); nodes.emplace_back(new ir::Node());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back().get());
......
...@@ -67,30 +67,31 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -67,30 +67,31 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
} }
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, const OpDesc &op, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node,
size_t place_id) const { size_t place_id) const {
auto p = places_[place_id]; auto p = places_[place_id];
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
for (auto &each_var_name : op.InputArgumentNames()) { for (ir::Node *input : node->inputs) {
VarHandle *var = VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id);
CreateOrGetLatestVarHandle(result, each_var_name, p, place_id);
op_handle->AddInput(var); op_handle->AddInput(var);
} }
for (auto &each_var_name : op.OutputArgumentNames()) { for (ir::Node *output : node->outputs) {
CreateOpOutput(result, op_handle, each_var_name, p, place_id); CreateOpOutput(result, op_handle, output, p, place_id);
} }
} }
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars( std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
const ProgramDesc &program) const { const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
std::vector<std::string> send_vars; std::vector<std::string> send_vars;
// since parameters are all in block 0, // since parameters are all in block 0,
// it's enough to only scan send ops in block 0 // it's enough to only scan send ops in block 0
for (auto *op : program.Block(0).AllOps()) { for (auto &node : nodes) {
if (!node->Op()) continue;
OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find send op, // TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string // instead of the the hard code string
if (op->Type() == "send") { if (op->Type() == "send") {
...@@ -104,9 +105,11 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars( ...@@ -104,9 +105,11 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
} }
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars( std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const ProgramDesc &program) const { const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
std::vector<std::string> recv_vars; std::vector<std::string> recv_vars;
for (auto *op : program.Block(0).AllOps()) { for (auto &node : nodes) {
if (!node->Op()) continue;
OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find recv op, // TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string // instead of the hard code string
if (op->Type() == "recv") { if (op->Type() == "recv") {
...@@ -120,7 +123,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars( ...@@ -120,7 +123,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
} }
bool MultiDevSSAGraphBuilder::IsDistTrainOp( bool MultiDevSSAGraphBuilder::IsDistTrainOp(
const OpDesc &op, const std::vector<std::string> &send_vars, ir::Node *node, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const { const std::vector<std::string> &recv_vars) const {
if (send_vars.size() == 0 || recv_vars.size() == 0) { if (send_vars.size() == 0 || recv_vars.size() == 0) {
return false; return false;
...@@ -143,8 +146,17 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( ...@@ -143,8 +146,17 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
return false; return false;
}; };
return checker(op.OutputArgumentNames(), send_vars) || std::vector<std::string> input_var_names;
checker(op.InputArgumentNames(), recv_vars); std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Var()->Name());
}
for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Var()->Name());
}
return checker(output_var_names, send_vars) ||
checker(input_var_names, recv_vars);
} }
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
...@@ -167,11 +179,16 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -167,11 +179,16 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
return dev_id; return dev_id;
} }
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
std::unique_ptr<Graph> graph) const { std::unique_ptr<Graph> graph) const {
const ProgramDesc &program = graph->Program(); auto nodes = std::move(graph->nodes);
for (auto *var : program.Block(0).AllVars()) { graph->nodes.clear();
all_vars_.emplace(var->Name(), var); LOG(ERROR) << "origin nodes count " << nodes.size();
for (auto &node : nodes) {
if (node->Var()) {
all_vars_.emplace(node->Var()->Name(), node->Var());
}
} }
Graph &result = *graph; Graph &result = *graph;
...@@ -181,10 +198,11 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Build( ...@@ -181,10 +198,11 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Build(
result.Set("vars", new GraphVars(places_.size())); result.Set("vars", new GraphVars(places_.size()));
result.Set("dep_vars", new GraphDepVars); result.Set("dep_vars", new GraphDepVars);
result.Set("ops", new GraphOps); result.Set("ops", new GraphOps);
// find send/recv vars so that we can place the distributed training // find send/recv vars so that we can place the distributed training
// realted op in the place 0 // realted op in the place 0
auto send_vars = FindDistTrainSendVars(program); auto send_vars = FindDistTrainSendVars(nodes);
auto recv_vars = FindDistTrainRecvVars(program); auto recv_vars = FindDistTrainRecvVars(nodes);
std::vector<std::unordered_set<std::string>> bcast_var_name_set; std::vector<std::unordered_set<std::string>> bcast_var_name_set;
bcast_var_name_set.resize(places_.size()); bcast_var_name_set.resize(places_.size());
...@@ -192,14 +210,16 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Build( ...@@ -192,14 +210,16 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Build(
size_t cur_device_id = 0; size_t cur_device_id = 0;
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { // TODO(panyx0718): FIXME: nodes should be sorted by "program" order.
for (auto &node : nodes) {
if (!node->Op()) continue;
if (boost::get<int>( if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
CreateRPCOp(&result, *op); CreateRPCOp(&result, node.get());
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) { } else if (IsDistTrainOp(node.get(), send_vars, recv_vars)) {
CreateDistTrainOp(&result, *op); CreateDistTrainOp(&result, node.get());
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(node.get())) {
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ != if (strategy_.gradient_scale_ !=
BuildStrategy::GradientScaleStrategy::kCustomized) { BuildStrategy::GradientScaleStrategy::kCustomized) {
...@@ -211,33 +231,35 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Build( ...@@ -211,33 +231,35 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Build(
// the block. // the block.
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(*op); int op_dev_id = GetOpDeviceID(node.get());
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, *op, op_dev_id); CreateComputationalOp(&result, node.get(), op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) { for (ir::Node *n : node->outputs) {
var_name_on_devices_.emplace(var_name, op_dev_id); var_name_on_devices_.emplace(n->Var()->Name(), op_dev_id);
} }
} else { } else {
// This op runs on all devices, and its output may have parameter's // This op runs on all devices, and its output may have parameter's
// gradients. // gradients.
if (op->Type() == "read" && strategy_.enable_data_balance_) { if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
op->SetAttr("throw_eof_exp", false); node->Op()->SetAttr("throw_eof_exp", false);
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, node.get(), places_.size());
const auto &data_var_names = op->Output("Out"); // TODO(panyx0718): builder shouldn't depend on the out logic of
// a specific op.
const auto &data_var_names = node->Op()->Output("Out");
InsertDataBalanceOp(&result, data_var_names); InsertDataBalanceOp(&result, data_var_names);
} else { } else {
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, node.get(), places_.size());
} }
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
if (static_cast<bool>(boost::get<int>(op->GetAttr( if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) & OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward))) { static_cast<int>(OpRole::kBackward))) {
try { try {
auto backward_vars = auto backward_vars = boost::get<std::vector<std::string>>(
boost::get<std::vector<std::string>>(op->GetNullableAttr( node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName())); OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
...@@ -328,13 +350,12 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( ...@@ -328,13 +350,12 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
const std::string &p_name, const std::string &p_name,
size_t src_dev_id) const { size_t src_dev_id) const {
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *op_handle = new BroadcastOpHandle(result->nodes.back().get(), auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr),
local_scopes_, places_, nccl_ctxs_); local_scopes_, places_, nccl_ctxs_);
#else #else
auto *op_handle = auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr),
new BroadcastOpHandle(result->nodes.back().get(), local_scopes_, places_); local_scopes_, places_);
#endif #endif
result->Get<GraphOps>("ops").emplace_back(op_handle); result->Get<GraphOps>("ops").emplace_back(op_handle);
...@@ -345,33 +366,31 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, ...@@ -345,33 +366,31 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable));
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name); auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
auto *out_var = auto *out_var =
new VarHandle(result->nodes.back().get(), vars.size(), i, p_name, p); new VarHandle(result->CreateVarNode(p_name), vars.size(), i, p_name, p);
vars.emplace_back(out_var); vars.emplace_back(out_var);
op_handle->AddOutput(out_var); op_handle->AddOutput(out_var);
} }
} }
void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
const OpDesc &op, ir::Node *node,
int dev_id) const { int dev_id) const {
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); result->Get<GraphOps>("ops").emplace_back(
result->Get<GraphOps>("ops").emplace_back(new ComputationOpHandle( new ComputationOpHandle(result->CreateOpNode(node->Op()), *node->Op(),
result->nodes.back().get(), op, local_scopes_[dev_id], places_[dev_id])); local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, op, dev_id); CreateOpHandleIOs(result, node, dev_id);
} }
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
const std::string &og) const { const std::string &og) const {
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle( result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle( result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
result->nodes.back().get(), local_scopes_, places_)); result->CreateOpNode(nullptr), local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
...@@ -383,8 +402,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, ...@@ -383,8 +402,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); auto var = new VarHandle(result->CreateVarNode(og), vars.size(), i, og, p);
auto var = new VarHandle(result->nodes.back().get(), vars.size(), i, og, p);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -392,13 +410,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, ...@@ -392,13 +410,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
void MultiDevSSAGraphBuilder::InsertDataBalanceOp( void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
Graph *result, const std::vector<std::string> &datas) const { Graph *result, const std::vector<std::string> &datas) const {
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle( result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle( result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
result->nodes.back().get(), local_scopes_, places_)); result->CreateOpNode(nullptr), local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
...@@ -408,9 +425,8 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ...@@ -408,9 +425,8 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
auto &vars = result->Get<GraphVars>("vars")[i][d_name]; auto &vars = result->Get<GraphVars>("vars")[i][d_name];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get()); op_handle->AddInput(vars.back().get());
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); auto var = new VarHandle(result->CreateVarNode(d_name), vars.size(), i,
auto var = d_name, p);
new VarHandle(result->nodes.back().get(), vars.size(), i, d_name, p);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -429,17 +445,17 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( ...@@ -429,17 +445,17 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once; return is_pg_once;
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const { int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1; return -1;
} }
int op_role = boost::get<int>( int op_role = boost::get<int>(
op.GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName())); node->Op()->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
if (op_role != static_cast<int>(framework::OpRole::kOptimize)) { if (op_role != static_cast<int>(framework::OpRole::kOptimize)) {
return -1; return -1;
} }
auto param_grad = boost::get<std::vector<std::string>>( auto param_grad = boost::get<std::vector<std::string>>(
op.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U); PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1]); int dev_id = GetVarDeviceID(param_grad[1]);
...@@ -464,9 +480,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { ...@@ -464,9 +480,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
auto *communication_dev_ctx = auto *communication_dev_ctx =
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif #endif
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
auto *op_handle = new ScaleLossGradOpHandle( auto *op_handle = new ScaleLossGradOpHandle(
result->nodes.back().get(), local_scopes_.size(), local_scopes_[i], result->CreateOpNode(nullptr), local_scopes_.size(), local_scopes_[i],
places_[i], communication_dev_ctx); places_[i], communication_dev_ctx);
result->Get<GraphOps>("ops").emplace_back(op_handle); result->Get<GraphOps>("ops").emplace_back(op_handle);
...@@ -476,34 +491,38 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { ...@@ -476,34 +491,38 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
// loss->pending_ops_.emplace_back(op_handle); // loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss); // op_handle->inputs_.emplace_back(loss);
CreateOpOutput(result, op_handle, GradVarName(loss_var_name_), places_[i], // TODO(panyx0718): GradVarName(loss_var_name_)
i); const std::string grad_var_name = GradVarName(loss_var_name_);
auto &vars = result->Get<GraphVars>("vars")[i][grad_var_name];
size_t version = vars.size();
auto var = new VarHandle(result->CreateVarNode(grad_var_name), version, i,
grad_var_name, places_[i]);
vars.emplace_back(var);
op_handle->AddOutput(var);
} }
} }
void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
const OpDesc &op, ir::Node *node,
size_t num_places) const { size_t num_places) const {
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx]; auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx]; auto s = local_scopes_[scope_idx];
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); result->Get<GraphOps>("ops").emplace_back(new ComputationOpHandle(
result->Get<GraphOps>("ops").emplace_back( result->CreateOpNode(node->Op()), *node->Op(), s, p));
new ComputationOpHandle(result->nodes.back().get(), op, s, p)); CreateOpHandleIOs(result, node, scope_idx);
CreateOpHandleIOs(result, op, scope_idx);
} }
} }
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
const std::string &og, const std::string &og,
int dst_dev_id) const { int dst_dev_id) const {
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation));
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle( result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_));
#else #else
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
new ReduceOpHandle(result->nodes.back().get(), local_scopes_, places_)); result->CreateOpNode(nullptr), local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
...@@ -516,8 +535,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, ...@@ -516,8 +535,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
} }
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og]; auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); auto var = new VarHandle(result->CreateVarNode(og), vars.size(), dst_dev_id,
auto var = new VarHandle(result->nodes.back().get(), vars.size(), dst_dev_id,
og, places_[dst_dev_id]); og, places_[dst_dev_id]);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
...@@ -530,8 +548,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, ...@@ -530,8 +548,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->Get<GraphOps>("ops")) { for (auto &prev_op : result->Get<GraphOps>("ops")) {
if (prev_op->Name() == prev_op_name) { if (prev_op->Name() == prev_op_name) {
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); auto *dep_var = new DummyVarHandle(result->CreateVarNode("dummy"));
auto *dep_var = new DummyVarHandle(result->nodes.back().get());
prev_op->AddOutput(dep_var); prev_op->AddOutput(dep_var);
result->Get<GraphDepVars>("dep_vars").emplace(dep_var); result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
op->AddInput(dep_var); op->AddInput(dep_var);
...@@ -540,22 +557,32 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, ...@@ -540,22 +557,32 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
} }
void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
const OpDesc &op) const { ir::Node *node) const {
int op_dev_id = -1; int op_dev_id = -1;
if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") { std::vector<std::string> input_var_names;
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Var()->Name());
}
for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Var()->Name());
}
if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows") {
op_dev_id = GetVarDeviceID(input_var_names[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : op.InputArgumentNames()) { for (auto &varname : input_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); var_name_on_devices_.emplace(varname, op_dev_id);
} }
} }
for (auto &varname : op.OutputArgumentNames()) { for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); var_name_on_devices_.emplace(varname, op_dev_id);
} }
} else if (op.Type() == "concat") { } else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); op_dev_id = GetVarDeviceID(input_var_names[0]);
for (auto &varname : op.OutputArgumentNames()) { for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); var_name_on_devices_.emplace(varname, op_dev_id);
} }
} else { } else {
...@@ -565,35 +592,43 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, ...@@ -565,35 +592,43 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
} }
PADDLE_ENFORCE(op_dev_id != -1, PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place for distributed op: %s", op.Type()); "can not find right place for distributed op: %s",
node->Op()->Type());
CreateComputationalOp(result, op, op_dev_id); CreateComputationalOp(result, node, op_dev_id);
if (op.Type() == "concat") { if (node->Op()->Type() == "concat") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
"fetch_barrier"); "fetch_barrier");
} }
} }
// Create RPC related op handles that connects its in ops and out ops. // Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const {
const OpDesc &op) const {
int op_dev_id = -1; int op_dev_id = -1;
if (op.Type() == "send") { if (node->Op()->Type() == "send") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); op_dev_id = GetVarDeviceID(node->inputs[0]->Var()->Name());
// the variable name which contains .block means it was splited by // the variable name which contains .block means it was splited by
// split_byref op // split_byref op
// so that we can balance the variable blocks to all the pserver // so that we can balance the variable blocks to all the pserver
// instances. // instances.
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
op.InputArgumentNames()[0].find(".block") == std::string::npos) { node->inputs[0]->Var()->Name().find(".block") == std::string::npos) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); std::vector<std::string> input_var_names;
for (auto &varname : op.InputArgumentNames()) { for (ir::Node *n : node->inputs) {
input_var_names.push_back(n->Var()->Name());
}
op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); var_name_on_devices_.emplace(varname, op_dev_id);
} }
} }
} else if (op.Type() == "recv") { } else if (node->Op()->Type() == "recv") {
op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames()); std::vector<std::string> output_var_names;
for (auto &varname : op.OutputArgumentNames()) { for (ir::Node *n : node->outputs) {
output_var_names.push_back(n->Var()->Name());
}
op_dev_id = GetAppropriateDeviceID(output_var_names);
for (auto &varname : output_var_names) {
var_name_on_devices_.emplace(varname, op_dev_id); var_name_on_devices_.emplace(varname, op_dev_id);
} }
} else { } else {
...@@ -602,21 +637,20 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ...@@ -602,21 +637,20 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result,
} }
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
op.Type()); node->Op()->Type());
result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle(
result->Get<GraphOps>("ops").emplace_back( result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
new RPCOpHandle(result->nodes.back().get(), op, local_scopes_[op_dev_id], node->Op()->Type(), places_[op_dev_id]));
op.Type(), places_[op_dev_id]));
if (op.Type() == "send_barrier") { if (node->Op()->Type() == "send_barrier") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send"); ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send");
} else if (op.Type() == "recv") { } else if (node->Op()->Type() == "recv") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
"send_barrier"); "send_barrier");
} else if (op.Type() == "fetch_barrier") { } else if (node->Op()->Type() == "fetch_barrier") {
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "recv"); ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "recv");
} else if (op.Type() == "send") { } else if (node->Op()->Type() == "send") {
// do nothing // do nothing
} else { } else {
PADDLE_THROW( PADDLE_THROW(
...@@ -624,12 +658,12 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ...@@ -624,12 +658,12 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result,
"send, send_barrier. recv, fetch_barrier]"); "send, send_barrier. recv, fetch_barrier]");
} }
CreateOpHandleIOs(result, op, op_dev_id); CreateOpHandleIOs(result, node, op_dev_id);
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
return boost::get<int>( return boost::get<int>(
op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
(static_cast<int>(OpRole::kBackward) | (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss)) && static_cast<int>(OpRole::kLoss)) &&
!loss_var_name_.empty(); // If loss_var is empty. This is test mode !loss_var_name_.empty(); // If loss_var is empty. This is test mode
......
...@@ -46,13 +46,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -46,13 +46,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const BuildStrategy &strategy); const BuildStrategy &strategy);
#endif #endif
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override;
std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override;
int GetVarDeviceID(const std::string &varname) const override; int GetVarDeviceID(const std::string &varname) const override;
private: private:
void CreateOpHandleIOs(Graph *result, const OpDesc &op, void CreateOpHandleIOs(Graph *result, ir::Node *node, size_t device_id) const;
size_t device_id) const;
private: private:
std::string loss_var_name_; std::string loss_var_name_;
...@@ -64,40 +62,39 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -64,40 +62,39 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
platform::NCCLContextMap *nccl_ctxs_; platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(ir::Node *node) const;
void CreateRPCOp(Graph *result, const OpDesc &op) const; void CreateRPCOp(Graph *result, ir::Node *node) const;
void CreateDistTrainOp(Graph *result, const OpDesc &op) const; void CreateDistTrainOp(Graph *result, ir::Node *node) const;
/** /**
* Is this operator as the end-point operator before/after send operator. * Is this operator as the end-point operator before/after send operator.
*/ */
bool IsDistTrainOp(const OpDesc &op, bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const; const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars( std::vector<std::string> FindDistTrainSendVars(
const ProgramDesc &program) const; const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
std::vector<std::string> FindDistTrainRecvVars( std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const; const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
void ConnectOp(Graph *result, OpHandleBase *op, void ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const; const std::string &prev_op_name) const;
void CreateComputationalOps(Graph *result, const OpDesc &op, void CreateComputationalOps(Graph *result, ir::Node *node,
size_t num_places) const; size_t num_places) const;
void CreateScaleLossGradOp(Graph *result) const; void CreateScaleLossGradOp(Graph *result) const;
VarHandle *CreateReduceOp(Graph *result, const std::string &og, VarHandle *CreateReduceOp(Graph *result, const std::string &og,
int dst_dev_id) const; int dst_dev_id) const;
void CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const; void CreateComputationalOp(Graph *result, ir::Node *node, int dev_id) const;
bool IsParameterGradientOnce( bool IsParameterGradientOnce(
const std::string &og, const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const; std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID(const OpDesc &op) const; int GetOpDeviceID(ir::Node *node) const;
void InsertAllReduceOp(Graph *result, const std::string &og) const; void InsertAllReduceOp(Graph *result, const std::string &og) const;
......
...@@ -97,7 +97,7 @@ struct TestReduceOpHandle { ...@@ -97,7 +97,7 @@ struct TestReduceOpHandle {
} }
param_scopes_[out_scope_idx]->Var("out"); param_scopes_[out_scope_idx]->Var("out");
nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); nodes.emplace_back(new ir::Node());
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_, op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_,
...@@ -121,7 +121,7 @@ struct TestReduceOpHandle { ...@@ -121,7 +121,7 @@ struct TestReduceOpHandle {
if (!use_gpu_) { if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
} }
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); nodes.emplace_back(new ir::Node());
auto *in_var_handle = auto *in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
in_var_handle->ClearGeneratedOp(); in_var_handle->ClearGeneratedOp();
...@@ -137,7 +137,7 @@ struct TestReduceOpHandle { ...@@ -137,7 +137,7 @@ struct TestReduceOpHandle {
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); nodes.emplace_back(new ir::Node());
auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx, auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx,
"out", gpu_list_[out_scope_idx]); "out", gpu_list_[out_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
......
...@@ -37,8 +37,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { ...@@ -37,8 +37,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
continue; continue;
} }
graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); auto *dep_var = new DummyVarHandle(graph->CreateVarNode("dummy"));
auto *dep_var = new DummyVarHandle(graph->nodes.back().get());
read_op->AddOutput(dep_var); read_op->AddOutput(dep_var);
write_op->AddInput(dep_var); write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var); graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
...@@ -49,15 +48,14 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { ...@@ -49,15 +48,14 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
} }
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
Graph *graph, const std::string &each_var_name, Graph *graph, ir::Node *node, const platform::Place &place,
const platform::Place &place, size_t place_offset) { size_t place_offset) {
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset]; auto &var_holders = graph->Get<GraphVars>("vars")[place_offset];
auto &var_holder = var_holders[each_var_name]; auto &var_holder = var_holders[node->Var()->Name()];
VarHandle *var = nullptr; VarHandle *var = nullptr;
if (var_holder.empty()) { if (var_holder.empty()) {
graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset,
var = new VarHandle(graph->nodes.back().get(), 0, place_offset, node->Var()->Name(), place);
each_var_name, place);
var_holder.emplace_back(var); var_holder.emplace_back(var);
} else { } else {
var = var_holder.rbegin()->get(); var = var_holder.rbegin()->get();
...@@ -66,14 +64,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ...@@ -66,14 +64,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
} }
void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name, ir::Node *node,
const platform::Place &place, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &vars = graph->Get<GraphVars>("vars")[place_offset][each_var_name]; auto &vars = graph->Get<GraphVars>("vars")[place_offset][node->Var()->Name()];
size_t version = vars.size(); size_t version = vars.size();
graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); auto var = new VarHandle(graph->CreateVarNode(node->Var()), version,
auto var = new VarHandle(graph->nodes.back().get(), version, place_offset, place_offset, node->Var()->Name(), place);
each_var_name, place);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -85,8 +82,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { ...@@ -85,8 +82,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
continue; continue;
} }
graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); auto *dummy_leaf = new DummyVarHandle(graph->CreateVarNode("dummy"));
auto *dummy_leaf = new DummyVarHandle(graph->nodes.back().get());
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf); graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
op->AddOutput(dummy_leaf); op->AddOutput(dummy_leaf);
} }
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -34,11 +35,11 @@ typedef std::vector< ...@@ -34,11 +35,11 @@ typedef std::vector<
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars; typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps; typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
class SSAGraphBuilder { class SSAGraphBuilder : public ir::Pass {
public: public:
SSAGraphBuilder() {} SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const = 0; virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
...@@ -53,16 +54,15 @@ class SSAGraphBuilder { ...@@ -53,16 +54,15 @@ class SSAGraphBuilder {
*/ */
static void PolishGraphToSupportDataHazards(Graph *graph); static void PolishGraphToSupportDataHazards(Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, ir::Node *node,
const std::string &each_var_name,
const platform::Place &place, const platform::Place &place,
size_t place_offset); size_t place_offset);
// Add an output variable (each_var_name, place, place_offset) to op_handle, // Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph // which belongs to graph
static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle, static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name, ir::Node *node, const platform::Place &place,
const platform::Place &place, size_t place_offset); size_t place_offset);
static void AddOutputToLeafOps(Graph *graph); static void AddOutputToLeafOps(Graph *graph);
}; };
......
...@@ -28,10 +28,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { ...@@ -28,10 +28,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std::unique_ptr<SSAGraphBuilder>&& builder) std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {} : builder_(std::move(builder)) {}
std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override { std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
auto new_graph = builder_->Build(std::move(graph)); auto new_graph = builder_->Apply(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(new_graph.get())); PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return new_graph; return std::move(new_graph);
} }
int GetVarDeviceID(const std::string& var_name) const override { int GetVarDeviceID(const std::string& var_name) const override {
......
...@@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { ...@@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
stream_ptr_(std::move(sout)), stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {} stream_ref_(*stream_ptr_) {}
std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override { std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
auto new_graph = builder_->Build(std::move(graph)); auto new_graph = builder_->Apply(std::move(graph));
printer_->Print(*new_graph, stream_ref_); printer_->Print(*new_graph, stream_ref_);
return new_graph; return std::move(new_graph);
} }
int GetVarDeviceID(const std::string& var_name) const override { int GetVarDeviceID(const std::string& var_name) const override {
......
...@@ -13,12 +13,45 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,45 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc &program) { std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc &program) {
std::unique_ptr<Graph> graph(new Graph(program)); std::unique_ptr<Graph> graph(new Graph(program));
std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) {
all_vars.emplace(var->Name(), var);
}
for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = graph->CreateOpNode(op);
for (auto &each_var_name : op->InputArgumentNames()) {
ir::Node *var = nullptr;
if (all_vars.count(each_var_name) != 0) {
var = graph->CreateVarNode(all_vars.at(each_var_name));
} else {
var = graph->CreateVarNode(each_var_name);
}
node->inputs.push_back(var);
var->outputs.push_back(node);
}
for (auto &each_var_name : op->OutputArgumentNames()) {
ir::Node *var = nullptr;
if (all_vars.count(each_var_name) != 0) {
var = graph->CreateVarNode(all_vars.at(each_var_name));
} else {
var = graph->CreateVarNode(each_var_name);
}
node->outputs.push_back(var);
var->inputs.push_back(node);
}
}
return std::move(graph); return std::move(graph);
} }
......
...@@ -39,8 +39,6 @@ class Graph { ...@@ -39,8 +39,6 @@ class Graph {
attr_dels_.clear(); attr_dels_.clear();
} }
const ProgramDesc& Program() const { return program_; }
template <typename AttrType> template <typename AttrType>
AttrType& Get(const std::string& attr_name) const { AttrType& Get(const std::string& attr_name) const {
return *boost::any_cast<AttrType*>(attrs_.at(attr_name)); return *boost::any_cast<AttrType*>(attrs_.at(attr_name));
...@@ -63,11 +61,30 @@ class Graph { ...@@ -63,11 +61,30 @@ class Graph {
return attr; return attr;
} }
ir::Node* CreateVarNode(VarDesc* var_desc) {
nodes.emplace_back(new ir::Node(var_desc));
return nodes.back().get();
}
ir::Node* CreateOpNode(OpDesc* op_desc) {
nodes.emplace_back(new ir::Node(op_desc));
return nodes.back().get();
}
// TODO(panyx0718): Need to handle CreateOpNode(nullptr).
ir::Node* CreateVarNode(const std::string& var_name) {
var_descs_.emplace_back(new VarDesc(var_name));
nodes.emplace_back(new ir::Node(var_descs_.back().get()));
return nodes.back().get();
}
std::vector<ir::Node*> inputs; std::vector<ir::Node*> inputs;
std::vector<ir::Node*> outputs; std::vector<ir::Node*> outputs;
std::vector<std::unique_ptr<ir::Node>> nodes; std::vector<std::unique_ptr<ir::Node>> nodes;
std::vector<std::unique_ptr<VarDesc>> var_descs_;
private: private:
// NOTE: program_ shouldn't be exposed to user.
const ProgramDesc& program_; const ProgramDesc& program_;
std::map<std::string, boost::any> attrs_; std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
......
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
...@@ -32,10 +34,12 @@ class Node { ...@@ -32,10 +34,12 @@ class Node {
public: public:
enum class Type { kNone = -1, kOperation, kVariable }; enum class Type { kNone = -1, kOperation, kVariable };
Node() : type_(Type::kNone) {}
explicit Node(Type type) : type_(type) {} explicit Node(Type type) : type_(type) {}
virtual ~Node() { virtual ~Node() {
for (auto &attr : attrs_) { for (auto& attr : attrs_) {
if (attr_dels_.find(attr.first) != attr_dels_.end()) { if (attr_dels_.find(attr.first) != attr_dels_.end()) {
attr_dels_[attr.first](); attr_dels_[attr.first]();
} }
...@@ -47,23 +51,34 @@ class Node { ...@@ -47,23 +51,34 @@ class Node {
Type NodeType() const { return type_; } Type NodeType() const { return type_; }
template <typename AttrType> template <typename AttrType>
void Set(const std::string &name, AttrType attr) { void Set(const std::string& name, AttrType attr) {
attrs_[name] = attr; attrs_[name] = attr;
} }
template <typename AttrType> template <typename AttrType>
void Set(const std::string &name, AttrType *attr, void Set(const std::string& name, AttrType* attr,
std::function<void(void)> attr_del) { std::function<void(void)> attr_del) {
attrs_[name] = attr; attrs_[name] = attr;
attr_dels_[name] = attr_del; attr_dels_[name] = attr_del;
} }
std::vector<Node *> inputs; VarDesc* Var() { return var_desc_; }
std::vector<Node *> outputs; OpDesc* Op() { return op_desc_; }
explicit Node(VarDesc* var_desc)
: var_desc_(var_desc), op_desc_(nullptr), type_(Type::kVariable) {}
explicit Node(OpDesc* op_desc)
: var_desc_(nullptr), op_desc_(op_desc), type_(Type::kOperation) {}
std::vector<Node*> inputs;
std::vector<Node*> outputs;
protected: protected:
std::map<std::string, boost::any> attrs_; std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
VarDesc* var_desc_;
OpDesc* op_desc_;
Type type_; Type type_;
private: private:
......
...@@ -20,15 +20,15 @@ limitations under the License. */ ...@@ -20,15 +20,15 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir {
class Pass { class Pass {
public: public:
Pass() = default; Pass() = default;
virtual ~Pass() {} virtual ~Pass() {}
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) {
return std::move(graph);
}
};
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
};
} // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -131,13 +131,10 @@ ParallelExecutor::ParallelExecutor( ...@@ -131,13 +131,10 @@ ParallelExecutor::ParallelExecutor(
PADDLE_THROW("Not compiled with CUDA."); PADDLE_THROW("Not compiled with CUDA.");
#endif #endif
} }
builder_ = builder_factory.Create(); builder_ = builder_factory.Create();
std::unique_ptr<Graph> graph = builder_->Build(ProgramToGraph(main_program)); std::unique_ptr<Graph> graph = builder_->Apply(ProgramToGraph(main_program));
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos), exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_))); member_->places_, std::move(member_->executor_)));
......
...@@ -148,6 +148,7 @@ class ParallelExecutor(object): ...@@ -148,6 +148,7 @@ class ParallelExecutor(object):
lambda var: var.persistable and var.type != core.VarDesc.VarType.RAW, lambda var: var.persistable and var.type != core.VarDesc.VarType.RAW,
main.list_vars()) main.list_vars())
] ]
sys.stderr.write('!!!!!!!!before\n')
self.executor = core.ParallelExecutor( self.executor = core.ParallelExecutor(
self._places, self._places,
...@@ -158,6 +159,7 @@ class ParallelExecutor(object): ...@@ -158,6 +159,7 @@ class ParallelExecutor(object):
set(self.persistable_vars), main.desc, loss_name set(self.persistable_vars), main.desc, loss_name
if loss_name else '', scope, local_scopes, exec_strategy, if loss_name else '', scope, local_scopes, exec_strategy,
build_strategy, num_trainers, trainer_id) build_strategy, num_trainers, trainer_id)
sys.stderr.write('!!!!!!!!after\n')
self.scope = scope self.scope = scope
def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True): def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册