From 10786a243ebe33e425a6202bd541a180bc17c510 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Fri, 13 Jul 2018 19:29:43 +0800 Subject: [PATCH] polish graph --- .../details/broadcast_op_handle_test.cc | 10 +-- .../details/computation_op_handle.cc | 6 +- .../framework/details/computation_op_handle.h | 3 +- .../details/gather_op_handle_test.cc | 10 +-- .../details/multi_devices_graph_builder.cc | 83 ++++++++++--------- .../details/reduce_op_handle_test.cc | 6 +- .../framework/details/ssa_graph_builder.cc | 19 +++-- .../details/threaded_ssa_graph_executor.cc | 4 +- paddle/fluid/framework/ir/graph.cc | 10 +-- paddle/fluid/framework/ir/graph.h | 6 +- paddle/fluid/framework/ir/node.h | 58 ++++++------- python/paddle/fluid/parallel_executor.py | 2 - 12 files changed, 104 insertions(+), 113 deletions(-) diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc index 1609b5965c0..63a6ed90828 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc @@ -96,7 +96,7 @@ struct TestBroadcastOpHandle { } param_scopes_[input_scope_idx]->Var("input"); - std::unique_ptr n(new ir::Node()); + std::unique_ptr n(new ir::Node("node0")); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, @@ -114,7 +114,7 @@ struct TestBroadcastOpHandle { #endif } - std::unique_ptr v(new ir::Node()); + std::unique_ptr v(new ir::Node("node1")); auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", gpu_list_[input_scope_idx]); vars_.emplace_back(in_var_handle); @@ -122,7 +122,7 @@ struct TestBroadcastOpHandle { // add dummy var - std::unique_ptr v2(new ir::Node()); + std::unique_ptr v2(new ir::Node("node2")); vars_.emplace_back(new DummyVarHandle(v2.get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); @@ -133,7 +133,7 @@ struct TestBroadcastOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - std::unique_ptr v3(new ir::Node()); + std::unique_ptr v3(new ir::Node("node3")); VarHandle* out_var_handle = new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); vars_.emplace_back(out_var_handle); @@ -141,7 +141,7 @@ struct TestBroadcastOpHandle { } // add dummy var - std::unique_ptr v4(new ir::Node()); + std::unique_ptr v4(new ir::Node("node4")); vars_.emplace_back(new DummyVarHandle(v4.get())); DummyVarHandle* out_dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index 16ad30d491b..b6282debdb4 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -19,10 +19,10 @@ namespace paddle { namespace framework { namespace details { -ComputationOpHandle::ComputationOpHandle(ir::Node *node, const OpDesc &op_desc, - Scope *scope, platform::Place place) +ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, + platform::Place place) : OpHandleBase(node), - op_(framework::OpRegistry::CreateOp(op_desc)), + op_(framework::OpRegistry::CreateOp(*node->Op())), scope_(scope), place_(place) {} diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index 9ca1d927b8d..d9fcd92427e 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -28,8 +28,7 @@ namespace framework { namespace details { struct ComputationOpHandle : public OpHandleBase { public: - ComputationOpHandle(ir::Node *node, const OpDesc &op_desc, Scope *scope, - platform::Place place); + ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place); std::string Name() const override; diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index f80cabf5010..e3806ac5e14 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -82,13 +82,13 @@ struct TestGatherOpHandle { } param_scopes_[input_scope_idx]->Var("out"); - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node")); op_handle_.reset( new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node1")); auto* in_var_handle = new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); vars_.emplace_back(in_var_handle); @@ -96,7 +96,7 @@ struct TestGatherOpHandle { } // add dummy var - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node2")); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* in_dummy_var_handle = static_cast(vars_.back().get()); @@ -104,14 +104,14 @@ struct TestGatherOpHandle { op_handle_->AddInput(in_dummy_var_handle); // add output - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node3")); auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, "out", gpu_list_[input_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); // add dummy var - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node4")); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index d66bc40090f..035fb629a89 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -90,7 +90,7 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( // since parameters are all in block 0, // it's enough to only scan send ops in block 0 for (auto &node : nodes) { - if (!node->Op()) continue; + if (node->NodeType() != ir::Node::Type::kOperation) continue; OpDesc *op = node->Op(); // TODO(Yancey1989): use a graceful method to find send op, // instead of the the hard code string @@ -108,7 +108,7 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( const std::vector> &nodes) const { std::vector recv_vars; for (auto &node : nodes) { - if (!node->Op()) continue; + if (node->NodeType() != ir::Node::Type::kOperation) continue; OpDesc *op = node->Op(); // TODO(Yancey1989): use a graceful method to find recv op, // instead of the hard code string @@ -149,10 +149,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( std::vector input_var_names; std::vector output_var_names; for (ir::Node *input : node->inputs) { - input_var_names.push_back(input->Var()->Name()); + input_var_names.push_back(input->Name()); } for (ir::Node *output : node->outputs) { - output_var_names.push_back(output->Var()->Name()); + output_var_names.push_back(output->Name()); } return checker(output_var_names, send_vars) || @@ -181,13 +181,13 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( std::unique_ptr MultiDevSSAGraphBuilder::Apply( std::unique_ptr graph) const { + // Rebuild the graph structure. auto nodes = std::move(graph->nodes); graph->nodes.clear(); - LOG(ERROR) << "origin nodes count " << nodes.size(); for (auto &node : nodes) { - if (node->Var()) { - all_vars_.emplace(node->Var()->Name(), node->Var()); + if (node->NodeType() == ir::Node::Type::kVariable) { + all_vars_.emplace(node->Name(), node->Var()); } } @@ -212,7 +212,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( // TODO(panyx0718): FIXME: nodes should be sorted by "program" order. for (auto &node : nodes) { - if (!node->Op()) continue; + if (node->NodeType() != ir::Node::Type::kOperation) continue; if (boost::get( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kRPC)) { @@ -235,7 +235,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( if (op_dev_id != -1) { // This op only runs on one specific device. CreateComputationalOp(&result, node.get(), op_dev_id); for (ir::Node *n : node->outputs) { - var_name_on_devices_.emplace(n->Var()->Name(), op_dev_id); + var_name_on_devices_.emplace(n->Name(), op_dev_id); } } else { // This op runs on all devices, and its output may have parameter's @@ -351,10 +351,10 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, const std::string &p_name, size_t src_dev_id) const { #ifdef PADDLE_WITH_CUDA - auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr), + auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"), local_scopes_, places_, nccl_ctxs_); #else - auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr), + auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"), local_scopes_, places_); #endif result->Get("ops").emplace_back(op_handle); @@ -367,8 +367,8 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, auto &p = places_[i]; SetCommunicationContext(op_handle, p); auto &vars = result->Get("vars").at(i).at(p_name); - auto *out_var = - new VarHandle(result->CreateVarNode(p_name), vars.size(), i, p_name, p); + auto *out_var = new VarHandle(result->CreateEmptyNode(p_name), vars.size(), + i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); } @@ -378,7 +378,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, ir::Node *node, int dev_id) const { result->Get("ops").emplace_back( - new ComputationOpHandle(result->CreateOpNode(node->Op()), *node->Op(), + new ComputationOpHandle(result->CreateOpNode(node->Op()), local_scopes_[dev_id], places_[dev_id])); CreateOpHandleIOs(result, node, dev_id); } @@ -386,11 +386,12 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back(new AllReduceOpHandle( - result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back( + new AllReduceOpHandle(result->CreateEmptyNode("allreduce"), local_scopes_, + places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new AllReduceOpHandle( - result->CreateOpNode(nullptr), local_scopes_, places_)); + result->CreateEmptyNode("allreduce"), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -402,7 +403,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); - auto var = new VarHandle(result->CreateVarNode(og), vars.size(), i, og, p); + auto var = + new VarHandle(result->CreateEmptyNode(og), vars.size(), i, og, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -411,11 +413,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( Graph *result, const std::vector &datas) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back(new DataBalanceOpHandle( - result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back( + new DataBalanceOpHandle(result->CreateEmptyNode("data_balance"), + local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new DataBalanceOpHandle( - result->CreateOpNode(nullptr), local_scopes_, places_)); + result->CreateEmptyNode("data_balance"), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { @@ -425,7 +428,7 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( auto &vars = result->Get("vars")[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); - auto var = new VarHandle(result->CreateVarNode(d_name), vars.size(), i, + auto var = new VarHandle(result->CreateEmptyNode(d_name), vars.size(), i, d_name, p); vars.emplace_back(var); op_handle->AddOutput(var); @@ -455,12 +458,12 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { return -1; } auto param_grad = boost::get>( - node->Op()->.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(param_grad.size(), 2U); int dev_id = GetVarDeviceID(param_grad[1]); - PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", op.Type(), - param_grad[0]); + PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", + node->Op()->Type(), param_grad[0]); return dev_id; } @@ -481,8 +484,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); #endif auto *op_handle = new ScaleLossGradOpHandle( - result->CreateOpNode(nullptr), local_scopes_.size(), local_scopes_[i], - places_[i], communication_dev_ctx); + result->CreateEmptyNode("scale_loss_grad"), local_scopes_.size(), + local_scopes_[i], places_[i], communication_dev_ctx); result->Get("ops").emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale @@ -495,7 +498,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { const std::string grad_var_name = GradVarName(loss_var_name_); auto &vars = result->Get("vars")[i][grad_var_name]; size_t version = vars.size(); - auto var = new VarHandle(result->CreateVarNode(grad_var_name), version, i, + auto var = new VarHandle(result->CreateEmptyNode(grad_var_name), version, i, grad_var_name, places_[i]); vars.emplace_back(var); op_handle->AddOutput(var); @@ -508,8 +511,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->Get("ops").emplace_back(new ComputationOpHandle( - result->CreateOpNode(node->Op()), *node->Op(), s, p)); + result->Get("ops").emplace_back( + new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); CreateOpHandleIOs(result, node, scope_idx); } } @@ -519,10 +522,10 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, int dst_dev_id) const { #ifdef PADDLE_WITH_CUDA result->Get("ops").emplace_back(new ReduceOpHandle( - result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); + result->CreateEmptyNode("reduce"), local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new ReduceOpHandle( - result->CreateOpNode(nullptr), local_scopes_, places_)); + result->CreateEmptyNode("reduce"), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -535,7 +538,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, op_handle->AddInput(prev_grad.get()); } auto &vars = result->Get("vars")[dst_dev_id][og]; - auto var = new VarHandle(result->CreateVarNode(og), vars.size(), dst_dev_id, + auto var = new VarHandle(result->CreateEmptyNode(og), vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); @@ -548,7 +551,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { for (auto &prev_op : result->Get("ops")) { if (prev_op->Name() == prev_op_name) { - auto *dep_var = new DummyVarHandle(result->CreateVarNode("dummy")); + auto *dep_var = new DummyVarHandle(result->CreateEmptyNode("dummy")); prev_op->AddOutput(dep_var); result->Get("dep_vars").emplace(dep_var); op->AddInput(dep_var); @@ -562,10 +565,10 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, std::vector input_var_names; std::vector output_var_names; for (ir::Node *input : node->inputs) { - input_var_names.push_back(input->Var()->Name()); + input_var_names.push_back(input->Name()); } for (ir::Node *output : node->outputs) { - output_var_names.push_back(output->Var()->Name()); + output_var_names.push_back(output->Name()); } if (node->Op()->Type() == "split_byref" || @@ -606,16 +609,16 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const { int op_dev_id = -1; if (node->Op()->Type() == "send") { - op_dev_id = GetVarDeviceID(node->inputs[0]->Var()->Name()); + op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); // the variable name which contains .block means it was splited by // split_byref op // so that we can balance the variable blocks to all the pserver // instances. if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && - node->inputs[0]->Var()->Name().find(".block") == std::string::npos) { + node->inputs[0]->Name().find(".block") == std::string::npos) { std::vector input_var_names; for (ir::Node *n : node->inputs) { - input_var_names.push_back(n->Var()->Name()); + input_var_names.push_back(n->Name()); } op_dev_id = GetAppropriateDeviceID(input_var_names); for (auto &varname : input_var_names) { @@ -625,7 +628,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const { } else if (node->Op()->Type() == "recv") { std::vector output_var_names; for (ir::Node *n : node->outputs) { - output_var_names.push_back(n->Var()->Name()); + output_var_names.push_back(n->Name()); } op_dev_id = GetAppropriateDeviceID(output_var_names); for (auto &varname : output_var_names) { diff --git a/paddle/fluid/framework/details/reduce_op_handle_test.cc b/paddle/fluid/framework/details/reduce_op_handle_test.cc index e7c83ffd320..3a9a5841239 100644 --- a/paddle/fluid/framework/details/reduce_op_handle_test.cc +++ b/paddle/fluid/framework/details/reduce_op_handle_test.cc @@ -97,7 +97,7 @@ struct TestReduceOpHandle { } param_scopes_[out_scope_idx]->Var("out"); - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node")); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_, @@ -121,7 +121,7 @@ struct TestReduceOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node1")); auto *in_var_handle = new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); in_var_handle->ClearGeneratedOp(); @@ -137,7 +137,7 @@ struct TestReduceOpHandle { op_handle_->AddInput(in_dummy_var_handle); // add output - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node2")); auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx, "out", gpu_list_[out_scope_idx]); vars_.emplace_back(out_var_handle); diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 6a8bd7875cb..884fc645555 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -37,7 +37,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { continue; } - auto *dep_var = new DummyVarHandle(graph->CreateVarNode("dummy")); + auto *dep_var = new DummyVarHandle(graph->CreateEmptyNode("dummy")); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); graph->Get("dep_vars").emplace(dep_var); @@ -51,11 +51,16 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset) { auto &var_holders = graph->Get("vars")[place_offset]; - auto &var_holder = var_holders[node->Var()->Name()]; + auto &var_holder = var_holders[node->Name()]; VarHandle *var = nullptr; if (var_holder.empty()) { - var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, - node->Var()->Name(), place); + if (node->NodeType() == ir::Node::Type::kVariable) { + var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, + node->Name(), place); + } else { + var = new VarHandle(graph->CreateEmptyNode(node->Name()), 0, place_offset, + node->Name(), place); + } var_holder.emplace_back(var); } else { var = var_holder.rbegin()->get(); @@ -67,10 +72,10 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, ir::Node *node, const platform::Place &place, size_t place_offset) { - auto &vars = graph->Get("vars")[place_offset][node->Var()->Name()]; + auto &vars = graph->Get("vars")[place_offset][node->Name()]; size_t version = vars.size(); auto var = new VarHandle(graph->CreateVarNode(node->Var()), version, - place_offset, node->Var()->Name(), place); + place_offset, node->Name(), place); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -82,7 +87,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { if (!op->Outputs().empty()) { continue; } - auto *dummy_leaf = new DummyVarHandle(graph->CreateVarNode("dummy")); + auto *dummy_leaf = new DummyVarHandle(graph->CreateEmptyNode("dummy")); graph->Get("dep_vars").emplace(dummy_leaf); op->AddOutput(dummy_leaf); } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 9a2413118e2..8c9cb7cabb8 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -173,7 +173,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( auto &var_name = fetch_tensors[i]; auto &vars = fetched_vars.at(var_name); - ir::Node *fetch_n = new ir::Node(ir::Node::Type::kOperation); + ir::Node *fetch_n = new ir::Node("fetch"); auto *op = new FetchOpHandle(fetch_n, fetch_data, i, &local_scopes_); temp_nodes->emplace_back(fetch_n); fetch_ops->emplace_back(op); @@ -186,7 +186,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( op->AddInput(var); } - ir::Node *dummy_n = new ir::Node(ir::Node::Type::kVariable); + ir::Node *dummy_n = new ir::Node("fetch"); auto *fetch_dummy = new DummyVarHandle(dummy_n); op->AddOutput(fetch_dummy); fetch_dependencies->emplace(fetch_dummy); diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index c1f8f917c4c..14d697c509d 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -35,19 +35,15 @@ std::unique_ptr ProgramToGraph(const ProgramDesc &program) { if (all_vars.count(each_var_name) != 0) { var = graph->CreateVarNode(all_vars.at(each_var_name)); } else { - var = graph->CreateVarNode(each_var_name); + LOG(ERROR) << "input var not in all_var list: " << each_var_name; + var = graph->CreateEmptyNode(each_var_name); } node->inputs.push_back(var); var->outputs.push_back(node); } for (auto &each_var_name : op->OutputArgumentNames()) { - ir::Node *var = nullptr; - if (all_vars.count(each_var_name) != 0) { - var = graph->CreateVarNode(all_vars.at(each_var_name)); - } else { - var = graph->CreateVarNode(each_var_name); - } + ir::Node *var = graph->CreateVarNode(all_vars.at(each_var_name)); node->outputs.push_back(var); var->inputs.push_back(node); } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index ff4f31fb7af..8b185f96254 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -72,16 +72,14 @@ class Graph { } // TODO(panyx0718): Need to handle CreateOpNode(nullptr). - ir::Node* CreateVarNode(const std::string& var_name) { - var_descs_.emplace_back(new VarDesc(var_name)); - nodes.emplace_back(new ir::Node(var_descs_.back().get())); + ir::Node* CreateEmptyNode(const std::string& name) { + nodes.emplace_back(new ir::Node(name)); return nodes.back().get(); } std::vector inputs; std::vector outputs; std::vector> nodes; - std::vector> var_descs_; private: // NOTE: program_ shouldn't be exposed to user. diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 0e0b81a7b14..d2d08bc461a 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -32,51 +32,43 @@ namespace ir { class Node { public: - enum class Type { kNone = -1, kOperation, kVariable }; + enum class Type { kNone, kOperation, kVariable }; + explicit Node(const std::string& name) + : name_(name), + var_desc_(nullptr), + op_desc_(nullptr), + type_(Type::kNone) {} - Node() : type_(Type::kNone) {} - - explicit Node(Type type) : type_(type) {} + explicit Node(VarDesc* var_desc) + : name_(var_desc->Name()), + var_desc_(var_desc), + op_desc_(nullptr), + type_(Type::kVariable) {} - virtual ~Node() { - for (auto& attr : attrs_) { - if (attr_dels_.find(attr.first) != attr_dels_.end()) { - attr_dels_[attr.first](); - } - } - attr_dels_.clear(); - attrs_.clear(); - } + explicit Node(OpDesc* op_desc) + : name_(op_desc->Type()), + var_desc_(nullptr), + op_desc_(op_desc), + type_(Type::kOperation) {} Type NodeType() const { return type_; } - template - void Set(const std::string& name, AttrType attr) { - attrs_[name] = attr; - } + std::string Name() const { return name_; } - template - void Set(const std::string& name, AttrType* attr, - std::function attr_del) { - attrs_[name] = attr; - attr_dels_[name] = attr_del; + VarDesc* Var() { + PADDLE_ENFORCE(type_ == Type::kVariable); + return var_desc_; + } + OpDesc* Op() { + PADDLE_ENFORCE(type_ == Type::kOperation); + return op_desc_; } - - VarDesc* Var() { return var_desc_; } - OpDesc* Op() { return op_desc_; } - - explicit Node(VarDesc* var_desc) - : var_desc_(var_desc), op_desc_(nullptr), type_(Type::kVariable) {} - - explicit Node(OpDesc* op_desc) - : var_desc_(nullptr), op_desc_(op_desc), type_(Type::kOperation) {} std::vector inputs; std::vector outputs; protected: - std::map attrs_; - std::map> attr_dels_; + const std::string name_; VarDesc* var_desc_; OpDesc* op_desc_; Type type_; diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 59789c6defa..10028a8c6e3 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -148,7 +148,6 @@ class ParallelExecutor(object): lambda var: var.persistable and var.type != core.VarDesc.VarType.RAW, main.list_vars()) ] - sys.stderr.write('!!!!!!!!before\n') self.executor = core.ParallelExecutor( self._places, @@ -159,7 +158,6 @@ class ParallelExecutor(object): set(self.persistable_vars), main.desc, loss_name if loss_name else '', scope, local_scopes, exec_strategy, build_strategy, num_trainers, trainer_id) - sys.stderr.write('!!!!!!!!after\n') self.scope = scope def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True): -- GitLab