diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index 4e8bc94cf82627515a578f843590092941a416ed..a71420b8c111c397f710e0c93237ec85f148b1e8 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -30,7 +30,7 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp cc_library(types_lite SRCS types.cc) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) -cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite) +cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite) cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite) add_subdirectory(mir) diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index c66078645f8d3ca766202164d0198cbf9d6ee54e..26dc50ab73e165a1ab26ada668fbc03751f6c893 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -50,5 +50,9 @@ cc_test(test_variable_place_infrence_pass SRCS variable_place_inference_pass_tes ${test_variable_place_infrence_pass_DEPS}) cc_library(pattern_matcher_lite SRCS pattern_matcher.cc DEPS mir_node mir_ssa_graph op_lite) -cc_test(test_pattern_matcher_lite SRCS pattern_matcher_tester.cc DEPS pattern_matcher_lite) - +cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern_matcher_lite) + +cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite) +cc_test(test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS + pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite + mir_passes compatible_pb_lite program_lite ${ops_lite}) diff --git a/paddle/fluid/lite/core/mir/pass_manager.cc b/paddle/fluid/lite/core/mir/pass_manager.cc index 508c2fd5522519793af26973f711c4c7d2b7a7d3..e12246ca83985f2de8473d713de56b25e6da2613 100644 --- a/paddle/fluid/lite/core/mir/pass_manager.cc +++ b/paddle/fluid/lite/core/mir/pass_manager.cc @@ -16,10 +16,6 @@ namespace paddle { namespace lite { -namespace mir { - -PassManager::PassManager() {} - -} // namespace mir +namespace mir {} // namespace mir } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pass_manager.h b/paddle/fluid/lite/core/mir/pass_manager.h index 2fc4654d920583de19db02d0053bfe21282815f2..e80c0c851632add55b9013c6e51954365da13e91 100644 --- a/paddle/fluid/lite/core/mir/pass_manager.h +++ b/paddle/fluid/lite/core/mir/pass_manager.h @@ -30,7 +30,7 @@ class PassManager { return x; } - PassManager(); + PassManager() {} void Run(const std::unique_ptr& graph) { for (auto& pass : passes_) { diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.cc b/paddle/fluid/lite/core/mir/pattern_matcher.cc index 1f5a1c6e4e57278de5976058d2b7b98030baaf73..c7fa42ac5a786e5a8994a5fba3e2d427d752dcad 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher.cc @@ -27,6 +27,30 @@ namespace mir { size_t PMPattern::id_ = 0UL; +PMNode &PMNode::operator>>(PMNode &right) { + pattern_->AddEdge(this, &right); + // automatically add out op link relation. + if (right.IsOp()) { + CHECK(!right.op_type_.empty()); + this->assert_is_op_input(right.op_type_); + } + + return right; +} + +PMNode &PMNode::operator>>(std::vector &nodes) { + for (auto *node : nodes) { + *this >> *node; + } + return *this; +} + +void operator>>(std::vector &others, PMNode &me) { + for (auto *o : others) { + *o >> me; + } +} + PMNode *PMPattern::NewNode(const std::string &name) { if (!name.empty()) { CHECK_EQ(node_map_.count(name), 0UL) @@ -122,9 +146,7 @@ void PatternMatcher::ValidateByNodeRole( // Collect the inlinks and outlinks. std::unordered_set ios; for (auto &item : subgraph) { - if (!item.first->IsIntermediate()) { - ios.insert(item.second); - } + ios.insert(item.second); } for (auto &item : subgraph) { if (item.first->IsIntermediate()) { @@ -400,6 +422,30 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) { return this; } +void GraphSafeRemoveNodes(SSAGraph *graph, + const std::unordered_set &nodes) { + for (auto *node : nodes) { + graph->RemoveNode(node); + } + + for (auto &node : graph->mutable_nodes()) { + for (auto it = node.inlinks.begin(); it != node.inlinks.end();) { + if (nodes.count(*it)) { + it = node.inlinks.erase(it); + } else { + it++; + } + } + for (auto it = node.outlinks.begin(); it != node.outlinks.end();) { + if (nodes.count(*it)) { + it = node.outlinks.erase(it); + } else { + it++; + } + } + } +} + } // namespace mir } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.h b/paddle/fluid/lite/core/mir/pattern_matcher.h index 8ea8f615aeb4a49ac233f8761372757dce1877f9..2241e71af3de9e9692b2fd740c1e91ee7839fa91 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher.h +++ b/paddle/fluid/lite/core/mir/pattern_matcher.h @@ -58,6 +58,15 @@ struct PMNode { PMNode& LinksTo(const std::vector& others); PMNode& LinksFrom(const std::vector& others); + // Link this to another node. + PMNode& operator>>(PMNode& right); + + // Link many nodes to this node. + friend void operator>>(std::vector& others, PMNode& me); + + // Link this to many other nodes. + PMNode& operator>>(std::vector& nodes); + bool Tell(const Node* node) const { if (teller_) return teller_(node); @@ -92,6 +101,20 @@ struct PMNode { return this; } + PMNode* AsVar() { + type_ = Type::kVar; + assert_is_var(); + return this; + } + + PMNode* AsOp(const std::string& op_type) { + type_ = Type::kOp; + assert_is_op(op_type); + return this; + } + + void set_op_type(const std::string& op_type) { op_type_ = op_type; } + bool IsIntermediate() const { return role_ == Role::kIntermediate; } bool IsInput() const { return role_ == Role::kInput; } bool IsOutput() const { return role_ == Role::kOutput; } @@ -141,6 +164,7 @@ struct PMNode { std::vector asserts_; PMPattern* pattern_; std::string name_; + std::string op_type_; Type type_; Role role_{Role::kUnknown}; }; @@ -273,6 +297,10 @@ class PatternMatcher { std::unordered_map> pmnodes2nodes_; }; +// Graph safely remove some nodes, will automatically clean up the edges. +void GraphSafeRemoveNodes(SSAGraph* graph, + const std::unordered_set& nodes); + // Some pre-defined patterns those can be reused in multiple passes. // The related Fluid Layer or Op should be one pattern here for better re-usage // across different fusion. diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..5dc929cda5ee296623ba12a0a2d355c2f71ae7c8 --- /dev/null +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" +#include + +namespace paddle { +namespace lite { +namespace mir { + +void FuseBase::PerformPatternMatcher(SSAGraph *graph) { + LOG(INFO) << "\n" << matcher_.pattern().DotString(); + // Get subgraphs and record the mir::Node pointers for each PMNode. + auto handler = [&](const PatternMatcher::subgraph_t &subgraph, SSAGraph *g) { + // get all the reigistered nodes. + key2nodes_.emplace_back(); + for (auto &item : nodes_) { + key2nodes_.back()[item.first] = subgraph.at(item.second); + } + }; + + matcher_(graph, handler); +} + +void FuseBase::DeleteInterNodes(SSAGraph *graph) { + std::set keys; + for (auto &node : nodes_) { + if (node.second->IsIntermediate()) { + keys.insert(node.first); + } + } + + LOG(INFO) << "keys.size " << keys.size(); + + std::unordered_set nodes2rm; + for (auto &matched : key2nodes_) { + LOG(INFO) << "get matched " << matched.size(); + for (const auto &key : keys) { + nodes2rm.insert(matched.at(key)); + } + } + + LOG(INFO) << "clean nodes " << nodes2rm.size(); + GraphSafeRemoveNodes(graph, nodes2rm); +} + +PMNode *FuseBase::GetOrCreateNode(const std::string &key) { + auto it = nodes_.find(key); + if (it != nodes_.end()) { + return it->second; + } + nodes_.emplace(key, + matcher_.mutable_pattern()->NewNode(patterns::UniqueKey(key))); + it = nodes_.find(key); + return it->second; +} + +PMNode *FuseBase::OpNode(const std::string &key, const std::string &op_type) { + GetOrCreateNode(key)->set_op_type(op_type); + GetOrCreateNode(key)->AsOp(op_type); + return GetOrCreateNode(key); +} + +PMNode *FuseBase::VarNode(const std::string &key) { + GetOrCreateNode(key)->AsVar(); + return GetOrCreateNode(key); +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h new file mode 100644 index 0000000000000000000000000000000000000000..645e33165f4c07c304554d1289c447c59526ea3c --- /dev/null +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h @@ -0,0 +1,78 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/lite/core/mir/node.h" +#include "paddle/fluid/lite/core/mir/pattern_matcher.h" +#include "paddle/fluid/lite/core/mir/ssa_graph.h" + +namespace paddle { +namespace lite { +namespace mir { + +class FuseBase { + public: + using key2nodes_t = std::map; + + virtual ~FuseBase() = default; + + void operator()(SSAGraph* graph) { + BuildPattern(); + PerformPatternMatcher(graph); + + for (const auto& matched : key2nodes_) { + InsertNewNode(graph, matched); + } + + DeleteInterNodes(graph); + } + + // Build a PMPattern using PMNode. + virtual void BuildPattern() = 0; + + // Generate an operator desc with a matched subgraph. + virtual cpp::OpDesc GenOpDesc(const key2nodes_t& matched) = 0; + + PMNode* OpNode(const std::string& key, const std::string& op_type); + + PMNode* VarNode(const std::string& key); + + protected: + virtual void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) = 0; + + private: + void PerformPatternMatcher(SSAGraph* graph); + + // Delete nodes that are marked as Intermediate + void DeleteInterNodes(SSAGraph* graph); + + private: + PMNode* GetOrCreateNode(const std::string& key); + + protected: + PatternMatcher matcher_; + std::map nodes_; + std::vector key2nodes_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc b/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..44f95dab754c70290470773f221255778280f0da --- /dev/null +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc @@ -0,0 +1,154 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" +#include +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" +#include "paddle/fluid/lite/core/program.h" + +namespace paddle { +namespace lite { +namespace mir { + +// An demo. +class FcFuser : public FuseBase { + public: + void BuildPattern() override { + // create nodes. + auto* x = VarNode("x"); + auto* W = VarNode("W"); + auto* b = VarNode("b"); + auto* mul = OpNode("mul", "mul"); + auto* mul_out = VarNode("mul_out"); + auto* add = OpNode("add", "elementwise_add"); + auto* Out = VarNode("Out"); + + // create topology. + // std::vector({W, x}) >> *mul >> *mul_out; + // std::vector({mul_out, b}) >> *add >> *Out; + *W >> *mul; + *x >> *mul >> *mul_out; + *b >> *add; + *mul_out >> *add >> *Out; + + // Some op specialities. + mul_out->AsIntermediate(); + mul->AsIntermediate(); + add->AsIntermediate(); + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto op_desc = GenOpDesc(matched); + auto fc_op = LiteOpRegistry::Global().Create("fc"); + auto mul = matched.at("mul")->stmt()->op; + auto* scope = mul->scope(); + auto& valid_places = mul->valid_places(); + fc_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places); + + IR_NODE_LINK_TO(matched.at("W"), new_op_node); + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(matched.at("b"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("Out")); + } + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("fc"); + op_desc.SetInput("Input", {matched.at("x")->arg()->name}); + op_desc.SetInput("W", {matched.at("W")->arg()->name}); + op_desc.SetInput("Bias", {matched.at("b")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("Out")->arg()->name}); + op_desc.SetAttr("in_num_col_dims", 1); + return op_desc; + } +}; + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + auto* main_block = program_desc->MutableBlock(0); + auto* mul = main_block->AppendOp(); + auto* add = main_block->AppendOp(); + main_block->Var("x"); + main_block->Var("b"); + main_block->Var("mul_out"); + main_block->Var("w"); + main_block->Var("out"); + main_block->Var("out1"); + + scope->Var("w")->GetMutable(); + scope->Var("b")->GetMutable(); + scope->Var("mul_out")->GetMutable(); + scope->Var("w")->GetMutable(); + scope->Var("out")->GetMutable(); + scope->Var("out1")->GetMutable(); + + mul->SetInput("X", {"x"}); + mul->SetInput("Y", {"w"}); + mul->SetOutput("Out", {"mul_out"}); + mul->SetType("mul"); + mul->SetAttr("x_num_col_dims", 1); + mul->SetAttr("y_num_col_dims", 1); + + add->SetInput("X", {"mul_out"}); + add->SetInput("Y", {"b"}); + add->SetOutput("Out", {"out"}); + add->SetType("elementwise_add"); + add->SetAttr("axis", 1); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(pattern_matcher2, graph_test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + + ASSERT_EQ(graph->nodes().size(), + 8UL /*real nodes*/ + 2UL /*feed op + fetch op*/); + Visualize(graph.get()); +} + +TEST(pattern_matcher2, test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + const int num_nodes = graph->nodes().size(); + FcFuser fuser; + fuser(graph.get()); + ASSERT_EQ(graph->nodes().size(), + num_nodes - 3UL /*nodes removed */ + 1UL /* fused fc node*/); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(fc); +USE_LITE_OP(mul); +USE_LITE_OP(elementwise_add); diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_tester.cc b/paddle/fluid/lite/core/mir/pattern_matcher_test.cc similarity index 100% rename from paddle/fluid/lite/core/mir/pattern_matcher_tester.cc rename to paddle/fluid/lite/core/mir/pattern_matcher_test.cc diff --git a/paddle/fluid/lite/core/mir/ssa_graph.cc b/paddle/fluid/lite/core/mir/ssa_graph.cc index d3f33ab406159df68a6d99c45ed81becd03ec740..82507067c4726b271013cf4a69e95c5045b091a8 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph.cc @@ -119,8 +119,7 @@ void SSAGraph::GraphCreateWeightVarNodes(const Program &program) { } Node *SSAGraph::GraphCreateInstructNode( - const Program &program, const std::shared_ptr &op, - const std::vector &valid_places) { + const std::shared_ptr &op, const std::vector &valid_places) { node_storage_.emplace_back(); // TODO(Superjomn) remove one valid_places here. op->SetValidPlaces(valid_places); @@ -141,7 +140,7 @@ void SSAGraph::Build(const Program &program, CHECK(CheckNodesRoleSet()); for (auto &op : program.ops()) { - auto *op_node = GraphCreateInstructNode(program, op, valid_places); + auto *op_node = GraphCreateInstructNode(op, valid_places); for (const std::string &name : op->op_info()->input_names()) { auto *arg = Argument(name); CHECK(arg->IsRoleSet()); @@ -162,6 +161,13 @@ void SSAGraph::Build(const Program &program, CheckValid(); } +void SSAGraph::RemoveNode(const mir::Node *node) { + auto pos = std::find_if(node_storage_.begin(), node_storage_.end(), + [&node](mir::Node &n) { return &n == node; }); + CHECK(pos != node_storage_.end()); + node_storage_.erase(pos); +} + mir::Node *SSAGraph::Argument(const std::string &name) { auto it = arguments_.find(name); CHECK(it != arguments_.end()) << "no argument called " << name; diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 7a80f28b6401e773cec805e4d4544c5f45364739..5cad1478c225a6551fcd653ca4e79b58360e3724 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -38,6 +38,7 @@ class SSAGraph : GraphBase { // @param program: the op program // @param valid_places: the valid places user set for the system. void Build(const Program &program, const std::vector &valid_places); + void RemoveNode(const mir::Node *node); mir::Node *Argument(const std::string &name); @@ -63,12 +64,12 @@ class SSAGraph : GraphBase { CHECK(CheckLinksRoleSet()); } + Node *GraphCreateInstructNode(const std::shared_ptr &op, + const std::vector &valid_places); + private: void GraphCreateTmpVarNodes(const Program &program); void GraphCreateWeightVarNodes(const Program &program); - Node *GraphCreateInstructNode(const Program &program, - const std::shared_ptr &op, - const std::vector &valid_places); // Check the bidirectional connection. bool CheckBidirectionalConnection(); diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index bc30a00a497c307d6a327298b4664fb3d5bcb568..484d22abf52dda9832b524146114e2b2e093bb99 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -61,7 +61,7 @@ std::vector> OpLite::CreateKernels( targets.insert(place.target); } - CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_; + // CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_; VLOG(2) << "op " << op_type_ << " get " << kernels.size() << " kernels"; return kernels; }