diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 4cc6e5727b2bc1b3e1bba03b0a1de40125f1ef54..4050424e7012ae6463d4fea1f99aa94de0fa838b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -68,7 +68,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( } } -void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node, +void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, + ir::Node *node, size_t place_id) const { auto p = places_[place_id]; auto *op_handle = result->Get("ops").back().get(); @@ -192,8 +193,9 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( // to parameter/gradients before optimizer ops, topo sort is insufficient. ( // some optimizer ops might not depend on any nodes), we manually move all // optimizer nodes after last backward nodes. -std::vector SortOpsAndDelayOptimizeOp(const Graph &graph) { - std::vector ret = ir::TopologySort(graph); +// However, the assumption by SSAGraphBuilder should be relaxed in the future. +std::vector SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { + std::vector ret = ir::TopologySortOperations(graph); size_t last_backward = 0; std::vector optimize_ops; std::vector sorted_ret; @@ -232,8 +234,8 @@ std::vector SortOpsAndDelayOptimizeOp(const Graph &graph) { return sorted_ret; } -std::unique_ptr MultiDevSSAGraphBuilder::Apply( - std::unique_ptr graph) const { +std::unique_ptr MultiDevSSAGraphBuilder::Apply( + std::unique_ptr graph) const { // Rebuild the graph structure. std::vector sorted_ops = SortOpsAndDelayOptimizeOp(*graph); auto nodes = std::move(graph->nodes); @@ -245,7 +247,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( } } - Graph &result = *graph; + ir::Graph &result = *graph; std::unordered_set og_has_been_broadcast; // We cannot invoke resize. It is a bug of GCC 4.8 @@ -397,7 +399,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( #endif } -void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, +void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, const std::string &p_name, size_t src_dev_id) const { #ifdef PADDLE_WITH_CUDA @@ -427,7 +429,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, } } -void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, +void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ir::Node *node, int dev_id) const { result->Get("ops").emplace_back( @@ -436,7 +438,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, CreateOpHandleIOs(result, node, dev_id); } -void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, +void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA result->Get("ops").emplace_back(new AllReduceOpHandle( @@ -466,7 +468,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, } void MultiDevSSAGraphBuilder::InsertDataBalanceOp( - Graph *result, const std::vector &datas) const { + ir::Graph *result, const std::vector &datas) const { #ifdef PADDLE_WITH_CUDA result->Get("ops").emplace_back(new DataBalanceOpHandle( result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), @@ -529,7 +531,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { return got == var_name_on_devices_.end() ? -1 : got->second; } -void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { +void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const { for (size_t i = 0; i < places_.size(); ++i) { // Insert ScaleCost OpHandle #ifdef PADDLE_WITH_CUDA @@ -559,7 +561,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { } } -void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, +void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, ir::Node *node, size_t num_places) const { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { @@ -571,7 +573,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, } } -VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, +VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, const std::string &og, int dst_dev_id) const { #ifdef PADDLE_WITH_CUDA @@ -604,7 +606,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, // Find the first occurence of `prev_op_name` and make current `op` depend // on it. -void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, +void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { for (auto &prev_op : result->Get("ops")) { if (prev_op->Name() == prev_op_name) { @@ -617,7 +619,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, } } -void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, +void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ir::Node *node) const { int op_dev_id = -1; std::vector input_var_names; @@ -664,7 +666,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, } // Create RPC related op handles that connects its in ops and out ops. -void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const { +void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, + ir::Node *node) const { int op_dev_id = -1; if (node->Op()->Type() == "send") { op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 2b7f4f586b4e750fde9245286c977258a9db6086..c2c764bb9443ab932f4460341c3abfd403e7b5eb 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -46,11 +46,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::vector &local_scopes, const BuildStrategy &strategy); #endif - std::unique_ptr Apply(std::unique_ptr graph) const override; + std::unique_ptr Apply( + std::unique_ptr graph) const override; int GetVarDeviceID(const std::string &varname) const override; private: - void CreateOpHandleIOs(Graph *result, ir::Node *node, size_t device_id) const; + void CreateOpHandleIOs(ir::Graph *result, ir::Node *node, + size_t device_id) const; private: std::string loss_var_name_; @@ -64,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(ir::Node *node) const; - void CreateRPCOp(Graph *result, ir::Node *node) const; - void CreateDistTrainOp(Graph *result, ir::Node *node) const; + void CreateRPCOp(ir::Graph *result, ir::Node *node) const; + void CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; /** * Is this operator as the end-point operator before/after send operator. @@ -79,16 +81,17 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { std::vector FindDistTrainRecvVars( const std::vector> &nodes) const; - void ConnectOp(Graph *result, OpHandleBase *op, + void ConnectOp(ir::Graph *result, OpHandleBase *op, const std::string &prev_op_name) const; - void CreateComputationalOps(Graph *result, ir::Node *node, + void CreateComputationalOps(ir::Graph *result, ir::Node *node, size_t num_places) const; - void CreateScaleLossGradOp(Graph *result) const; - VarHandle *CreateReduceOp(Graph *result, const std::string &og, + void CreateScaleLossGradOp(ir::Graph *result) const; + VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, int dst_dev_id) const; - void CreateComputationalOp(Graph *result, ir::Node *node, int dev_id) const; + void CreateComputationalOp(ir::Graph *result, ir::Node *node, + int dev_id) const; bool IsParameterGradientOnce( const std::string &og, @@ -96,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { int GetOpDeviceID(ir::Node *node) const; - void InsertAllReduceOp(Graph *result, const std::string &og) const; + void InsertAllReduceOp(ir::Graph *result, const std::string &og) const; - void InsertDataBalanceOp(Graph *result, + void InsertDataBalanceOp(ir::Graph *result, const std::vector &datas) const; - void CreateBroadcastOp(Graph *result, const std::string &p_name, + void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, size_t src_dev_id) const; bool IsSparseGradient(const std::string &og) const; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 2be4bb009eff2866f39f08f11052822eb1fdea5a..3c579f427e5a04ce3c144e3709451247b29c6fcc 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -17,7 +17,7 @@ namespace paddle { namespace framework { namespace details { -void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { +void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { for (auto &var_map : graph->Get("vars")) { for (auto &name_pair : var_map) { if (name_pair.second.size() <= 1) { @@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { } VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( - Graph *graph, ir::Node *node, const platform::Place &place, + ir::Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset) { auto &var_holders = graph->Get("vars")[place_offset]; auto &var_holder = var_holders[node->Name()]; @@ -81,7 +81,7 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( return var; } -void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, +void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, ir::Node *new_node, const platform::Place &place, size_t place_offset) { @@ -93,7 +93,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, op_handle->AddOutput(var); } -void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { +void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) { for (auto &op : graph->Get("ops")) { if (!op->Outputs().empty()) { continue; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index e8e8acdb38f893302fb92c47d6f1cb2d38453e0f..f64445b470a76766f5a8e6e106418a14f352ef11 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -64,19 +64,19 @@ class SSAGraphBuilder : public ir::Pass { * * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) */ - static void PolishGraphToSupportDataHazards(Graph *graph); + static void PolishGraphToSupportDataHazards(ir::Graph *graph); - static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, ir::Node *node, + static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset); // Add an output variable (each_var_name, place, place_offset) to op_handle, // which belongs to graph - static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle, + static void CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, ir::Node *new_node, const platform::Place &place, size_t place_offset); - static void AddOutputToLeafOps(Graph *graph); + static void AddOutputToLeafOps(ir::Graph *graph); }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc index 7c79d7f1e881c67514634d56caa715c41927dbce..0438b096109a287366610d06ef2bd14c765a8f43 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.cc +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -20,7 +20,7 @@ namespace paddle { namespace framework { namespace details { -bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const { +bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { std::unordered_map pending_ops; std::unordered_set pending_vars; std::unordered_set ready_vars; diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 2e397e86825a41765a360d31fa8595d17027f3ec..51ce6e5ecad755613551aa6525b5cfbe4a8933ef 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -28,7 +28,8 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { std::unique_ptr&& builder) : builder_(std::move(builder)) {} - std::unique_ptr Apply(std::unique_ptr graph) const override { + std::unique_ptr Apply( + std::unique_ptr graph) const override { auto new_graph = builder_->Apply(std::move(graph)); PADDLE_ENFORCE(IsValidGraph(new_graph.get())); return new_graph; @@ -38,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { return builder_->GetVarDeviceID(var_name); } - bool IsValidGraph(const Graph* graph) const; + bool IsValidGraph(const ir::Graph* graph) const; private: std::unique_ptr builder_; diff --git a/paddle/fluid/framework/details/ssa_graph_printer.cc b/paddle/fluid/framework/details/ssa_graph_printer.cc index 6dd6fd262e35a192ba85eb3aa16660526d2ebca2..20aab1464400aa9bb1bd6af11c06269c688a8308 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.cc +++ b/paddle/fluid/framework/details/ssa_graph_printer.cc @@ -21,7 +21,7 @@ namespace framework { namespace details { template -static inline void IterAllVar(const Graph &graph, Callback callback) { +static inline void IterAllVar(const ir::Graph &graph, Callback callback) { for (auto &each : graph.Get("vars")) { for (auto &pair1 : each) { for (auto &pair2 : pair1.second) { @@ -35,7 +35,7 @@ static inline void IterAllVar(const Graph &graph, Callback callback) { } } -void GraphvizSSAGraphPrinter::Print(const Graph &graph, +void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, std::ostream &sout) const { size_t var_id = 0; std::unordered_map vars; diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index cd72162f44ca76aa6340606cf79a73601eae89af..a77c1bad3f15bca9064ded860696eb68b033b090 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -25,12 +25,12 @@ namespace details { class SSAGraphPrinter { public: virtual ~SSAGraphPrinter() {} - virtual void Print(const Graph& graph, std::ostream& sout) const = 0; + virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0; }; class GraphvizSSAGraphPrinter : public SSAGraphPrinter { public: - void Print(const Graph& graph, std::ostream& sout) const override; + void Print(const ir::Graph& graph, std::ostream& sout) const override; }; class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { @@ -50,7 +50,8 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { stream_ptr_(std::move(sout)), stream_ref_(*stream_ptr_) {} - std::unique_ptr Apply(std::unique_ptr graph) const override { + std::unique_ptr Apply( + std::unique_ptr graph) const override { auto new_graph = builder_->Apply(std::move(graph)); printer_->Print(*new_graph, stream_ref_); return new_graph; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index f85c62dd6c4a8033a037b1e001ece6a9cc90ca98..c19f74476f9a1498a7d61f5faf204e9966aea155 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -21,7 +21,8 @@ namespace framework { namespace details { ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, - const std::vector &places, std::unique_ptr &&graph) + const std::vector &places, + std::unique_ptr &&graph) : graph_(std::move(graph)), pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) : nullptr), diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index bf7c0a367a19ff4ac9462334516f1577672faa68..3d67daa45e20fdea52689684397ad01f2f4cd783 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -40,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, - std::unique_ptr &&graph); + std::unique_ptr &&graph); // Run a SSAGraph by a thread pool // Use topological sort algorithm @@ -53,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { details::OpHandleBase *op); private: - std::unique_ptr graph_; + std::unique_ptr graph_; std::unique_ptr<::ThreadPool> pool_; std::vector local_scopes_; std::vector places_; diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 46640fedcce16079abdebe61bc3c8fb87f8822eb..18211f2e2b0e2b50c51b378f0b36a0ece7d5926f 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -22,6 +22,7 @@ limitations under the License. */ namespace paddle { namespace framework { +namespace ir { /* namespace { void SortHelper( @@ -41,7 +42,7 @@ void SortHelper( ret->push_back(node); } -std::vector TopologySort( +std::vector TopologySortOperations( const std::map> &adj_list) { std::unordered_set visited; std::vector ret; @@ -156,7 +157,7 @@ bool HasCircle(const std::map> return false; } -std::map> BuildAdjList( +std::map> BuildOperationAdjList( const std::vector &nodes) { std::map> adj_list; @@ -178,17 +179,17 @@ std::map> BuildAdjList( return adj_list; } -std::vector TopologySortOperationFromInToOut( +std::vector TopologySortOperationsOperationFromInToOut( const std::vector> &nodes) { std::vector tmp; for (auto& n : nodes) { tmp.push_back(n.get()); } std::map> adj_list = -BuildAdjList(tmp); +BuildOperationAdjList(tmp); PADDLE_ENFORCE(!HasCircle(adj_list)); - std::vector ret = TopologySort(adj_list); + std::vector ret = TopologySortOperations(adj_list); ir::Node *last_backward = nullptr; std::vector optimize_ops; @@ -235,5 +236,6 @@ BuildAdjList(tmp); return ret; }*/ +} // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index b4ac135b029005b723abca2cb9b9a9aa175eda40..a1e39b08a4da703e494dfed87578179bfed17597 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -26,13 +26,13 @@ limitations under the License. */ namespace paddle { namespace framework { - +namespace ir { class Graph { public: - explicit Graph(const ProgramDesc& program); + explicit Graph(const ProgramDesc &program); virtual ~Graph() { - for (auto& attr : attrs_) { + for (auto &attr : attrs_) { attr_dels_[attr.first](); } attrs_.clear(); @@ -40,12 +40,12 @@ class Graph { } template - AttrType& Get(const std::string& attr_name) const { - return *boost::any_cast(attrs_.at(attr_name)); + AttrType &Get(const std::string &attr_name) const { + return *boost::any_cast(attrs_.at(attr_name)); } template - void Set(const std::string& attr_name, AttrType* attr) { + void Set(const std::string &attr_name, AttrType *attr) { PADDLE_ENFORCE(attrs_.count(attr_name) == 0); attrs_[attr_name] = attr; attr_dels_[attr_name] = [attr, attr_name]() { @@ -54,17 +54,17 @@ class Graph { }; } - ir::Node* CreateVarNode(VarDesc* var_desc) { + ir::Node *CreateVarNode(VarDesc *var_desc) { nodes.emplace_back(new ir::Node(var_desc)); return nodes.back().get(); } - ir::Node* CreateOpNode(OpDesc* op_desc) { + ir::Node *CreateOpNode(OpDesc *op_desc) { nodes.emplace_back(new ir::Node(op_desc)); return nodes.back().get(); } - ir::Node* CreateEmptyNode(const std::string& name, ir::Node::Type type) { + ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { nodes.emplace_back(new ir::Node(name, type)); return nodes.back().get(); } @@ -73,10 +73,10 @@ class Graph { private: // NOTE: program_ shouldn't be exposed to user. - const ProgramDesc& program_; + const ProgramDesc &program_; std::map attrs_; std::map> attr_dels_; }; - +} // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index ecd90f4f3ec699298793a58ec3ce3d2d4a41ab03..76458be135018692b1b766dedbadada8340a0423 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -64,7 +64,7 @@ bool HasCircleHelper( bool HasCircle(const Graph &graph) { std::map> adj_list = - BuildAdjList(graph); + BuildOperationAdjList(graph); std::unordered_set visited; std::unordered_set in_trace; @@ -76,9 +76,9 @@ bool HasCircle(const Graph &graph) { return false; } -std::vector TopologySort(const Graph &graph) { +std::vector TopologySortOperations(const Graph &graph) { std::map> adj_list = - BuildAdjList(graph); + BuildOperationAdjList(graph); std::unordered_set visited; std::vector ret; for (auto adj : adj_list) { @@ -89,7 +89,7 @@ std::vector TopologySort(const Graph &graph) { return ret; } -std::map> BuildAdjList( +std::map> BuildOperationAdjList( const Graph &graph) { std::map> adj_list; diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index b8714eb5be03143657c22903b0283af7d76f83bd..55b2e3f5ca67a46d692b409e3bc64793883bb05a 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -26,9 +26,9 @@ namespace framework { namespace ir { bool HasCircle(const Graph &graph); -std::vector TopologySort(const Graph &graph); +std::vector TopologySortOperations(const Graph &graph); -std::map> BuildAdjList( +std::map> BuildOperationAdjList( const Graph &graph); } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 186047b370c778a43a2828d249716ea0bafee39e..f8fbd2242d5922591c07621a812ccd2531410861 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -93,7 +93,7 @@ TEST(GraphTest, Basic) { ASSERT_EQ(proto::VarType::LOD_TENSOR, prog.MutableBlock(0)->Var("test_out")->GetType()); - std::unique_ptr g(new Graph(prog)); + std::unique_ptr g(new ir::Graph(prog)); ASSERT_EQ(g->nodes[0]->Name(), "sum"); ASSERT_EQ(g->nodes[0]->inputs[0]->Name(), "test_a"); ASSERT_EQ(g->nodes[0]->inputs[1]->Name(), "test_b"); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 1e5bba62b53025dacdbf2d74b35f266cf4e422c2..02c836bea194553bb9c4bc5677cc408dd302e9ce 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -132,7 +132,7 @@ ParallelExecutor::ParallelExecutor( #endif } builder_ = builder_factory.Create(); - std::unique_ptr graph(new Graph(main_program)); + std::unique_ptr graph(new ir::Graph(main_program)); graph = builder_->Apply(std::move(graph)); member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, std::move(graph)));