提交 7dbdcddb 编写于 作者: Z Zhen Wang 提交者: GitHub

simplify fuse process. (#17954)

* simplify fuse process.

* update from upstream

* close the kernels not empty checker.
上级 fb110f72
...@@ -30,7 +30,7 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp ...@@ -30,7 +30,7 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp
cc_library(types_lite SRCS types.cc) cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite) cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite) cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite)
add_subdirectory(mir) add_subdirectory(mir)
......
...@@ -50,5 +50,9 @@ cc_test(test_variable_place_infrence_pass SRCS variable_place_inference_pass_tes ...@@ -50,5 +50,9 @@ cc_test(test_variable_place_infrence_pass SRCS variable_place_inference_pass_tes
${test_variable_place_infrence_pass_DEPS}) ${test_variable_place_infrence_pass_DEPS})
cc_library(pattern_matcher_lite SRCS pattern_matcher.cc DEPS mir_node mir_ssa_graph op_lite) cc_library(pattern_matcher_lite SRCS pattern_matcher.cc DEPS mir_node mir_ssa_graph op_lite)
cc_test(test_pattern_matcher_lite SRCS pattern_matcher_tester.cc DEPS pattern_matcher_lite) cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern_matcher_lite)
cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite)
cc_test(test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS
pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite
mir_passes compatible_pb_lite program_lite ${ops_lite})
...@@ -16,10 +16,6 @@ ...@@ -16,10 +16,6 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {} // namespace mir
PassManager::PassManager() {}
} // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -30,7 +30,7 @@ class PassManager { ...@@ -30,7 +30,7 @@ class PassManager {
return x; return x;
} }
PassManager(); PassManager() {}
void Run(const std::unique_ptr<SSAGraph>& graph) { void Run(const std::unique_ptr<SSAGraph>& graph) {
for (auto& pass : passes_) { for (auto& pass : passes_) {
......
...@@ -27,6 +27,30 @@ namespace mir { ...@@ -27,6 +27,30 @@ namespace mir {
size_t PMPattern::id_ = 0UL; size_t PMPattern::id_ = 0UL;
PMNode &PMNode::operator>>(PMNode &right) {
pattern_->AddEdge(this, &right);
// automatically add out op link relation.
if (right.IsOp()) {
CHECK(!right.op_type_.empty());
this->assert_is_op_input(right.op_type_);
}
return right;
}
PMNode &PMNode::operator>>(std::vector<PMNode *> &nodes) {
for (auto *node : nodes) {
*this >> *node;
}
return *this;
}
void operator>>(std::vector<PMNode *> &others, PMNode &me) {
for (auto *o : others) {
*o >> me;
}
}
PMNode *PMPattern::NewNode(const std::string &name) { PMNode *PMPattern::NewNode(const std::string &name) {
if (!name.empty()) { if (!name.empty()) {
CHECK_EQ(node_map_.count(name), 0UL) CHECK_EQ(node_map_.count(name), 0UL)
...@@ -122,9 +146,7 @@ void PatternMatcher::ValidateByNodeRole( ...@@ -122,9 +146,7 @@ void PatternMatcher::ValidateByNodeRole(
// Collect the inlinks and outlinks. // Collect the inlinks and outlinks.
std::unordered_set<Node *> ios; std::unordered_set<Node *> ios;
for (auto &item : subgraph) { for (auto &item : subgraph) {
if (!item.first->IsIntermediate()) { ios.insert(item.second);
ios.insert(item.second);
}
} }
for (auto &item : subgraph) { for (auto &item : subgraph) {
if (item.first->IsIntermediate()) { if (item.first->IsIntermediate()) {
...@@ -400,6 +422,30 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) { ...@@ -400,6 +422,30 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) {
return this; return this;
} }
void GraphSafeRemoveNodes(SSAGraph *graph,
const std::unordered_set<const Node *> &nodes) {
for (auto *node : nodes) {
graph->RemoveNode(node);
}
for (auto &node : graph->mutable_nodes()) {
for (auto it = node.inlinks.begin(); it != node.inlinks.end();) {
if (nodes.count(*it)) {
it = node.inlinks.erase(it);
} else {
it++;
}
}
for (auto it = node.outlinks.begin(); it != node.outlinks.end();) {
if (nodes.count(*it)) {
it = node.outlinks.erase(it);
} else {
it++;
}
}
}
}
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -58,6 +58,15 @@ struct PMNode { ...@@ -58,6 +58,15 @@ struct PMNode {
PMNode& LinksTo(const std::vector<PMNode*>& others); PMNode& LinksTo(const std::vector<PMNode*>& others);
PMNode& LinksFrom(const std::vector<PMNode*>& others); PMNode& LinksFrom(const std::vector<PMNode*>& others);
// Link this to another node.
PMNode& operator>>(PMNode& right);
// Link many nodes to this node.
friend void operator>>(std::vector<PMNode*>& others, PMNode& me);
// Link this to many other nodes.
PMNode& operator>>(std::vector<PMNode*>& nodes);
bool Tell(const Node* node) const { bool Tell(const Node* node) const {
if (teller_) return teller_(node); if (teller_) return teller_(node);
...@@ -92,6 +101,20 @@ struct PMNode { ...@@ -92,6 +101,20 @@ struct PMNode {
return this; return this;
} }
PMNode* AsVar() {
type_ = Type::kVar;
assert_is_var();
return this;
}
PMNode* AsOp(const std::string& op_type) {
type_ = Type::kOp;
assert_is_op(op_type);
return this;
}
void set_op_type(const std::string& op_type) { op_type_ = op_type; }
bool IsIntermediate() const { return role_ == Role::kIntermediate; } bool IsIntermediate() const { return role_ == Role::kIntermediate; }
bool IsInput() const { return role_ == Role::kInput; } bool IsInput() const { return role_ == Role::kInput; }
bool IsOutput() const { return role_ == Role::kOutput; } bool IsOutput() const { return role_ == Role::kOutput; }
...@@ -141,6 +164,7 @@ struct PMNode { ...@@ -141,6 +164,7 @@ struct PMNode {
std::vector<teller_t> asserts_; std::vector<teller_t> asserts_;
PMPattern* pattern_; PMPattern* pattern_;
std::string name_; std::string name_;
std::string op_type_;
Type type_; Type type_;
Role role_{Role::kUnknown}; Role role_{Role::kUnknown};
}; };
...@@ -273,6 +297,10 @@ class PatternMatcher { ...@@ -273,6 +297,10 @@ class PatternMatcher {
std::unordered_map<const PMNode*, std::unordered_set<Node*>> pmnodes2nodes_; std::unordered_map<const PMNode*, std::unordered_set<Node*>> pmnodes2nodes_;
}; };
// Graph safely remove some nodes, will automatically clean up the edges.
void GraphSafeRemoveNodes(SSAGraph* graph,
const std::unordered_set<const Node*>& nodes);
// Some pre-defined patterns those can be reused in multiple passes. // Some pre-defined patterns those can be reused in multiple passes.
// The related Fluid Layer or Op should be one pattern here for better re-usage // The related Fluid Layer or Op should be one pattern here for better re-usage
// across different fusion. // across different fusion.
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
#include <glog/logging.h>
namespace paddle {
namespace lite {
namespace mir {
void FuseBase::PerformPatternMatcher(SSAGraph *graph) {
LOG(INFO) << "\n" << matcher_.pattern().DotString();
// Get subgraphs and record the mir::Node pointers for each PMNode.
auto handler = [&](const PatternMatcher::subgraph_t &subgraph, SSAGraph *g) {
// get all the reigistered nodes.
key2nodes_.emplace_back();
for (auto &item : nodes_) {
key2nodes_.back()[item.first] = subgraph.at(item.second);
}
};
matcher_(graph, handler);
}
void FuseBase::DeleteInterNodes(SSAGraph *graph) {
std::set<std::string> keys;
for (auto &node : nodes_) {
if (node.second->IsIntermediate()) {
keys.insert(node.first);
}
}
LOG(INFO) << "keys.size " << keys.size();
std::unordered_set<const Node *> nodes2rm;
for (auto &matched : key2nodes_) {
LOG(INFO) << "get matched " << matched.size();
for (const auto &key : keys) {
nodes2rm.insert(matched.at(key));
}
}
LOG(INFO) << "clean nodes " << nodes2rm.size();
GraphSafeRemoveNodes(graph, nodes2rm);
}
PMNode *FuseBase::GetOrCreateNode(const std::string &key) {
auto it = nodes_.find(key);
if (it != nodes_.end()) {
return it->second;
}
nodes_.emplace(key,
matcher_.mutable_pattern()->NewNode(patterns::UniqueKey(key)));
it = nodes_.find(key);
return it->second;
}
PMNode *FuseBase::OpNode(const std::string &key, const std::string &op_type) {
GetOrCreateNode(key)->set_op_type(op_type);
GetOrCreateNode(key)->AsOp(op_type);
return GetOrCreateNode(key);
}
PMNode *FuseBase::VarNode(const std::string &key) {
GetOrCreateNode(key)->AsVar();
return GetOrCreateNode(key);
}
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/mir/pattern_matcher.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace paddle {
namespace lite {
namespace mir {
class FuseBase {
public:
using key2nodes_t = std::map<std::string, Node*>;
virtual ~FuseBase() = default;
void operator()(SSAGraph* graph) {
BuildPattern();
PerformPatternMatcher(graph);
for (const auto& matched : key2nodes_) {
InsertNewNode(graph, matched);
}
DeleteInterNodes(graph);
}
// Build a PMPattern using PMNode.
virtual void BuildPattern() = 0;
// Generate an operator desc with a matched subgraph.
virtual cpp::OpDesc GenOpDesc(const key2nodes_t& matched) = 0;
PMNode* OpNode(const std::string& key, const std::string& op_type);
PMNode* VarNode(const std::string& key);
protected:
virtual void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) = 0;
private:
void PerformPatternMatcher(SSAGraph* graph);
// Delete nodes that are marked as Intermediate
void DeleteInterNodes(SSAGraph* graph);
private:
PMNode* GetOrCreateNode(const std::string& key);
protected:
PatternMatcher matcher_;
std::map<std::string, PMNode*> nodes_;
std::vector<key2nodes_t> key2nodes_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
#include <gtest/gtest.h>
#include <memory>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/program.h"
namespace paddle {
namespace lite {
namespace mir {
// An demo.
class FcFuser : public FuseBase {
public:
void BuildPattern() override {
// create nodes.
auto* x = VarNode("x");
auto* W = VarNode("W");
auto* b = VarNode("b");
auto* mul = OpNode("mul", "mul");
auto* mul_out = VarNode("mul_out");
auto* add = OpNode("add", "elementwise_add");
auto* Out = VarNode("Out");
// create topology.
// std::vector<PMNode*>({W, x}) >> *mul >> *mul_out;
// std::vector<PMNode*>({mul_out, b}) >> *add >> *Out;
*W >> *mul;
*x >> *mul >> *mul_out;
*b >> *add;
*mul_out >> *add >> *Out;
// Some op specialities.
mul_out->AsIntermediate();
mul->AsIntermediate();
add->AsIntermediate();
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
auto op_desc = GenOpDesc(matched);
auto fc_op = LiteOpRegistry::Global().Create("fc");
auto mul = matched.at("mul")->stmt()->op;
auto* scope = mul->scope();
auto& valid_places = mul->valid_places();
fc_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places);
IR_NODE_LINK_TO(matched.at("W"), new_op_node);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(matched.at("b"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("Out"));
}
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("fc");
op_desc.SetInput("Input", {matched.at("x")->arg()->name});
op_desc.SetInput("W", {matched.at("W")->arg()->name});
op_desc.SetInput("Bias", {matched.at("b")->arg()->name});
op_desc.SetOutput("Out", {matched.at("Out")->arg()->name});
op_desc.SetAttr("in_num_col_dims", 1);
return op_desc;
}
};
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
auto* main_block = program_desc->MutableBlock(0);
auto* mul = main_block->AppendOp();
auto* add = main_block->AppendOp();
main_block->Var("x");
main_block->Var("b");
main_block->Var("mul_out");
main_block->Var("w");
main_block->Var("out");
main_block->Var("out1");
scope->Var("w")->GetMutable<lite::Tensor>();
scope->Var("b")->GetMutable<lite::Tensor>();
scope->Var("mul_out")->GetMutable<lite::Tensor>();
scope->Var("w")->GetMutable<lite::Tensor>();
scope->Var("out")->GetMutable<lite::Tensor>();
scope->Var("out1")->GetMutable<lite::Tensor>();
mul->SetInput("X", {"x"});
mul->SetInput("Y", {"w"});
mul->SetOutput("Out", {"mul_out"});
mul->SetType("mul");
mul->SetAttr("x_num_col_dims", 1);
mul->SetAttr("y_num_col_dims", 1);
add->SetInput("X", {"mul_out"});
add->SetInput("Y", {"b"});
add->SetOutput("Out", {"out"});
add->SetType("elementwise_add");
add->SetAttr("axis", 1);
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
return graph;
}
TEST(pattern_matcher2, graph_test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
ASSERT_EQ(graph->nodes().size(),
8UL /*real nodes*/ + 2UL /*feed op + fetch op*/);
Visualize(graph.get());
}
TEST(pattern_matcher2, test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
const int num_nodes = graph->nodes().size();
FcFuser fuser;
fuser(graph.get());
ASSERT_EQ(graph->nodes().size(),
num_nodes - 3UL /*nodes removed */ + 1UL /* fused fc node*/);
}
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(fc);
USE_LITE_OP(mul);
USE_LITE_OP(elementwise_add);
...@@ -119,8 +119,7 @@ void SSAGraph::GraphCreateWeightVarNodes(const Program &program) { ...@@ -119,8 +119,7 @@ void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
} }
Node *SSAGraph::GraphCreateInstructNode( Node *SSAGraph::GraphCreateInstructNode(
const Program &program, const std::shared_ptr<OpLite> &op, const std::shared_ptr<OpLite> &op, const std::vector<Place> &valid_places) {
const std::vector<Place> &valid_places) {
node_storage_.emplace_back(); node_storage_.emplace_back();
// TODO(Superjomn) remove one valid_places here. // TODO(Superjomn) remove one valid_places here.
op->SetValidPlaces(valid_places); op->SetValidPlaces(valid_places);
...@@ -141,7 +140,7 @@ void SSAGraph::Build(const Program &program, ...@@ -141,7 +140,7 @@ void SSAGraph::Build(const Program &program,
CHECK(CheckNodesRoleSet()); CHECK(CheckNodesRoleSet());
for (auto &op : program.ops()) { for (auto &op : program.ops()) {
auto *op_node = GraphCreateInstructNode(program, op, valid_places); auto *op_node = GraphCreateInstructNode(op, valid_places);
for (const std::string &name : op->op_info()->input_names()) { for (const std::string &name : op->op_info()->input_names()) {
auto *arg = Argument(name); auto *arg = Argument(name);
CHECK(arg->IsRoleSet()); CHECK(arg->IsRoleSet());
...@@ -162,6 +161,13 @@ void SSAGraph::Build(const Program &program, ...@@ -162,6 +161,13 @@ void SSAGraph::Build(const Program &program,
CheckValid(); CheckValid();
} }
void SSAGraph::RemoveNode(const mir::Node *node) {
auto pos = std::find_if(node_storage_.begin(), node_storage_.end(),
[&node](mir::Node &n) { return &n == node; });
CHECK(pos != node_storage_.end());
node_storage_.erase(pos);
}
mir::Node *SSAGraph::Argument(const std::string &name) { mir::Node *SSAGraph::Argument(const std::string &name) {
auto it = arguments_.find(name); auto it = arguments_.find(name);
CHECK(it != arguments_.end()) << "no argument called " << name; CHECK(it != arguments_.end()) << "no argument called " << name;
......
...@@ -38,6 +38,7 @@ class SSAGraph : GraphBase { ...@@ -38,6 +38,7 @@ class SSAGraph : GraphBase {
// @param program: the op program // @param program: the op program
// @param valid_places: the valid places user set for the system. // @param valid_places: the valid places user set for the system.
void Build(const Program &program, const std::vector<Place> &valid_places); void Build(const Program &program, const std::vector<Place> &valid_places);
void RemoveNode(const mir::Node *node);
mir::Node *Argument(const std::string &name); mir::Node *Argument(const std::string &name);
...@@ -63,12 +64,12 @@ class SSAGraph : GraphBase { ...@@ -63,12 +64,12 @@ class SSAGraph : GraphBase {
CHECK(CheckLinksRoleSet()); CHECK(CheckLinksRoleSet());
} }
Node *GraphCreateInstructNode(const std::shared_ptr<OpLite> &op,
const std::vector<Place> &valid_places);
private: private:
void GraphCreateTmpVarNodes(const Program &program); void GraphCreateTmpVarNodes(const Program &program);
void GraphCreateWeightVarNodes(const Program &program); void GraphCreateWeightVarNodes(const Program &program);
Node *GraphCreateInstructNode(const Program &program,
const std::shared_ptr<OpLite> &op,
const std::vector<Place> &valid_places);
// Check the bidirectional connection. // Check the bidirectional connection.
bool CheckBidirectionalConnection(); bool CheckBidirectionalConnection();
......
...@@ -61,7 +61,7 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels( ...@@ -61,7 +61,7 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
targets.insert(place.target); targets.insert(place.target);
} }
CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_; // CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_;
VLOG(2) << "op " << op_type_ << " get " << kernels.size() << " kernels"; VLOG(2) << "op " << op_type_ << " get " << kernels.size() << " kernels";
return kernels; return kernels;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册