diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc index 90ee3f7d93e7e51f48ef491437511841b4c8e448..1609b5965c03d0fddabf67c264c0065feb5e3551 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc @@ -96,7 +96,7 @@ struct TestBroadcastOpHandle { } param_scopes_[input_scope_idx]->Var("input"); - std::unique_ptr n(new ir::Node(ir::Node::Type::kOperation)); + std::unique_ptr n(new ir::Node()); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, @@ -114,7 +114,7 @@ struct TestBroadcastOpHandle { #endif } - std::unique_ptr v(new ir::Node(ir::Node::Type::kVariable)); + std::unique_ptr v(new ir::Node()); auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", gpu_list_[input_scope_idx]); vars_.emplace_back(in_var_handle); @@ -122,7 +122,7 @@ struct TestBroadcastOpHandle { // add dummy var - std::unique_ptr v2(new ir::Node(ir::Node::Type::kVariable)); + std::unique_ptr v2(new ir::Node()); vars_.emplace_back(new DummyVarHandle(v2.get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); @@ -133,7 +133,7 @@ struct TestBroadcastOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - std::unique_ptr v3(new ir::Node(ir::Node::Type::kVariable)); + std::unique_ptr v3(new ir::Node()); VarHandle* out_var_handle = new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); vars_.emplace_back(out_var_handle); @@ -141,7 +141,7 @@ struct TestBroadcastOpHandle { } // add dummy var - std::unique_ptr v4(new ir::Node(ir::Node::Type::kVariable)); + std::unique_ptr v4(new ir::Node()); vars_.emplace_back(new DummyVarHandle(v4.get())); DummyVarHandle* out_dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index 5b11f8cdc7479be5e44bccaedb42c656834168b3..f80cabf501028e10cad7c46a2cbc68c66db235b1 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -82,13 +82,13 @@ struct TestGatherOpHandle { } param_scopes_[input_scope_idx]->Var("out"); - nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); + nodes.emplace_back(new ir::Node()); op_handle_.reset( new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); - nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + nodes.emplace_back(new ir::Node()); auto* in_var_handle = new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); vars_.emplace_back(in_var_handle); @@ -96,7 +96,7 @@ struct TestGatherOpHandle { } // add dummy var - nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + nodes.emplace_back(new ir::Node()); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* in_dummy_var_handle = static_cast(vars_.back().get()); @@ -104,14 +104,14 @@ struct TestGatherOpHandle { op_handle_->AddInput(in_dummy_var_handle); // add output - nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + nodes.emplace_back(new ir::Node()); auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, "out", gpu_list_[input_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); // add dummy var - nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + nodes.emplace_back(new ir::Node()); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index cb2ab905163ab9673e42e3c23a5d12ae55fb201a..d66bc40090fa4c09ba4365ec88a59c6974f1bc10 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -67,30 +67,31 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( } } -void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, const OpDesc &op, +void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node, size_t place_id) const { auto p = places_[place_id]; auto *op_handle = result->Get("ops").back().get(); op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); - for (auto &each_var_name : op.InputArgumentNames()) { - VarHandle *var = - CreateOrGetLatestVarHandle(result, each_var_name, p, place_id); + for (ir::Node *input : node->inputs) { + VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id); op_handle->AddInput(var); } - for (auto &each_var_name : op.OutputArgumentNames()) { - CreateOpOutput(result, op_handle, each_var_name, p, place_id); + for (ir::Node *output : node->outputs) { + CreateOpOutput(result, op_handle, output, p, place_id); } } std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( - const ProgramDesc &program) const { + const std::vector> &nodes) const { std::vector send_vars; // since parameters are all in block 0, // it's enough to only scan send ops in block 0 - for (auto *op : program.Block(0).AllOps()) { + for (auto &node : nodes) { + if (!node->Op()) continue; + OpDesc *op = node->Op(); // TODO(Yancey1989): use a graceful method to find send op, // instead of the the hard code string if (op->Type() == "send") { @@ -104,9 +105,11 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( } std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( - const ProgramDesc &program) const { + const std::vector> &nodes) const { std::vector recv_vars; - for (auto *op : program.Block(0).AllOps()) { + for (auto &node : nodes) { + if (!node->Op()) continue; + OpDesc *op = node->Op(); // TODO(Yancey1989): use a graceful method to find recv op, // instead of the hard code string if (op->Type() == "recv") { @@ -120,7 +123,7 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( } bool MultiDevSSAGraphBuilder::IsDistTrainOp( - const OpDesc &op, const std::vector &send_vars, + ir::Node *node, const std::vector &send_vars, const std::vector &recv_vars) const { if (send_vars.size() == 0 || recv_vars.size() == 0) { return false; @@ -143,8 +146,17 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( return false; }; - return checker(op.OutputArgumentNames(), send_vars) || - checker(op.InputArgumentNames(), recv_vars); + std::vector input_var_names; + std::vector output_var_names; + for (ir::Node *input : node->inputs) { + input_var_names.push_back(input->Var()->Name()); + } + for (ir::Node *output : node->outputs) { + output_var_names.push_back(output->Var()->Name()); + } + + return checker(output_var_names, send_vars) || + checker(input_var_names, recv_vars); } size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( @@ -167,11 +179,16 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( return dev_id; } -std::unique_ptr MultiDevSSAGraphBuilder::Build( +std::unique_ptr MultiDevSSAGraphBuilder::Apply( std::unique_ptr graph) const { - const ProgramDesc &program = graph->Program(); - for (auto *var : program.Block(0).AllVars()) { - all_vars_.emplace(var->Name(), var); + auto nodes = std::move(graph->nodes); + graph->nodes.clear(); + LOG(ERROR) << "origin nodes count " << nodes.size(); + + for (auto &node : nodes) { + if (node->Var()) { + all_vars_.emplace(node->Var()->Name(), node->Var()); + } } Graph &result = *graph; @@ -181,10 +198,11 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( result.Set("vars", new GraphVars(places_.size())); result.Set("dep_vars", new GraphDepVars); result.Set("ops", new GraphOps); + // find send/recv vars so that we can place the distributed training // realted op in the place 0 - auto send_vars = FindDistTrainSendVars(program); - auto recv_vars = FindDistTrainRecvVars(program); + auto send_vars = FindDistTrainSendVars(nodes); + auto recv_vars = FindDistTrainRecvVars(nodes); std::vector> bcast_var_name_set; bcast_var_name_set.resize(places_.size()); @@ -192,14 +210,16 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( size_t cur_device_id = 0; bool is_forwarding = true; - for (auto *op : program.Block(0).AllOps()) { + // TODO(panyx0718): FIXME: nodes should be sorted by "program" order. + for (auto &node : nodes) { + if (!node->Op()) continue; if (boost::get( - op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kRPC)) { - CreateRPCOp(&result, *op); - } else if (IsDistTrainOp(*op, send_vars, recv_vars)) { - CreateDistTrainOp(&result, *op); - } else if (IsScaleLossOp(*op)) { + CreateRPCOp(&result, node.get()); + } else if (IsDistTrainOp(node.get(), send_vars, recv_vars)) { + CreateDistTrainOp(&result, node.get()); + } else if (IsScaleLossOp(node.get())) { // user can customize loss@grad if not use_default_grad_scale_ if (strategy_.gradient_scale_ != BuildStrategy::GradientScaleStrategy::kCustomized) { @@ -211,33 +231,35 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( // the block. is_forwarding = false; } else { - int op_dev_id = GetOpDeviceID(*op); + int op_dev_id = GetOpDeviceID(node.get()); if (op_dev_id != -1) { // This op only runs on one specific device. - CreateComputationalOp(&result, *op, op_dev_id); - for (auto &var_name : op->OutputArgumentNames()) { - var_name_on_devices_.emplace(var_name, op_dev_id); + CreateComputationalOp(&result, node.get(), op_dev_id); + for (ir::Node *n : node->outputs) { + var_name_on_devices_.emplace(n->Var()->Name(), op_dev_id); } } else { // This op runs on all devices, and its output may have parameter's // gradients. - if (op->Type() == "read" && strategy_.enable_data_balance_) { - op->SetAttr("throw_eof_exp", false); - CreateComputationalOps(&result, *op, places_.size()); - const auto &data_var_names = op->Output("Out"); + if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) { + node->Op()->SetAttr("throw_eof_exp", false); + CreateComputationalOps(&result, node.get(), places_.size()); + // TODO(panyx0718): builder shouldn't depend on the out logic of + // a specific op. + const auto &data_var_names = node->Op()->Output("Out"); InsertDataBalanceOp(&result, data_var_names); } else { - CreateComputationalOps(&result, *op, places_.size()); + CreateComputationalOps(&result, node.get(), places_.size()); } if (!is_forwarding && places_.size() > 1) { // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. - if (static_cast(boost::get(op->GetAttr( + if (static_cast(boost::get(node->Op()->GetAttr( OpProtoAndCheckerMaker::OpRoleAttrName())) & static_cast(OpRole::kBackward))) { try { - auto backward_vars = - boost::get>(op->GetNullableAttr( + auto backward_vars = boost::get>( + node->Op()->GetNullableAttr( OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); @@ -328,13 +350,12 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, const std::string &p_name, size_t src_dev_id) const { - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA - auto *op_handle = new BroadcastOpHandle(result->nodes.back().get(), + auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_); #else - auto *op_handle = - new BroadcastOpHandle(result->nodes.back().get(), local_scopes_, places_); + auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr), + local_scopes_, places_); #endif result->Get("ops").emplace_back(op_handle); @@ -345,33 +366,31 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); auto &vars = result->Get("vars").at(i).at(p_name); auto *out_var = - new VarHandle(result->nodes.back().get(), vars.size(), i, p_name, p); + new VarHandle(result->CreateVarNode(p_name), vars.size(), i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); } } void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, - const OpDesc &op, + ir::Node *node, int dev_id) const { - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); - result->Get("ops").emplace_back(new ComputationOpHandle( - result->nodes.back().get(), op, local_scopes_[dev_id], places_[dev_id])); - CreateOpHandleIOs(result, op, dev_id); + result->Get("ops").emplace_back( + new ComputationOpHandle(result->CreateOpNode(node->Op()), *node->Op(), + local_scopes_[dev_id], places_[dev_id])); + CreateOpHandleIOs(result, node, dev_id); } void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, const std::string &og) const { - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA result->Get("ops").emplace_back(new AllReduceOpHandle( - result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); + result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new AllReduceOpHandle( - result->nodes.back().get(), local_scopes_, places_)); + result->CreateOpNode(nullptr), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -383,8 +402,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - auto var = new VarHandle(result->nodes.back().get(), vars.size(), i, og, p); + auto var = new VarHandle(result->CreateVarNode(og), vars.size(), i, og, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -392,13 +410,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( Graph *result, const std::vector &datas) const { - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA result->Get("ops").emplace_back(new DataBalanceOpHandle( - result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); + result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new DataBalanceOpHandle( - result->nodes.back().get(), local_scopes_, places_)); + result->CreateOpNode(nullptr), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { @@ -408,9 +425,8 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( auto &vars = result->Get("vars")[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - auto var = - new VarHandle(result->nodes.back().get(), vars.size(), i, d_name, p); + auto var = new VarHandle(result->CreateVarNode(d_name), vars.size(), i, + d_name, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -429,17 +445,17 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( return is_pg_once; } -int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const { +int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { return -1; } int op_role = boost::get( - op.GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName())); + node->Op()->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName())); if (op_role != static_cast(framework::OpRole::kOptimize)) { return -1; } auto param_grad = boost::get>( - op.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + node->Op()->.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(param_grad.size(), 2U); int dev_id = GetVarDeviceID(param_grad[1]); @@ -464,9 +480,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { auto *communication_dev_ctx = platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); #endif - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); auto *op_handle = new ScaleLossGradOpHandle( - result->nodes.back().get(), local_scopes_.size(), local_scopes_[i], + result->CreateOpNode(nullptr), local_scopes_.size(), local_scopes_[i], places_[i], communication_dev_ctx); result->Get("ops").emplace_back(op_handle); @@ -476,34 +491,38 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { // loss->pending_ops_.emplace_back(op_handle); // op_handle->inputs_.emplace_back(loss); - CreateOpOutput(result, op_handle, GradVarName(loss_var_name_), places_[i], - i); + // TODO(panyx0718): GradVarName(loss_var_name_) + const std::string grad_var_name = GradVarName(loss_var_name_); + auto &vars = result->Get("vars")[i][grad_var_name]; + size_t version = vars.size(); + auto var = new VarHandle(result->CreateVarNode(grad_var_name), version, i, + grad_var_name, places_[i]); + vars.emplace_back(var); + op_handle->AddOutput(var); } } void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, - const OpDesc &op, + ir::Node *node, size_t num_places) const { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); - result->Get("ops").emplace_back( - new ComputationOpHandle(result->nodes.back().get(), op, s, p)); - CreateOpHandleIOs(result, op, scope_idx); + result->Get("ops").emplace_back(new ComputationOpHandle( + result->CreateOpNode(node->Op()), *node->Op(), s, p)); + CreateOpHandleIOs(result, node, scope_idx); } } VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, const std::string &og, int dst_dev_id) const { - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA result->Get("ops").emplace_back(new ReduceOpHandle( - result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); + result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back( - new ReduceOpHandle(result->nodes.back().get(), local_scopes_, places_)); + result->Get("ops").emplace_back(new ReduceOpHandle( + result->CreateOpNode(nullptr), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -516,8 +535,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, op_handle->AddInput(prev_grad.get()); } auto &vars = result->Get("vars")[dst_dev_id][og]; - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - auto var = new VarHandle(result->nodes.back().get(), vars.size(), dst_dev_id, + auto var = new VarHandle(result->CreateVarNode(og), vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); @@ -530,8 +548,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { for (auto &prev_op : result->Get("ops")) { if (prev_op->Name() == prev_op_name) { - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - auto *dep_var = new DummyVarHandle(result->nodes.back().get()); + auto *dep_var = new DummyVarHandle(result->CreateVarNode("dummy")); prev_op->AddOutput(dep_var); result->Get("dep_vars").emplace(dep_var); op->AddInput(dep_var); @@ -540,22 +557,32 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, } void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, - const OpDesc &op) const { + ir::Node *node) const { int op_dev_id = -1; - if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") { - op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); + std::vector input_var_names; + std::vector output_var_names; + for (ir::Node *input : node->inputs) { + input_var_names.push_back(input->Var()->Name()); + } + for (ir::Node *output : node->outputs) { + output_var_names.push_back(output->Var()->Name()); + } + + if (node->Op()->Type() == "split_byref" || + node->Op()->Type() == "split_selected_rows") { + op_dev_id = GetVarDeviceID(input_var_names[0]); if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { - op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); - for (auto &varname : op.InputArgumentNames()) { + op_dev_id = GetAppropriateDeviceID(input_var_names); + for (auto &varname : input_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } } - for (auto &varname : op.OutputArgumentNames()) { + for (auto &varname : output_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } - } else if (op.Type() == "concat") { - op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); - for (auto &varname : op.OutputArgumentNames()) { + } else if (node->Op()->Type() == "concat") { + op_dev_id = GetVarDeviceID(input_var_names[0]); + for (auto &varname : output_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } } else { @@ -565,35 +592,43 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, } PADDLE_ENFORCE(op_dev_id != -1, - "can not find right place for distributed op: %s", op.Type()); + "can not find right place for distributed op: %s", + node->Op()->Type()); - CreateComputationalOp(result, op, op_dev_id); - if (op.Type() == "concat") { + CreateComputationalOp(result, node, op_dev_id); + if (node->Op()->Type() == "concat") { ConnectOp(result, result->Get("ops").back().get(), "fetch_barrier"); } } // Create RPC related op handles that connects its in ops and out ops. -void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, - const OpDesc &op) const { +void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const { int op_dev_id = -1; - if (op.Type() == "send") { - op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); + if (node->Op()->Type() == "send") { + op_dev_id = GetVarDeviceID(node->inputs[0]->Var()->Name()); // the variable name which contains .block means it was splited by // split_byref op // so that we can balance the variable blocks to all the pserver // instances. if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && - op.InputArgumentNames()[0].find(".block") == std::string::npos) { - op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); - for (auto &varname : op.InputArgumentNames()) { + node->inputs[0]->Var()->Name().find(".block") == std::string::npos) { + std::vector input_var_names; + for (ir::Node *n : node->inputs) { + input_var_names.push_back(n->Var()->Name()); + } + op_dev_id = GetAppropriateDeviceID(input_var_names); + for (auto &varname : input_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } } - } else if (op.Type() == "recv") { - op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames()); - for (auto &varname : op.OutputArgumentNames()) { + } else if (node->Op()->Type() == "recv") { + std::vector output_var_names; + for (ir::Node *n : node->outputs) { + output_var_names.push_back(n->Var()->Name()); + } + op_dev_id = GetAppropriateDeviceID(output_var_names); + for (auto &varname : output_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } } else { @@ -602,21 +637,20 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, } PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", - op.Type()); + node->Op()->Type()); - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); - result->Get("ops").emplace_back( - new RPCOpHandle(result->nodes.back().get(), op, local_scopes_[op_dev_id], - op.Type(), places_[op_dev_id])); + result->Get("ops").emplace_back(new RPCOpHandle( + result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], + node->Op()->Type(), places_[op_dev_id])); - if (op.Type() == "send_barrier") { + if (node->Op()->Type() == "send_barrier") { ConnectOp(result, result->Get("ops").back().get(), "send"); - } else if (op.Type() == "recv") { + } else if (node->Op()->Type() == "recv") { ConnectOp(result, result->Get("ops").back().get(), "send_barrier"); - } else if (op.Type() == "fetch_barrier") { + } else if (node->Op()->Type() == "fetch_barrier") { ConnectOp(result, result->Get("ops").back().get(), "recv"); - } else if (op.Type() == "send") { + } else if (node->Op()->Type() == "send") { // do nothing } else { PADDLE_THROW( @@ -624,12 +658,12 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, "send, send_barrier. recv, fetch_barrier]"); } - CreateOpHandleIOs(result, op, op_dev_id); + CreateOpHandleIOs(result, node, op_dev_id); } -bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { +bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { return boost::get( - op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == (static_cast(OpRole::kBackward) | static_cast(OpRole::kLoss)) && !loss_var_name_.empty(); // If loss_var is empty. This is test mode diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 248ea8ea62ba92270d12c895b1ffc8c353248cd6..2b7f4f586b4e750fde9245286c977258a9db6086 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -46,13 +46,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::vector &local_scopes, const BuildStrategy &strategy); #endif - - std::unique_ptr Build(std::unique_ptr graph) const override; + std::unique_ptr Apply(std::unique_ptr graph) const override; int GetVarDeviceID(const std::string &varname) const override; private: - void CreateOpHandleIOs(Graph *result, const OpDesc &op, - size_t device_id) const; + void CreateOpHandleIOs(Graph *result, ir::Node *node, size_t device_id) const; private: std::string loss_var_name_; @@ -64,40 +62,39 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { platform::NCCLContextMap *nccl_ctxs_; #endif - bool IsScaleLossOp(const OpDesc &op) const; + bool IsScaleLossOp(ir::Node *node) const; - void CreateRPCOp(Graph *result, const OpDesc &op) const; - void CreateDistTrainOp(Graph *result, const OpDesc &op) const; + void CreateRPCOp(Graph *result, ir::Node *node) const; + void CreateDistTrainOp(Graph *result, ir::Node *node) const; /** * Is this operator as the end-point operator before/after send operator. */ - bool IsDistTrainOp(const OpDesc &op, - const std::vector &send_vars, + bool IsDistTrainOp(ir::Node *node, const std::vector &send_vars, const std::vector &recv_vars) const; std::vector FindDistTrainSendVars( - const ProgramDesc &program) const; + const std::vector> &nodes) const; std::vector FindDistTrainRecvVars( - const ProgramDesc &program) const; + const std::vector> &nodes) const; void ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const; - void CreateComputationalOps(Graph *result, const OpDesc &op, + void CreateComputationalOps(Graph *result, ir::Node *node, size_t num_places) const; void CreateScaleLossGradOp(Graph *result) const; VarHandle *CreateReduceOp(Graph *result, const std::string &og, int dst_dev_id) const; - void CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const; + void CreateComputationalOp(Graph *result, ir::Node *node, int dev_id) const; bool IsParameterGradientOnce( const std::string &og, std::unordered_set *og_has_been_broadcast) const; - int GetOpDeviceID(const OpDesc &op) const; + int GetOpDeviceID(ir::Node *node) const; void InsertAllReduceOp(Graph *result, const std::string &og) const; diff --git a/paddle/fluid/framework/details/reduce_op_handle_test.cc b/paddle/fluid/framework/details/reduce_op_handle_test.cc index d029dd9e15e2ecfab1d4d5b8674f16ed3f4cdf32..e7c83ffd320a2bb4ebe304b8557af7b777ededfe 100644 --- a/paddle/fluid/framework/details/reduce_op_handle_test.cc +++ b/paddle/fluid/framework/details/reduce_op_handle_test.cc @@ -97,7 +97,7 @@ struct TestReduceOpHandle { } param_scopes_[out_scope_idx]->Var("out"); - nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); + nodes.emplace_back(new ir::Node()); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_, @@ -121,7 +121,7 @@ struct TestReduceOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + nodes.emplace_back(new ir::Node()); auto *in_var_handle = new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); in_var_handle->ClearGeneratedOp(); @@ -137,7 +137,7 @@ struct TestReduceOpHandle { op_handle_->AddInput(in_dummy_var_handle); // add output - nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + nodes.emplace_back(new ir::Node()); auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx, "out", gpu_list_[out_scope_idx]); vars_.emplace_back(out_var_handle); diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 846f98ddfa38ec68947a70cfffba4e0aa62bd690..6a8bd7875cb52e1383c814f26dbe299a6c87cb9c 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -37,8 +37,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { continue; } - graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - auto *dep_var = new DummyVarHandle(graph->nodes.back().get()); + auto *dep_var = new DummyVarHandle(graph->CreateVarNode("dummy")); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); graph->Get("dep_vars").emplace(dep_var); @@ -49,15 +48,14 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { } VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( - Graph *graph, const std::string &each_var_name, - const platform::Place &place, size_t place_offset) { + Graph *graph, ir::Node *node, const platform::Place &place, + size_t place_offset) { auto &var_holders = graph->Get("vars")[place_offset]; - auto &var_holder = var_holders[each_var_name]; + auto &var_holder = var_holders[node->Var()->Name()]; VarHandle *var = nullptr; if (var_holder.empty()) { - graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - var = new VarHandle(graph->nodes.back().get(), 0, place_offset, - each_var_name, place); + var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, + node->Var()->Name(), place); var_holder.emplace_back(var); } else { var = var_holder.rbegin()->get(); @@ -66,14 +64,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( } void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, - const std::string &each_var_name, + ir::Node *node, const platform::Place &place, size_t place_offset) { - auto &vars = graph->Get("vars")[place_offset][each_var_name]; + auto &vars = graph->Get("vars")[place_offset][node->Var()->Name()]; size_t version = vars.size(); - graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - auto var = new VarHandle(graph->nodes.back().get(), version, place_offset, - each_var_name, place); + auto var = new VarHandle(graph->CreateVarNode(node->Var()), version, + place_offset, node->Var()->Name(), place); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -85,8 +82,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { if (!op->Outputs().empty()) { continue; } - graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - auto *dummy_leaf = new DummyVarHandle(graph->nodes.back().get()); + auto *dummy_leaf = new DummyVarHandle(graph->CreateVarNode("dummy")); graph->Get("dep_vars").emplace(dummy_leaf); op->AddOutput(dummy_leaf); } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 4fbf036241d1b9936d2ef902628ba1f850cee5b8..9933bf32b7a1faa6841620662ba0781c19e47f54 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -23,6 +23,7 @@ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" namespace paddle { namespace framework { @@ -34,11 +35,11 @@ typedef std::vector< typedef std::unordered_set> GraphDepVars; typedef std::vector> GraphOps; -class SSAGraphBuilder { +class SSAGraphBuilder : public ir::Pass { public: SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} - virtual std::unique_ptr Build(std::unique_ptr graph) const = 0; + virtual int GetVarDeviceID(const std::string &var_name) const = 0; DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); @@ -53,16 +54,15 @@ class SSAGraphBuilder { */ static void PolishGraphToSupportDataHazards(Graph *graph); - static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, - const std::string &each_var_name, + static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset); // Add an output variable (each_var_name, place, place_offset) to op_handle, // which belongs to graph static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle, - const std::string &each_var_name, - const platform::Place &place, size_t place_offset); + ir::Node *node, const platform::Place &place, + size_t place_offset); static void AddOutputToLeafOps(Graph *graph); }; diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 2c8b2e13c5cc40ef799e2647fc50370363dcfe37..f1080610381128325ea0affba760ac66798fd948 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -28,10 +28,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { std::unique_ptr&& builder) : builder_(std::move(builder)) {} - std::unique_ptr Build(std::unique_ptr graph) const override { - auto new_graph = builder_->Build(std::move(graph)); + std::unique_ptr Apply(std::unique_ptr graph) const override { + auto new_graph = builder_->Apply(std::move(graph)); PADDLE_ENFORCE(IsValidGraph(new_graph.get())); - return new_graph; + return std::move(new_graph); } int GetVarDeviceID(const std::string& var_name) const override { diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index 35f2a1b4f0e3c478717007ea1a43e1e3d6820861..411be02988a82b3e35d56833f92fc6fe405a2c3d 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { stream_ptr_(std::move(sout)), stream_ref_(*stream_ptr_) {} - std::unique_ptr Build(std::unique_ptr graph) const override { - auto new_graph = builder_->Build(std::move(graph)); + std::unique_ptr Apply(std::unique_ptr graph) const override { + auto new_graph = builder_->Apply(std::move(graph)); printer_->Print(*new_graph, stream_ref_); - return new_graph; + return std::move(new_graph); } int GetVarDeviceID(const std::string& var_name) const override { diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 28ad4efc71933697b887e33bf669cfc0c3fbb71e..c1f8f917c4cceb0f10bae0728bcb71081459db4c 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -13,12 +13,45 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/var_desc.h" namespace paddle { namespace framework { std::unique_ptr ProgramToGraph(const ProgramDesc &program) { std::unique_ptr graph(new Graph(program)); + + std::unordered_map all_vars; + for (auto *var : program.Block(0).AllVars()) { + all_vars.emplace(var->Name(), var); + } + + for (auto *op : program.Block(0).AllOps()) { + ir::Node *node = graph->CreateOpNode(op); + + for (auto &each_var_name : op->InputArgumentNames()) { + ir::Node *var = nullptr; + if (all_vars.count(each_var_name) != 0) { + var = graph->CreateVarNode(all_vars.at(each_var_name)); + } else { + var = graph->CreateVarNode(each_var_name); + } + node->inputs.push_back(var); + var->outputs.push_back(node); + } + + for (auto &each_var_name : op->OutputArgumentNames()) { + ir::Node *var = nullptr; + if (all_vars.count(each_var_name) != 0) { + var = graph->CreateVarNode(all_vars.at(each_var_name)); + } else { + var = graph->CreateVarNode(each_var_name); + } + node->outputs.push_back(var); + var->inputs.push_back(node); + } + } return std::move(graph); } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index e83cb5a82a3722313d16fe0517e90e58052a7f1a..ff4f31fb7afd300d5658c9a47a35443d185bd3c5 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -39,8 +39,6 @@ class Graph { attr_dels_.clear(); } - const ProgramDesc& Program() const { return program_; } - template AttrType& Get(const std::string& attr_name) const { return *boost::any_cast(attrs_.at(attr_name)); @@ -63,11 +61,30 @@ class Graph { return attr; } + ir::Node* CreateVarNode(VarDesc* var_desc) { + nodes.emplace_back(new ir::Node(var_desc)); + return nodes.back().get(); + } + + ir::Node* CreateOpNode(OpDesc* op_desc) { + nodes.emplace_back(new ir::Node(op_desc)); + return nodes.back().get(); + } + + // TODO(panyx0718): Need to handle CreateOpNode(nullptr). + ir::Node* CreateVarNode(const std::string& var_name) { + var_descs_.emplace_back(new VarDesc(var_name)); + nodes.emplace_back(new ir::Node(var_descs_.back().get())); + return nodes.back().get(); + } + std::vector inputs; std::vector outputs; std::vector> nodes; + std::vector> var_descs_; private: + // NOTE: program_ shouldn't be exposed to user. const ProgramDesc& program_; std::map attrs_; std::map> attr_dels_; diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 94ace92953a0f7fc7e585df66b48b35f95e2998f..0e0b81a7b149379008c5a9956e1c77b60249ffc1 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -21,6 +21,8 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/variant.h" @@ -32,10 +34,12 @@ class Node { public: enum class Type { kNone = -1, kOperation, kVariable }; + Node() : type_(Type::kNone) {} + explicit Node(Type type) : type_(type) {} virtual ~Node() { - for (auto &attr : attrs_) { + for (auto& attr : attrs_) { if (attr_dels_.find(attr.first) != attr_dels_.end()) { attr_dels_[attr.first](); } @@ -47,23 +51,34 @@ class Node { Type NodeType() const { return type_; } template - void Set(const std::string &name, AttrType attr) { + void Set(const std::string& name, AttrType attr) { attrs_[name] = attr; } template - void Set(const std::string &name, AttrType *attr, + void Set(const std::string& name, AttrType* attr, std::function attr_del) { attrs_[name] = attr; attr_dels_[name] = attr_del; } - std::vector inputs; - std::vector outputs; + VarDesc* Var() { return var_desc_; } + OpDesc* Op() { return op_desc_; } + + explicit Node(VarDesc* var_desc) + : var_desc_(var_desc), op_desc_(nullptr), type_(Type::kVariable) {} + + explicit Node(OpDesc* op_desc) + : var_desc_(nullptr), op_desc_(op_desc), type_(Type::kOperation) {} + + std::vector inputs; + std::vector outputs; protected: std::map attrs_; std::map> attr_dels_; + VarDesc* var_desc_; + OpDesc* op_desc_; Type type_; private: diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 2fc26c053f0b57790e2c385f424301b357d6b2d8..3f0fcff857ed8d8099d4d52d9a0898841577dbba 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -20,15 +20,15 @@ limitations under the License. */ namespace paddle { namespace framework { +namespace ir { class Pass { public: Pass() = default; virtual ~Pass() {} - virtual std::unique_ptr Apply(std::unique_ptr graph) { - return std::move(graph); - } -}; + virtual std::unique_ptr Apply(std::unique_ptr graph) const = 0; +}; +} // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index d30aba07a018841636a9e6e9ae747ca1ef5df6fe..c9014ffdf500e13158bd6e34fce9aa9f46b10904 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -131,13 +131,10 @@ ParallelExecutor::ParallelExecutor( PADDLE_THROW("Not compiled with CUDA."); #endif } - builder_ = builder_factory.Create(); - std::unique_ptr graph = builder_->Build(ProgramToGraph(main_program)); - + std::unique_ptr graph = builder_->Apply(ProgramToGraph(main_program)); member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, std::move(graph))); - member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), member_->places_, std::move(member_->executor_))); diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 10028a8c6e33edcea27650d925ca7378b770f143..59789c6defa499d2245762e2059be95f7927acbb 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -148,6 +148,7 @@ class ParallelExecutor(object): lambda var: var.persistable and var.type != core.VarDesc.VarType.RAW, main.list_vars()) ] + sys.stderr.write('!!!!!!!!before\n') self.executor = core.ParallelExecutor( self._places, @@ -158,6 +159,7 @@ class ParallelExecutor(object): set(self.persistable_vars), main.desc, loss_name if loss_name else '', scope, local_scopes, exec_strategy, build_strategy, num_trainers, trainer_id) + sys.stderr.write('!!!!!!!!after\n') self.scope = scope def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):