From 9b9603306c1b2c8fc0ec6ea54b0b289ab974b97b Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 11 Jul 2018 19:59:28 +0800 Subject: [PATCH] graph attrs --- .../details/multi_devices_graph_builder.cc | 115 +++++++----------- .../framework/details/ssa_graph_builder.cc | 17 +-- paddle/fluid/framework/ir/graph.h | 65 +++++++++- 3 files changed, 111 insertions(+), 86 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index da0272d48e9..9ac961f1b15 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -70,8 +70,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, const OpDesc &op, size_t place_id) const { auto p = places_[place_id]; - auto *op_handle = - boost::any_cast(result->attrs["ops"])->back().get(); + auto *op_handle = result->Get("ops").back().get(); op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); @@ -179,13 +178,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( std::unordered_set og_has_been_broadcast; // We cannot invoke resize. It is a bug of GCC 4.8 - result.attrs["vars"] = new std::vector< - std::unordered_map>>>( - places_.size()); - result.attrs["dep_vars"] = - new std::unordered_set>(); - result.attrs["ops"] = new std::vector>(); - + result.Set("vars", new GraphVars(places_.size())); + result.Set("dep_vars", new GraphDepVars); + result.Set("ops", new GraphOps); // find send/recv vars so that we can place the distributed training // realted op in the place 0 auto send_vars = FindDistTrainSendVars(program); @@ -308,13 +303,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( AddOutputToLeafOps(&result); std::unique_ptr ssa_graph(new SSAGraph); - ssa_graph->vars_ = - std::move(*boost::any_cast(graph->attrs["vars"])); - ssa_graph->ops_ = - std::move(*boost::any_cast(graph->attrs["ops"])); - ssa_graph->dep_vars_ = - std::move(*boost::any_cast(graph->attrs["dep_vars"])); - + ssa_graph->vars_ = std::move(*graph->Erase("vars")); + ssa_graph->ops_ = std::move(*graph->Erase("ops")); + ssa_graph->dep_vars_ = std::move(*graph->Erase("dep_vars")); return std::move(ssa_graph); } @@ -347,20 +338,15 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, #else auto *op_handle = new BroadcastOpHandle(local_scopes_, places_); #endif - - boost::any_cast(result->attrs["ops"])->emplace_back(op_handle); - auto *in = boost::any_cast(result->attrs["vars"]) - ->at(src_dev_id) - .at(p_name) - .back() - .get(); + result->Get("ops").emplace_back(op_handle); + auto *in = + result->Get("vars").at(src_dev_id).at(p_name).back().get(); op_handle->AddInput(in); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = - boost::any_cast(result->attrs["vars"])->at(i).at(p_name); + auto &vars = result->Get("vars").at(i).at(p_name); auto *out_var = new VarHandle(vars.size(), i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); @@ -370,28 +356,26 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const { - boost::any_cast(result->attrs["ops"]) - ->emplace_back( - new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); + result->Get("ops").emplace_back( + new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); CreateOpHandleIOs(result, op, dev_id); } void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA - boost::any_cast(result->attrs["ops"]) - ->emplace_back(new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back( + new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); #else - boost::any_cast(result->attrs["ops"]) - ->emplace_back(new AllReduceOpHandle(local_scopes_, places_)); + result->Get("ops").emplace_back( + new AllReduceOpHandle(local_scopes_, places_)); #endif - auto *op_handle = - boost::any_cast(result->attrs["ops"])->back().get(); + auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = (*boost::any_cast(result->attrs["vars"]))[i][og]; + auto &vars = result->Get("vars")[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); @@ -405,21 +389,18 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( Graph *result, const std::vector &datas) const { #ifdef PADDLE_WITH_CUDA - boost::any_cast(result->attrs["ops"]) - ->emplace_back( - new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back( + new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); #else - boost::any_cast(result->attrs["ops"]) - ->emplace_back(new DataBalanceOpHandle(local_scopes_, places_)); + result->Get("ops").emplace_back( + new DataBalanceOpHandle(local_scopes_, places_)); #endif - auto *op_handle = - boost::any_cast(result->attrs["ops"])->back().get(); + auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); for (const std::string &d_name : datas) { - auto &vars = - (*boost::any_cast(result->attrs["vars"]))[i][d_name]; + auto &vars = result->Get("vars")[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); auto var = new VarHandle(vars.size(), i, d_name, p); @@ -480,7 +461,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { auto *op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i], places_[i], communication_dev_ctx); - boost::any_cast(result->attrs["ops"])->emplace_back(op_handle); + result->Get("ops").emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale // factor. So it does not depend on any other operators. @@ -499,8 +480,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - boost::any_cast(result->attrs["ops"]) - ->emplace_back(new ComputationOpHandle(op, s, p)); + result->Get("ops").emplace_back( + new ComputationOpHandle(op, s, p)); CreateOpHandleIOs(result, op, scope_idx); } } @@ -509,25 +490,23 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, const std::string &og, int dst_dev_id) const { #ifdef PADDLE_WITH_CUDA - boost::any_cast(result->attrs["ops"]) - ->emplace_back(new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back( + new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); #else - boost::any_cast(result->attrs["ops"]) - ->emplace_back(new ReduceOpHandle(local_scopes_, places_)); + result->Get("ops").emplace_back( + new ReduceOpHandle(local_scopes_, places_)); #endif - auto *op_handle = - boost::any_cast(result->attrs["ops"])->back().get(); + auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = (*boost::any_cast(result->attrs["vars"]))[i][og]; + auto &vars = result->Get("vars")[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); } - auto &vars = - (*boost::any_cast(result->attrs["vars"]))[dst_dev_id][og]; + auto &vars = result->Get("vars")[dst_dev_id][og]; auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); @@ -538,12 +517,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, // on it. void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { - for (auto &prev_op : (*boost::any_cast(result->attrs["ops"]))) { + for (auto &prev_op : result->Get("ops")) { if (prev_op->Name() == prev_op_name) { auto *dep_var = new DummyVarHandle(); prev_op->AddOutput(dep_var); - boost::any_cast(result->attrs["dep_vars"]) - ->emplace(dep_var); + result->Get("dep_vars").emplace(dep_var); op->AddInput(dep_var); } } @@ -579,8 +557,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, CreateComputationalOp(result, op, op_dev_id); if (op.Type() == "concat") { - ConnectOp(result, - boost::any_cast(result->attrs["ops"])->back().get(), + ConnectOp(result, result->Get("ops").back().get(), "fetch_barrier"); } } @@ -615,22 +592,16 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", op.Type()); - boost::any_cast(result->attrs["ops"]) - ->emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], op.Type(), - places_[op_dev_id])); + result->Get("ops").emplace_back(new RPCOpHandle( + op, local_scopes_[op_dev_id], op.Type(), places_[op_dev_id])); if (op.Type() == "send_barrier") { - ConnectOp(result, - boost::any_cast(result->attrs["ops"])->back().get(), - "send"); + ConnectOp(result, result->Get("ops").back().get(), "send"); } else if (op.Type() == "recv") { - ConnectOp(result, - boost::any_cast(result->attrs["ops"])->back().get(), + ConnectOp(result, result->Get("ops").back().get(), "send_barrier"); } else if (op.Type() == "fetch_barrier") { - ConnectOp(result, - boost::any_cast(result->attrs["ops"])->back().get(), - "recv"); + ConnectOp(result, result->Get("ops").back().get(), "recv"); } else if (op.Type() == "send") { // do nothing } else { diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 2c0873cc878..2508ed0296d 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -18,7 +18,7 @@ namespace paddle { namespace framework { namespace details { void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { - for (auto &var_map : *boost::any_cast(graph->attrs["vars"])) { + for (auto &var_map : graph->Get("vars")) { for (auto &name_pair : var_map) { if (name_pair.second.size() <= 1) { continue; @@ -40,8 +40,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { auto *dep_var = new DummyVarHandle(); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); - boost::any_cast(graph->attrs["dep_vars"]) - ->emplace(dep_var); + graph->Get("dep_vars").emplace(dep_var); } } } @@ -51,8 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( Graph *graph, const std::string &each_var_name, const platform::Place &place, size_t place_offset) { - auto &var_holders = - (*boost::any_cast(graph->attrs["vars"]))[place_offset]; + auto &var_holders = graph->Get("vars")[place_offset]; auto &var_holder = var_holders[each_var_name]; VarHandle *var = nullptr; if (var_holder.empty()) { @@ -68,9 +66,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, const std::string &each_var_name, const platform::Place &place, size_t place_offset) { - auto &vars = - (*boost::any_cast(graph->attrs["vars"]))[place_offset] - [each_var_name]; + auto &vars = graph->Get("vars")[place_offset][each_var_name]; size_t version = vars.size(); auto var = new VarHandle(version, place_offset, each_var_name, place); vars.emplace_back(var); @@ -78,15 +74,14 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, } void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { - GraphOps &all_ops = *boost::any_cast(graph->attrs["ops"]); + GraphOps &all_ops = graph->Get("ops"); for (auto &op : all_ops) { if (!op->Outputs().empty()) { continue; } auto *dummy_leaf = new DummyVarHandle(); - boost::any_cast(graph->attrs["dep_vars"]) - ->emplace(dummy_leaf); + graph->Get("dep_vars").emplace(dummy_leaf); op->AddOutput(dummy_leaf); } } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 8996c2d43ae..f1de4d740d3 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -20,18 +20,77 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/variant.h" namespace paddle { namespace framework { +class Graph; + +template +struct AnyAttr { + public: + explicit AnyAttr(AttrType* attr) : attr_(attr) {} + + AttrType& Get() { return *boost::any_cast(attr_); } + + private: + friend Graph; + + AttrType* Release() { + released_ = true; + return boost::any_cast(attr_); + } + + void Delete() { + if (!released_) { + delete boost::any_cast(attr_); + } + } + + bool released_ = false; + boost::any attr_; +}; + class Graph { public: - std::map attrs; + virtual ~Graph() { + for (auto& attr : attrs) { + attr_dels[attr.first](); + } + attrs.clear(); + attr_dels.clear(); + } + + template + AttrType& Get(const std::string& attr_name) { + return boost::any_cast>(attrs[attr_name]).Get(); + } + + template + void Set(const std::string& attr_name, AttrType* attr) { + AnyAttr any_attr = AnyAttr(attr); + attrs[attr_name] = any_attr; + attr_dels[attr_name] = [&any_attr]() { any_attr.Delete(); }; + } - std::vector inputs; - std::vector outputs; + template + AttrType* Erase(const std::string& attr_name) { + AnyAttr attr_type = + boost::any_cast>(attrs[attr_name]); + attrs.erase(attr_name); + attr_dels.erase(attr_name); + return attr_type.Release(); + } + + std::vector inputs; + std::vector outputs; std::vector> nodes; + std::map attrs; + std::map> attr_dels; + + private: }; } // namespace framework -- GitLab