diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc index 63a6ed90828edb16d29c4510cbd0fe6ad1552f00..1413f7bd9ac515ae7dceee62de8f3bc74e3a2efc 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc @@ -96,7 +96,8 @@ struct TestBroadcastOpHandle { } param_scopes_[input_scope_idx]->Var("input"); - std::unique_ptr n(new ir::Node("node0")); + std::unique_ptr n( + new ir::Node("node0", ir::Node::Type::kOperation)); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, @@ -114,7 +115,8 @@ struct TestBroadcastOpHandle { #endif } - std::unique_ptr v(new ir::Node("node1")); + std::unique_ptr v( + new ir::Node("node1", ir::Node::Type::kVariable)); auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", gpu_list_[input_scope_idx]); vars_.emplace_back(in_var_handle); @@ -122,7 +124,8 @@ struct TestBroadcastOpHandle { // add dummy var - std::unique_ptr v2(new ir::Node("node2")); + std::unique_ptr v2( + new ir::Node("node2", ir::Node::Type::kVariable)); vars_.emplace_back(new DummyVarHandle(v2.get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); @@ -133,7 +136,8 @@ struct TestBroadcastOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - std::unique_ptr v3(new ir::Node("node3")); + std::unique_ptr v3( + new ir::Node("node3", ir::Node::Type::kVariable)); VarHandle* out_var_handle = new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); vars_.emplace_back(out_var_handle); @@ -141,7 +145,8 @@ struct TestBroadcastOpHandle { } // add dummy var - std::unique_ptr v4(new ir::Node("node4")); + std::unique_ptr v4( + new ir::Node("node4", ir::Node::Type::kVariable)); vars_.emplace_back(new DummyVarHandle(v4.get())); DummyVarHandle* out_dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index e3806ac5e141dd2b6cd846fd45b8e1c484169e06..c9b94d1e1039df6ff27f9ffe225b2a50c35a5c50 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -82,13 +82,13 @@ struct TestGatherOpHandle { } param_scopes_[input_scope_idx]->Var("out"); - nodes.emplace_back(new ir::Node("node")); + nodes.emplace_back(new ir::Node("node", ir::Node::Type::kOperation)); op_handle_.reset( new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); - nodes.emplace_back(new ir::Node("node1")); + nodes.emplace_back(new ir::Node("node1", ir::Node::Type::kVariable)); auto* in_var_handle = new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); vars_.emplace_back(in_var_handle); @@ -96,7 +96,7 @@ struct TestGatherOpHandle { } // add dummy var - nodes.emplace_back(new ir::Node("node2")); + nodes.emplace_back(new ir::Node("node2", ir::Node::Type::kVariable)); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* in_dummy_var_handle = static_cast(vars_.back().get()); @@ -104,14 +104,14 @@ struct TestGatherOpHandle { op_handle_->AddInput(in_dummy_var_handle); // add output - nodes.emplace_back(new ir::Node("node3")); + nodes.emplace_back(new ir::Node("node3", ir::Node::Type::kVariable)); auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, "out", gpu_list_[input_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); // add dummy var - nodes.emplace_back(new ir::Node("node4")); + nodes.emplace_back(new ir::Node("node4", ir::Node::Type::kVariable)); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 1e7ec95342e406523b2444918220780b8bbb62be..c52980472de8d48e8c21e7c1e53813aa4847cece 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -80,7 +80,14 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node, } for (ir::Node *output : node->outputs) { - CreateOpOutput(result, op_handle, output, p, place_id); + ir::Node *new_node = nullptr; + if (output->Var()) { + new_node = result->CreateVarNode(output->Var()); + } else { + new_node = + result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable); + } + CreateOpOutput(result, op_handle, new_node, p, place_id); } } @@ -246,7 +253,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) { node->Op()->SetAttr("throw_eof_exp", false); CreateComputationalOps(&result, node.get(), places_.size()); - // TODO(panyx0718): builder shouldn't depend on the out logic of + // TODO(paddle-dev): builder shouldn't depend on the out logic of // a specific op. const auto &data_var_names = node->Op()->Output("Out"); InsertDataBalanceOp(&result, data_var_names); @@ -354,11 +361,13 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, const std::string &p_name, size_t src_dev_id) const { #ifdef PADDLE_WITH_CUDA - auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"), - local_scopes_, places_, nccl_ctxs_); + auto *op_handle = new BroadcastOpHandle( + result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_); #else - auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"), - local_scopes_, places_); + auto *op_handle = new BroadcastOpHandle( + result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), + local_scopes_, places_); #endif result->Get("ops").emplace_back(op_handle); @@ -370,8 +379,9 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, auto &p = places_[i]; SetCommunicationContext(op_handle, p); auto &vars = result->Get("vars").at(i).at(p_name); - auto *out_var = new VarHandle(result->CreateEmptyNode(p_name), vars.size(), - i, p_name, p); + auto *out_var = new VarHandle( + result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(), + i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); } @@ -389,12 +399,13 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back( - new AllReduceOpHandle(result->CreateEmptyNode("allreduce"), local_scopes_, - places_, nccl_ctxs_)); + result->Get("ops").emplace_back(new AllReduceOpHandle( + result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new AllReduceOpHandle( - result->CreateEmptyNode("allreduce"), local_scopes_, places_)); + result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), + local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -407,7 +418,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, op_handle->AddInput(prev_grad.get()); auto var = - new VarHandle(result->CreateEmptyNode(og), vars.size(), i, og, p); + new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), + vars.size(), i, og, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -416,12 +428,13 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( Graph *result, const std::vector &datas) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back( - new DataBalanceOpHandle(result->CreateEmptyNode("data_balance"), - local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back(new DataBalanceOpHandle( + result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new DataBalanceOpHandle( - result->CreateEmptyNode("data_balance"), local_scopes_, places_)); + result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), + local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { @@ -431,8 +444,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( auto &vars = result->Get("vars")[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); - auto var = new VarHandle(result->CreateEmptyNode(d_name), vars.size(), i, - d_name, p); + auto var = new VarHandle( + result->CreateEmptyNode(d_name, ir::Node::Type::kVariable), + vars.size(), i, d_name, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -487,8 +501,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); #endif auto *op_handle = new ScaleLossGradOpHandle( - result->CreateEmptyNode("scale_loss_grad"), local_scopes_.size(), - local_scopes_[i], places_[i], communication_dev_ctx); + result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), + local_scopes_.size(), local_scopes_[i], places_[i], + communication_dev_ctx); result->Get("ops").emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale @@ -497,14 +512,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { // loss->pending_ops_.emplace_back(op_handle); // op_handle->inputs_.emplace_back(loss); - // TODO(panyx0718): GradVarName(loss_var_name_) - const std::string grad_var_name = GradVarName(loss_var_name_); - auto &vars = result->Get("vars")[i][grad_var_name]; - size_t version = vars.size(); - auto var = new VarHandle(result->CreateEmptyNode(grad_var_name), version, i, - grad_var_name, places_[i]); - vars.emplace_back(var); - op_handle->AddOutput(var); + CreateOpOutput(result, op_handle, + result->CreateEmptyNode(GradVarName(loss_var_name_), + ir::Node::Type::kVariable), + places_[i], i); } } @@ -525,10 +536,12 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, int dst_dev_id) const { #ifdef PADDLE_WITH_CUDA result->Get("ops").emplace_back(new ReduceOpHandle( - result->CreateEmptyNode("reduce"), local_scopes_, places_, nccl_ctxs_)); + result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new ReduceOpHandle( - result->CreateEmptyNode("reduce"), local_scopes_, places_)); + result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), + local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -541,8 +554,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, op_handle->AddInput(prev_grad.get()); } auto &vars = result->Get("vars")[dst_dev_id][og]; - auto var = new VarHandle(result->CreateEmptyNode(og), vars.size(), dst_dev_id, - og, places_[dst_dev_id]); + auto var = + new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), + vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); return var; @@ -554,7 +568,8 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { for (auto &prev_op : result->Get("ops")) { if (prev_op->Name() == prev_op_name) { - auto *dep_var = new DummyVarHandle(result->CreateEmptyNode("dummy")); + auto *dep_var = new DummyVarHandle( + result->CreateEmptyNode("dummy", ir::Node::Type::kVariable)); prev_op->AddOutput(dep_var); result->Get("dep_vars").emplace(dep_var); op->AddInput(dep_var); diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 7de4426de8814551e779959f2bdc563c52bc51ed..7bc130ef6e8d2e0caf6e445d12950b87e6dd4dbd 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -37,7 +37,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { continue; } - auto *dep_var = new DummyVarHandle(graph->CreateEmptyNode("dummy")); + auto *dep_var = new DummyVarHandle( + graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable)); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); graph->Get("dep_vars").emplace(dep_var); @@ -54,12 +55,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( auto &var_holder = var_holders[node->Name()]; VarHandle *var = nullptr; if (var_holder.empty()) { - if (node->NodeType() == ir::Node::Type::kVariable) { + if (node->Var()) { var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, node->Name(), place); } else { - var = new VarHandle(graph->CreateEmptyNode(node->Name()), 0, place_offset, - node->Name(), place); + var = new VarHandle( + graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), 0, + place_offset, node->Name(), place); } var_holder.emplace_back(var); } else { @@ -69,13 +71,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( } void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, - ir::Node *node, + ir::Node *new_node, const platform::Place &place, size_t place_offset) { - auto &vars = graph->Get("vars")[place_offset][node->Name()]; + auto &vars = graph->Get("vars")[place_offset][new_node->Name()]; size_t version = vars.size(); - auto var = new VarHandle(graph->CreateVarNode(node->Var()), version, - place_offset, node->Name(), place); + auto var = + new VarHandle(new_node, version, place_offset, new_node->Name(), place); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -85,7 +87,8 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { if (!op->Outputs().empty()) { continue; } - auto *dummy_leaf = new DummyVarHandle(graph->CreateEmptyNode("dummy")); + auto *dummy_leaf = new DummyVarHandle( + graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable)); graph->Get("dep_vars").emplace(dummy_leaf); op->AddOutput(dummy_leaf); } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 87749009efe0c540a74b792cf233d10a7bf4ea69..e8e8acdb38f893302fb92c47d6f1cb2d38453e0f 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -73,7 +73,7 @@ class SSAGraphBuilder : public ir::Pass { // Add an output variable (each_var_name, place, place_offset) to op_handle, // which belongs to graph static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle, - ir::Node *node, const platform::Place &place, + ir::Node *new_node, const platform::Place &place, size_t place_offset); static void AddOutputToLeafOps(Graph *graph); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index ac777883657e19cdb34f287f4a2d2dd1385fbf8a..38cde13fe279d264c51baff71cffcab7b6ebb227 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -173,7 +173,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( auto &var_name = fetch_tensors[i]; auto &vars = fetched_vars.at(var_name); - temp_nodes->emplace_back(new ir::Node("fetch")); + temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i, &local_scopes_); fetch_ops->emplace_back(op); @@ -186,7 +186,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( op->AddInput(var); } - temp_nodes->emplace_back(new ir::Node("fetch")); + temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get()); op->AddOutput(fetch_dummy); fetch_dependencies->emplace(fetch_dummy); diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index f8381af985a129d153a60daefabe2e4d3fafa375..d384ac0d50d1832cfec62cad096bfc9a71e02c70 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -41,7 +41,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { // TODO(paddle-dev): Seems some assumption doesn't hold? LOG(ERROR) << op->Type() << " input var not in all_var list: " << each_var_name; - var = CreateEmptyNode(each_var_name); + var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable); var_nodes[each_var_name] = var; } node->inputs.push_back(var); diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 93db573559799faadb3c6f245e5ead133735f510..3c268682afebe8a831b7cfd5f484c5fd8814c6dc 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -67,8 +67,8 @@ class Graph { // TODO(paddle-dev): There shouldn't be kNone nodes in the ir::Graph. // node should either be a executable kOperation or a kVariable. kNone // node is a temporary solution. - ir::Node* CreateEmptyNode(const std::string& name) { - nodes.emplace_back(new ir::Node(name)); + ir::Node* CreateEmptyNode(const std::string& name, ir::Node::Type type) { + nodes.emplace_back(new ir::Node(name, type)); return nodes.back().get(); } diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index cb1d524c34cf6a9263e47826d29eebe6f983f32b..38080b4ec578c23ae77ba77ea0b8aa8c40214617 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -26,12 +26,9 @@ namespace ir { class Node { public: - enum class Type { kNone, kOperation, kVariable }; - explicit Node(const std::string& name) - : name_(name), - var_desc_(nullptr), - op_desc_(nullptr), - type_(Type::kNone) {} + enum class Type { kOperation, kVariable }; + explicit Node(const std::string& name, Type type) + : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} explicit Node(VarDesc* var_desc) : name_(var_desc->Name()),