diff --git a/doc/fluid/design/ir/draft.md b/doc/fluid/design/ir/draft.md index e141ce095910ac9cf6061b99c7fabdcaf89d1fa6..c29337cba1fe859e4968cb800e4e7d9ff6a54d31 100644 --- a/doc/fluid/design/ir/draft.md +++ b/doc/fluid/design/ir/draft.md @@ -64,6 +64,41 @@ can also contain other things that describe some properties of the `Graph` or `Graph` nodes. `Attribute` can be passed across `Pass`. However, it should be used with care. +```cpp +class Graph { + public: + explicit Graph(const ProgramDesc &program); + + bool Has(const std::string &attr_name) const; + + template + AttrType &Get(const std::string &attr_name) const; + + template + void Set(const std::string &attr_name, AttrType *attr); + const std::unordered_set &Nodes() const; + + // Create a normal variable with non-null VarDesc. + ir::Node *CreateVarNode(VarDesc *var_desc); + + // Create a normal runnable operator with OpDesc. + ir::Node *CreateOpNode(OpDesc *op_desc); + + // Create a control dependency var that connects 2 operations. The + // var doesn't hold any data. Other than that, it's no different from + // other var, considering dependency analysis. + ir::Node *CreateControlDepVar(); + + // A more free style way of creating a graph node. Mostly use for test + // or "copy" from another node. Avoid using it if possible. + ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type); + + // Clear all node information of the graph and return the ownership of the + // nodes. + std::vector> ReleaseNodes(); +}; +``` + #### Pass `Pass` represents a transformation of `Graph`. Its input @@ -101,13 +136,15 @@ class Pass { // In my_pass.cc class MyPass : public Pass { - public: - std::unique_ptr Apply(std::unique_ptr graph) const override { + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const override { // do something. return graph; } } -REGISTER_PASS(my_pass, MyPass); +REGISTER_PASS(my_pass, MyPass) +.RequirePassAttr("places") +.RequireGraphAttr("dep_vars"); // To use the pass. @@ -132,4 +169,17 @@ maintaining the original modeling logic. * Graph is transformed from raw model logic to a form that is efficient to execute. -Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor +``` +// Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor +auto graph = Graph(program); +graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah)); +// For more complex Pass, Optimize Process can provide Pass attributes. +auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass"); +mem_opt_pass.SetNotOwned("optimize_level", 1); +mem_opt_pass->Apply(std::move(graph)); +graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah)); +graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah)); +Executor exe; +exe.Run(graph); + +``` diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a6bdd12b63d71b6bdf9fcfb72e29bb6a530f2896..bf7d76a8a6e173e648cea5aaba9b7202d787173b 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -1,7 +1,9 @@ cc_library(node SRCS node.cc DEPS proto_desc) cc_library(graph SRCS graph.cc DEPS node) cc_library(graph_helper SRCS graph_helper.cc DEPS graph) -cc_library(pass SRCS pass.cc DEPS graph node) +cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper) -cc_test(graph_test SRCS graph_test.cc DEPS graph op_registry) -cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph_helper op_registry) + +cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) +cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) +cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 78094e46fb1175f6c0d2cb74348eb4ab6e97dea6..c9d55fbf525a1a476ac469e8e57462169a7db2da 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -53,7 +53,8 @@ class Graph { template void Set(const std::string &attr_name, AttrType *attr) { - PADDLE_ENFORCE(attrs_.count(attr_name) == 0); + PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph", + attr_name); attrs_[attr_name] = attr; attr_dels_[attr_name] = [attr, attr_name]() { VLOG(3) << "deleting " << attr_name; diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 2ebc3c7430cc342cfba0c1fbd441177aa70edf62..d7158eba62686be57499df697466797e4034ea8f 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -13,23 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { namespace framework { namespace ir { std::unique_ptr Pass::Apply(std::unique_ptr graph) const { + PADDLE_ENFORCE(!applied_, "Pass can only Apply() once."); + PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty."); for (const std::string& attr : required_pass_attrs_) { PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), - "Required pass atrribute %s not registered.", attr); + "Required pass atrribute %s not set.", attr); } for (const std::string& attr : required_graph_attrs_) { - PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not exist.", + PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.", attr); } auto applied_graph = ApplyImpl(std::move(graph)); // TODO(panyx0718): Add more verifications. PADDLE_ENFORCE(!HasCircle(*applied_graph), "Illegal Pass. Generated graph shouldn't has cycle."); + applied_ = true; return applied_graph; } diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 3f65794fab2e6b4370fca1cc1469eaf2b1aef1aa..0f14083d259172f5b5f1ed80c7d38312d711beb5 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/variant.h" @@ -56,7 +55,8 @@ class Pass { // Set a pointer to the attribute. Pass takes ownership of the attribute. template void Set(const std::string &attr_name, AttrType *attr) { - PADDLE_ENFORCE(attrs_.count(attr_name) == 0); + PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass", + attr_name); attrs_[attr_name] = attr; attr_dels_[attr_name] = [attr, attr_name]() { VLOG(3) << "deleting " << attr_name; @@ -89,6 +89,7 @@ class Pass { required_graph_attrs_.insert(attrs.begin(), attrs.end()); } + mutable bool applied_{false}; std::unordered_set required_pass_attrs_; std::unordered_set required_graph_attrs_; std::map attrs_; @@ -118,14 +119,15 @@ class PassRegistry { return map_.find(pass_type) != map_.end(); } - void Insert(const std::string &type, const PassCreator &pass_creator) { - PADDLE_ENFORCE(!Has(type), "Pass %s has been registered", type); - map_.insert({type, pass_creator}); + void Insert(const std::string &pass_type, const PassCreator &pass_creator) { + PADDLE_ENFORCE(!Has(pass_type), "Pass %s has been registered", pass_type); + map_.insert({pass_type, pass_creator}); } - std::unique_ptr Get(const std::string &type) const { - PADDLE_ENFORCE(Has(type), "Pass %s has not been registered", type); - return map_.at(type)(); + std::unique_ptr Get(const std::string &pass_type) const { + PADDLE_ENFORCE(Has(pass_type), "Pass %s has not been registered", + pass_type); + return map_.at(pass_type)(); } private: diff --git a/paddle/fluid/framework/ir/pass_test.cc b/paddle/fluid/framework/ir/pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5b5011412ed39e033a7a65921e9c64ce2d54c638 --- /dev/null +++ b/paddle/fluid/framework/ir/pass_test.cc @@ -0,0 +1,112 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/pass.h" +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { +void BuildCircleGraph(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable); + + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + + o2->outputs.push_back(v2); + o1->inputs.push_back(v2); + v2->inputs.push_back(o2); + v2->outputs.push_back(o1); +} + +class TestPass : public Pass { + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const { + graph->Set("copy_test_pass_attr", new int); + graph->Set("copy_test_graph_attr", new int); + + int test_pass_attr = this->Get("test_pass_attr"); + graph->Get("copy_test_pass_attr") = test_pass_attr + 1; + + int test_graph_attr = graph->Get("test_graph_attr"); + graph->Get("copy_test_graph_attr") = test_graph_attr + 1; + return graph; + } +}; + +TEST(PassTest, TestPassAttrCheck) { + ProgramDesc prog; + auto pass = PassRegistry::Instance().Get("test_pass"); + std::unique_ptr graph(new Graph(prog)); + std::string exception; + try { + graph = pass->Apply(std::move(graph)); + } catch (paddle::platform::EnforceNotMet e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("test_pass_attr not set") != exception.npos); + + int val = 1; + graph.reset(new Graph(prog)); + pass->SetNotOwned("test_pass_attr", &val); + + try { + graph = pass->Apply(std::move(graph)); + } catch (paddle::platform::EnforceNotMet e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("test_graph_attr not set") != exception.npos); + + graph.reset(new Graph(prog)); + graph->Set("test_graph_attr", new int); + graph->Get("test_graph_attr") = 1; + graph = pass->Apply(std::move(graph)); + ASSERT_EQ(graph->Get("copy_test_pass_attr"), 2); + ASSERT_EQ(graph->Get("copy_test_graph_attr"), 2); + + try { + graph = pass->Apply(std::move(graph)); + } catch (paddle::platform::EnforceNotMet e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("Pass can only Apply() once") != exception.npos); + + pass = PassRegistry::Instance().Get("test_pass"); + pass->SetNotOwned("test_pass_attr", &val); + graph.reset(new Graph(prog)); + BuildCircleGraph(graph.get()); + graph->Set("test_graph_attr", new int); + graph->Get("test_graph_attr") = 2; + try { + auto tmp = pass->Apply(std::move(graph)); + } catch (paddle::platform::EnforceNotMet e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("shouldn't has cycle") != exception.npos); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(test_pass, paddle::framework::ir::TestPass) + .RequirePassAttr("test_pass_attr") + .RequireGraphAttr("test_graph_attr"); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 112b48ca31e59f1fa50044d4dc892a8afc42ddbd..b5f01a9a2b76472063658f1a051a2ee3c65559b7 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -44,7 +44,10 @@ std::unique_ptr ApplyParallelExecutorPass( #else const BuildStrategy &strategy) { #endif + // Convert the program to graph. std::unique_ptr graph(new ir::Graph(main_program)); + + // Apply a graph viz pass to record a graph. if (!strategy.debug_graphviz_path_.empty()) { auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); const std::string graph_path = string::Sprintf( @@ -53,6 +56,7 @@ std::unique_ptr ApplyParallelExecutorPass( graph = viz_pass->Apply(std::move(graph)); } + // Convert graph to run on multi-devices. auto multi_device_pass = ir::PassRegistry::Instance().Get("multi_device_pass"); multi_device_pass->SetNotOwned>("places", @@ -71,6 +75,7 @@ std::unique_ptr ApplyParallelExecutorPass( #endif graph = multi_device_pass->Apply(std::move(graph)); + // Apply a graph print pass to record a graph with device info. if (!strategy.debug_graphviz_path_.empty()) { auto multi_device_print_pass = ir::PassRegistry::Instance().Get("multi_device_print_pass"); @@ -81,17 +86,10 @@ std::unique_ptr ApplyParallelExecutorPass( graph = multi_device_print_pass->Apply(std::move(graph)); } + // Verify that the graph is correct for multi-device executor. auto multi_device_check_pass = ir::PassRegistry::Instance().Get("multi_device_check_pass"); graph = multi_device_check_pass->Apply(std::move(graph)); - - if (!strategy.debug_graphviz_path_.empty()) { - auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); - const std::string graph_path = string::Sprintf( - "%s%s", strategy.debug_graphviz_path_.c_str(), "_before_exec"); - viz_pass->Set("graph_viz_path", new std::string(graph_path)); - graph = viz_pass->Apply(std::move(graph)); - } return graph; }