未验证 提交 f6db9806 编写于 作者: W wuhuanzhou 提交者: GitHub

GeneratePass for Python Pass (#35708)

#### 背景

#35602 提供Python侧开发子图替换类Pass的方式:

- 利用Paddle Python API或者辅助类型定义子图program用来匹配/替换图;
- Python侧注册Pass时,将注册函数最终转换为protobuf定义的PassDesc数据形式,供C++侧进行解析完成Pass实例注册。

本PR即为根据PassDesc规则描述解析生成Pass实例。

#### 方案设计

##### Pass规则验证

在以往的Pass开发中,会存在随着算子迭代引发的匹配失效或者错误匹配的问题,该问题可以通过扫描算子支持的参数设置及参数类型等来判断是否应该使用该Pass或者给出提示需要修改Pass代码。

当前Pass开发中提供了算子兼容性OpCompatSensiblePass用于解决上述问题。但同时还存在不足:由于以往Pass开发在运行时才能获取到pattern信息,所以需要在执行Pass时才可以判断。

使用PassDesc表示的Pass可以在执行Pass前验证上述问题,这个过程在VerifyDesc中完成。

##### 根据匹配子图构造pattern

GeneratePass对于图匹配和替换使用GraphPatternDecetor完成,构造匹配pattern实际上就是将对应对象成员PDPattern中添加PDNode和边关系。该过程在函数`InitGeneratePattern`中完成,该函数没有作为GeneratePass的成员方法,主要出于后续可能开发新的Decetor考虑,GeneratePass与Decetor的操作是没有关联的。

初始化pattern主要通过遍历匹配子图program的全部算子实现:

1. 添加当前算子对应PDNode及限制条件(算子类型、属性限制等);
2. 遍历当前算子对应输入并从pattern中尝试获取PDNode:
   - 在pattern中获取到PDNode且为输出节点:表示属于匹配子图的中间节点,将该PDNode设置为中间节点;
   - 在pattern中没有获取到PDNode:添加该输入PDNode并设置作为输入节点;
   - 设置输入到算子的边关系;
3. 遍历当前算子对应输出:
   - 在pattern中获取到PDNode且为输入节点:表示属于匹配子图的中间节点,将该PDNode设置为中间节点;
   - 在pattern中没有获取到PDNode:添加该输入PDNode并设置作为输出节点;
   - 设置算子到输出的边关系;

##### 根据替换子图操作graph

替换子图操作的过程在`GetGenerateRewrite`函数中完成,与`InitGeneratePattern`类似没有作为GeneratePass的成员方法。

生成替换子图操作过程如下:

1. 判断冗余替换子图;
2. 遍历替换子图program的全部算子添加替换子图Node:
   1. 添加当前算子的Node及属性设置;
   2. 遍历当前算子对应输入,添加中间variable节点;
   3. 遍历当前算子对应输出,添加中间variable节点;
   4. 添加输入/输出节点与算子节点的边关系;
3. 删除匹配图中属于中间节点的Node;

##### 优化子图验证

对于替换子图或者替换后的计算图是否可以正确运行等,可以在执行Pass时验证,从而防止在后续执行计算图时出现异常。

当前Pass执行直接修改计算图,验证失败时无法很好的完成还原操作,目前子图验证暂时默认成功,留到后续改进。
上级 e64fed86
......@@ -28,6 +28,7 @@ add_subdirectory(io)
add_subdirectory(new_executor)
#ddim lib
proto_library(framework_proto SRCS framework.proto)
proto_library(pass_desc_proto SRCS pass_desc.proto DEPS framework_proto)
proto_library(op_def_proto SRCS op_def.proto DEPS framework_proto)
cc_library(op_def_api SRCS op_def_api.cc DEPS op_def_proto boost)
......
......@@ -95,6 +95,8 @@ pass_library(multihead_matmul_fuse_pass inference)
pass_library(adaptive_pool2d_convert_global_pass inference)
pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)
if(WITH_GPU OR WITH_ROCM)
pass_library(cudnn_placement_pass base DEPS placement_pass_base)
pass_library(embedding_eltwise_layernorm_fuse_pass inference)
......@@ -156,6 +158,7 @@ cc_test(test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_
cc_test(test_adaptive_pool2d_convert_global_pass SRCS adaptive_pool2d_convert_global_pass_tester.cc DEPS adaptive_pool2d_convert_global_pass)
cc_test(test_unsqueeze2_eltwise_fuse_pass SRCS unsqueeze2_eltwise_fuse_pass_tester.cc DEPS unsqueeze2_eltwise_fuse_pass)
cc_test(test_layer_norm_fuse_pass_cc SRCS layer_norm_fuse_pass_tester.cc DEPS layer_norm_fuse_pass pass_test_util naive_executor)
cc_test(test_generate_pass_cc SRCS generate_pass_tester.cc DEPS generate_pass pass_desc_proto)
if(WITH_GPU OR WITH_ROCM)
cc_test(test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass)
cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/generate_pass.h"
namespace paddle {
namespace framework {
namespace ir {
void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
const proto::BlockDesc& block = pass_desc.pattern().blocks(0);
// Traverse all operators to create subgraph.
for (int index = 0; index < block.ops_size(); ++index) {
const proto::OpDesc& op = block.ops(index);
// Create a PDNode for current operator. Use the index as name to avoid
// multiple operators with same type. Get a PDNode from pattern subgraph
// through index in rewrite phase.
PDNode* op_pdnode =
pattern->NewNode(std::to_string(index))->assert_is_op(op.type());
// Create PDNodes for inputs of current operator.
for (const proto::OpDesc::Var& var : op.inputs()) {
for (const std::string& argument : var.arguments()) {
// The input may be the output of other operator.
PDNode* var_pdnode = pattern->RetrieveNode(argument);
if (nullptr == var_pdnode) {
var_pdnode = pattern->NewNode(argument)->AsInput();
} else if (var_pdnode->IsOutput()) {
var_pdnode->AsIntermediate();
}
var_pdnode->assert_is_op_input(op.type());
pattern->AddEdge(var_pdnode, op_pdnode);
}
}
// Create PDNodes for outputs of current operator.
for (const proto::OpDesc::Var& var : op.outputs()) {
for (const std::string& argument : var.arguments()) {
// The output may be the input of other operator.
PDNode* var_pdnode = pattern->RetrieveNode(argument);
if (nullptr == var_pdnode) {
var_pdnode = pattern->NewNode(argument)->AsOutput();
} else if (var_pdnode->IsInput()) {
var_pdnode->AsIntermediate();
}
var_pdnode->assert_is_op_output(op.type());
pattern->AddEdge(op_pdnode, var_pdnode);
}
}
// Set attribute condition for current operator.
for (const proto::OpDesc::Attr& attr : op.attrs()) {
op_pdnode->assert_more([&](Node* x) {
if (x && x->IsOp()) {
OpDesc* op_desc = x->Op();
if (op_desc->HasAttr(attr.name())) {
return GetAttrValue(attr) == op_desc->GetAttr(attr.name());
}
return false;
}
return false;
});
}
}
}
GraphPatternDetector::handle_t GetGenerateRewrite(
const PDPattern& pattern, const proto::PassDesc& pass_desc) {
GraphPatternDetector::handle_t handler = [&](
const GraphPatternDetector::subgraph_t subgraph, Graph* graph) {
// There are some duplicate patterns.
for (auto iter : subgraph) {
if (nullptr == graph->RetrieveNode(iter.second->id())) {
VLOG(3) << "Node [" << iter.second->Name()
<< "] of subgraph has been removed. So skip this optimize.";
return;
}
}
const proto::BlockDesc& block = pass_desc.replace().blocks(0);
// `var_node_maps` record the mapping of variable to the pattern subgraph.
std::map<std::string, Node*> var_node_maps;
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) {
Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var()));
var_node_maps.insert({var_map.replace_var(), node});
}
// Traverse all operators to create subgraph.
for (const proto::OpDesc& op : block.ops()) {
OpDesc op_desc;
std::vector<Node *> in_nodes, out_nodes;
op_desc.SetType(op.type());
// Create Nodes for inputs of current operator.
for (const proto::OpDesc::Var& var : op.inputs()) {
std::vector<std::string> arguments;
for (const std::string& argument : var.arguments()) {
// The input may be mapped on the operator of pattern subgraph.
Node* node = nullptr;
auto iter = var_node_maps.find(argument);
if (var_node_maps.end() == iter) {
VarDesc var_desc(patterns::UniqueKey(argument));
node = graph->CreateVarNode(&var_desc);
var_node_maps.insert({argument, node});
} else {
node = iter->second;
}
in_nodes.push_back(node);
arguments.push_back(node->Name());
}
op_desc.SetInput(var.parameter(), arguments);
}
// Create Nodes for outputs of current operator.
for (const proto::OpDesc::Var& var : op.outputs()) {
std::vector<std::string> arguments;
for (const std::string& argument : var.arguments()) {
// The output may be mapped on the operator of pattern subgraph.
Node* node = nullptr;
auto iter = var_node_maps.find(argument);
if (var_node_maps.end() == iter) {
VarDesc var_desc(patterns::UniqueKey(argument));
node = graph->CreateVarNode(&var_desc);
var_node_maps.insert({argument, node});
} else {
node = iter->second;
}
out_nodes.push_back(node);
arguments.push_back(node->Name());
}
op_desc.SetOutput(var.parameter(), arguments);
}
// Set attribute for current operator.
for (const proto::OpDesc::Attr& attr : op.attrs()) {
op_desc.SetAttr(attr.name(), GetAttrValue(attr));
}
// Create a Node for current operator.
Node* op_node = graph->CreateOpNode(&op_desc);
for (Node* node : in_nodes) {
IR_NODE_LINK_TO(node, op_node);
}
for (Node* node : out_nodes) {
IR_NODE_LINK_TO(op_node, node);
}
}
// Remove nodes that are intermediate.
std::unordered_set<const Node*> remove_nodes;
for (const std::unique_ptr<PDNode>& pdnode : pattern.nodes()) {
remove_nodes.emplace(subgraph.at(pdnode.get()));
}
for (auto iter : var_node_maps) {
remove_nodes.erase(iter.second);
}
GraphSafeRemoveNodes(graph, remove_nodes);
};
return handler;
}
GeneratePass::GeneratePass(const std::string& binary_str) {
multi_pass_desc_.ParseFromString(binary_str);
VerifyDesc();
}
GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc)
: multi_pass_desc_(multi_pass_desc) {
VerifyDesc();
}
void GeneratePass::ApplyImpl(Graph* graph) const {
for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) {
GraphPatternDetector detector;
InitGeneratePattern(pass_desc, detector.mutable_pattern());
detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc));
// The rewrited graph needs to be verified. Current Pass should be skipped
// if validation failed. Rewrite based on the original graph cannot
// implement rollback operation.
VerifyGraph(*graph);
}
}
void GeneratePass::VerifyDesc() const {
PADDLE_ENFORCE_NE(multi_pass_desc_.pass_descs_size(), 0,
platform::errors::InvalidArgument(
"Size of PassDesc should not be empty."));
for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) {
// Check inputs/outputs of subgraph should in `var_maps`.
std::set<std::string> pattern_var_sets, replace_var_sets;
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) {
pattern_var_sets.emplace(var_map.pattern_var());
replace_var_sets.emplace(var_map.replace_var());
}
auto check_vars = [=](std::set<std::string>* var_sets,
const proto::BlockDesc& block) {
for (const proto::OpDesc& op : block.ops()) {
for (const proto::OpDesc::Var& var : op.outputs()) {
for (const std::string& argument : var.arguments()) {
var_sets->emplace(argument);
}
}
}
for (const proto::OpDesc& op : block.ops()) {
for (const proto::OpDesc::Var& var : op.inputs()) {
for (const std::string& argument : var.arguments()) {
PADDLE_ENFORCE_NE(
var_sets->find(argument), var_sets->end(),
platform::errors::InvalidArgument(
"Subgraph of PassDesc has argument [%s] not in `var_maps`.",
argument));
}
}
}
};
check_vars(&pattern_var_sets, pass_desc.pattern().blocks(0));
check_vars(&replace_var_sets, pass_desc.replace().blocks(0));
}
}
bool GeneratePass::VerifyGraph(const Graph& graph) {
// Return true temporarily.
return true;
}
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/pass_desc.pb.h"
namespace paddle {
namespace framework {
namespace ir {
// Generate a substitute pass from protobuf.
class GeneratePass : public Pass {
public:
// from binary_str
explicit GeneratePass(const std::string& binary_str);
// from PassDesc/MultiPassDesc
explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc);
protected:
void ApplyImpl(Graph* graph) const override;
private:
GeneratePass() = delete;
DISABLE_COPY_AND_ASSIGN(GeneratePass);
// Verify desc
void VerifyDesc() const;
// Verify graph
static bool VerifyGraph(const Graph& graph);
proto::MultiPassDesc multi_pass_desc_;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/generate_pass.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
template <proto::MultiPassDesc (*Functor)(void)>
class CXXGeneratePass : public GeneratePass {
public:
CXXGeneratePass() : GeneratePass(Functor()) {}
};
#define REGISTER_GENERATE_PASS(pass_type, function) \
REGISTER_PASS(pass_type, ::paddle::framework::ir::CXXGeneratePass<&function>)
proto::MultiPassDesc generate_fc_fuse() {
proto::MultiPassDesc multi_pass_desc;
for (bool with_relu : {true, false}) {
proto::PassDesc* pass_desc = multi_pass_desc.add_pass_descs();
proto::BlockDesc* pattern = pass_desc->mutable_pattern()->add_blocks();
pattern->set_idx(0);
pattern->set_parent_idx(0);
proto::OpDesc* mul = pattern->add_ops();
mul->set_type("mul");
proto::OpDesc::Var* mul_x = mul->add_inputs();
mul_x->set_parameter("X");
mul_x->add_arguments()->assign("x");
proto::OpDesc::Var* mul_y = mul->add_inputs();
mul_y->set_parameter("Y");
mul_y->add_arguments()->assign("w");
proto::OpDesc::Var* mul_out = mul->add_outputs();
mul_out->set_parameter("Out");
mul_out->add_arguments()->assign("mul_out");
proto::OpDesc* ewadd = pattern->add_ops();
ewadd->set_type("elementwise_add");
proto::OpDesc::Var* ewadd_x = ewadd->add_inputs();
ewadd_x->set_parameter("X");
ewadd_x->add_arguments()->assign("mul_out");
proto::OpDesc::Var* ewadd_y = ewadd->add_inputs();
ewadd_y->set_parameter("Y");
ewadd_y->add_arguments()->assign("b");
proto::OpDesc::Var* ewadd_out = ewadd->add_outputs();
ewadd_out->set_parameter("Out");
ewadd_out->add_arguments()->assign("ewadd_out");
proto::OpDesc* relu = nullptr;
proto::BlockDesc* replace = pass_desc->mutable_replace()->add_blocks();
replace->set_idx(0);
replace->set_parent_idx(0);
proto::OpDesc* fc = replace->add_ops();
fc->set_type("fc");
proto::OpDesc::Var* fc_x = fc->add_inputs();
fc_x->set_parameter("Input");
fc_x->add_arguments()->assign("x");
proto::OpDesc::Var* fc_w = fc->add_inputs();
fc_w->set_parameter("W");
fc_w->add_arguments()->assign("w");
proto::OpDesc::Var* fc_b = fc->add_inputs();
fc_b->set_parameter("Bias");
fc_b->add_arguments()->assign("b");
proto::OpDesc::Var* fc_out = fc->add_outputs();
fc_out->set_parameter("Out");
fc_out->add_arguments()->assign("fc_out");
for (const char* var : {"x", "w", "b", "fc_out"}) {
proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps();
var_map->set_pattern_var(var);
var_map->set_replace_var(var);
}
proto::PassDesc::AttrMap* attr_map = pass_desc->add_attr_maps();
attr_map->set_pattern_op_idx(0);
attr_map->set_pattern_name("x_num_col_dims");
attr_map->set_replace_op_idx(0);
attr_map->set_replace_name("in_num_col_dims");
if (with_relu) {
relu = pattern->add_ops();
relu->set_type("relu");
proto::OpDesc::Var* relu_x = relu->add_inputs();
relu_x->set_parameter("X");
relu_x->add_arguments()->assign("ewadd_out");
proto::OpDesc::Var* relu_out = relu->add_outputs();
relu_out->set_parameter("Out");
relu_out->add_arguments()->assign("relu_out");
pass_desc->mutable_var_maps(3)->set_pattern_var("relu_out");
proto::OpDesc::Attr* attr = fc->add_attrs();
attr->set_name("activation_type");
attr->set_type(proto::AttrType::STRING);
attr->set_s("relu");
} else {
pass_desc->mutable_var_maps(3)->set_pattern_var("ewadd_out");
}
}
return multi_pass_desc;
}
proto::MultiPassDesc generate_multi_add_to_addn() {
proto::MultiPassDesc multi_pass_desc;
proto::PassDesc* pass_desc = multi_pass_desc.add_pass_descs();
proto::BlockDesc* pattern = pass_desc->mutable_pattern()->add_blocks();
proto::OpDesc* ewadd_0 = pattern->add_ops();
ewadd_0->set_type("elementwise_add");
proto::OpDesc::Var* ewadd_0_x = ewadd_0->add_inputs();
ewadd_0_x->set_parameter("X");
ewadd_0_x->add_arguments()->assign("a");
proto::OpDesc::Var* ewadd_0_y = ewadd_0->add_inputs();
ewadd_0_y->set_parameter("Y");
ewadd_0_y->add_arguments()->assign("b");
proto::OpDesc::Var* ewadd_0_out = ewadd_0->add_outputs();
ewadd_0_out->set_parameter("Out");
ewadd_0_out->add_arguments()->assign("ewadd_out_0");
proto::OpDesc* ewadd_1 = pattern->add_ops();
ewadd_1->set_type("elementwise_add");
proto::OpDesc::Var* ewadd_1_x = ewadd_1->add_inputs();
ewadd_1_x->set_parameter("X");
ewadd_1_x->add_arguments()->assign("ewadd_out_0");
proto::OpDesc::Var* ewadd_1_y = ewadd_1->add_inputs();
ewadd_1_y->set_parameter("Y");
ewadd_1_y->add_arguments()->assign("c");
proto::OpDesc::Var* ewadd_1_out = ewadd_1->add_outputs();
ewadd_1_out->set_parameter("Out");
ewadd_1_out->add_arguments()->assign("ewadd_out_1");
proto::BlockDesc* replace = pass_desc->mutable_replace()->add_blocks();
proto::OpDesc* addn = replace->add_ops();
addn->set_type("add_n");
proto::OpDesc::Var* addn_x = addn->add_inputs();
addn_x->set_parameter("X");
addn_x->add_arguments()->assign("a");
addn_x->add_arguments()->assign("b");
addn_x->add_arguments()->assign("c");
proto::OpDesc::Var* addn_out = addn->add_outputs();
addn_out->set_parameter("Out");
addn_out->add_arguments()->assign("addn_out");
for (const char* var : {"a", "b", "c", "ewadd_out_1"}) {
proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps();
var_map->set_pattern_var(var);
var_map->set_replace_var(var);
}
pass_desc->mutable_var_maps(3)->set_replace_var("addn_out");
return multi_pass_desc;
}
proto::MultiPassDesc generate_combine_matmul() {
proto::MultiPassDesc multi_pass_desc;
proto::PassDesc* pass_desc = multi_pass_desc.add_pass_descs();
proto::BlockDesc* pattern = pass_desc->mutable_pattern()->add_blocks();
proto::OpDesc* matmul_0 = pattern->add_ops();
matmul_0->set_type("matmul");
proto::OpDesc::Var* matmul_0_x = matmul_0->add_inputs();
matmul_0_x->set_parameter("X");
matmul_0_x->add_arguments()->assign("a");
proto::OpDesc::Var* matmul_0_y = matmul_0->add_inputs();
matmul_0_y->set_parameter("Y");
matmul_0_y->add_arguments()->assign("b");
proto::OpDesc::Var* matmul_0_out = matmul_0->add_outputs();
matmul_0_out->set_parameter("Out");
matmul_0_out->add_arguments()->assign("matmul_out_0");
proto::OpDesc* matmul_1 = pattern->add_ops();
matmul_1->set_type("matmul");
proto::OpDesc::Var* matmul_1_x = matmul_1->add_inputs();
matmul_1_x->set_parameter("X");
matmul_1_x->add_arguments()->assign("a");
proto::OpDesc::Var* matmul_1_y = matmul_1->add_inputs();
matmul_1_y->set_parameter("Y");
matmul_1_y->add_arguments()->assign("c");
proto::OpDesc::Var* matmul_1_out = matmul_1->add_outputs();
matmul_1_out->set_parameter("Out");
matmul_1_out->add_arguments()->assign("matmul_out_1");
proto::BlockDesc* replace = pass_desc->mutable_replace()->add_blocks();
proto::OpDesc* concat = replace->add_ops();
concat->set_type("concat");
proto::OpDesc::Var* concat_x = concat->add_inputs();
concat_x->set_parameter("X");
concat_x->add_arguments()->assign("b");
concat_x->add_arguments()->assign("c");
proto::OpDesc::Var* concat_out = concat->add_outputs();
concat_out->set_parameter("Out");
concat_out->add_arguments()->assign("concat_out");
proto::OpDesc* matmul = replace->add_ops();
matmul->set_type("matmul");
proto::OpDesc::Var* matmul_x = matmul->add_inputs();
matmul_x->set_parameter("X");
matmul_x->add_arguments()->assign("a");
proto::OpDesc::Var* matmul_y = matmul->add_inputs();
matmul_y->set_parameter("Y");
matmul_y->add_arguments()->assign("concat_out");
proto::OpDesc::Var* matmul_out = matmul->add_outputs();
matmul_out->set_parameter("Out");
matmul_out->add_arguments()->assign("matmul_out");
proto::OpDesc* slice_0 = replace->add_ops();
slice_0->set_type("slice");
proto::OpDesc::Var* slice_0_x = slice_0->add_inputs();
slice_0_x->set_parameter("X");
slice_0_x->add_arguments()->assign("matmul_out");
proto::OpDesc::Var* slice_0_out = slice_0->add_outputs();
slice_0_out->set_parameter("Out");
slice_0_out->add_arguments()->assign("slice_out_0");
proto::OpDesc* slice_1 = replace->add_ops();
slice_1->set_type("slice");
proto::OpDesc::Var* slice_1_x = slice_1->add_inputs();
slice_1_x->set_parameter("X");
slice_1_x->add_arguments()->assign("matmul_out");
proto::OpDesc::Var* slice_1_out = slice_1->add_outputs();
slice_1_out->set_parameter("Out");
slice_1_out->add_arguments()->assign("slice_out_1");
for (const char* var : {"a", "b", "c", "matmul_out_0", "matmul_out_1"}) {
proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps();
var_map->set_pattern_var(var);
var_map->set_replace_var(var);
}
pass_desc->mutable_var_maps(3)->set_replace_var("slice_out_0");
pass_desc->mutable_var_maps(4)->set_replace_var("slice_out_1");
return multi_pass_desc;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_GENERATE_PASS(generate_fc_fuse,
paddle::framework::ir::generate_fc_fuse);
REGISTER_GENERATE_PASS(generate_multi_add_to_addn,
paddle::framework::ir::generate_multi_add_to_addn);
REGISTER_GENERATE_PASS(generate_combine_matmul,
paddle::framework::ir::generate_combine_matmul);
namespace paddle {
namespace framework {
namespace ir {
TEST(GeneratePass, construct_with_string) {
std::string binary_str;
generate_fc_fuse().SerializeToString(&binary_str);
GeneratePass generate_pass(binary_str);
}
TEST(GeneratePass, generate_fc_fuse) {
// inputs operator output
// --------------------------------------------------------
// (a, filters_0 bias_0) conv2d -> conv2d_out
// conv2d_out relu -> relu_out_0
// (relu_out_0, weights_0) mul -> mul_out_0
// (mul_out_0, bias_1) elementwise_add -> add_out_0
// add_out_0 relu -> relu_out_1
// (relu_out_1, weights_1) mul -> mul_out_1
// (mul_out_1, bias_2) elementwise_add -> add_out_1
Layers layers;
auto* a = layers.data("a");
auto* filters_0 = layers.data("conv2d_filters_0", {}, true);
auto* bias_0 = layers.data("conv2d_bias_0", {}, true);
auto* conv2d_out = layers.conv2d(a, filters_0, bias_0, false);
auto* relu_out_0 = layers.relu(conv2d_out);
auto* weights_0 = layers.data("weights_0", {}, true);
auto* mul_out_0 = layers.mul(relu_out_0, weights_0);
auto* bias_1 = layers.data("bias_1", {}, true);
auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1, nullptr, 1);
auto* relu_out_1 = layers.relu(add_out_0);
auto* weights_1 = layers.data("weights_1", {}, true);
auto* mul_out_1 = layers.mul(relu_out_1, weights_1);
auto* bias_2 = layers.data("bias_2", {}, true);
auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2, nullptr, 1);
VLOG(4) << add_out_1;
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("generate_fc_fuse");
int num_nodes_before = graph->Nodes().size();
int num_mul_nodes_before = GetNumOpNodes(graph, "mul");
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_fc_nodes_after = GetNumOpNodes(graph, "fc");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6,
platform::errors::InvalidArgument(
"num_nodes_before=%d, num_nodes_after=%d.",
num_nodes_before, num_nodes_after));
PADDLE_ENFORCE_EQ(num_fc_nodes_after, 2,
platform::errors::InvalidArgument("num_fc_nodes_after=%d.",
num_fc_nodes_after));
PADDLE_ENFORCE_EQ(num_mul_nodes_before, num_fc_nodes_after,
platform::errors::InvalidArgument(
"num_mul_nodes_before=%d, num_fc_nodes_after=%d.",
num_mul_nodes_before, num_fc_nodes_after));
}
TEST(GeneratePass, generate_multi_add_to_addn) {
// inputs operator output
// --------------------------------------------------------
// (a, b) elementwise_add -> add_out_0
// (add_out_0, c) elementwise_add -> add_out_1
Layers layers;
auto* a = layers.data("a");
auto* b = layers.data("b");
auto* c = layers.data("c");
auto* add_out_0 = layers.elementwise_add(a, b);
layers.elementwise_add(add_out_0, c);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("generate_multi_add_to_addn");
int num_nodes_before = graph->Nodes().size();
int num_add_nodes_before = GetNumOpNodes(graph, "elementwise_add");
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_addn_nodes_after = GetNumOpNodes(graph, "add_n");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 2,
platform::errors::InvalidArgument(
"num_nodes_before=%d, num_nodes_after=%d.",
num_nodes_before, num_nodes_after));
PADDLE_ENFORCE_EQ(num_addn_nodes_after, 1,
platform::errors::InvalidArgument(
"num_addn_nodes_after=%d.", num_addn_nodes_after));
PADDLE_ENFORCE_EQ(num_add_nodes_before, num_addn_nodes_after + 1,
platform::errors::InvalidArgument(
"num_add_nodes_before=%d, num_addn_nodes_after=%d.",
num_add_nodes_before, num_addn_nodes_after));
}
TEST(GeneratePass, generate_combine_matmul) {
// inputs operator output
// --------------------------------------------------------
// (a, b) matmul -> matmul_out_0
// (a, c) matmul -> matmul_out_1
Layers layers;
auto* a = layers.data("a");
auto* b = layers.data("b");
auto* c = layers.data("c");
layers.matmul(a, b);
layers.matmul(a, c);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("generate_combine_matmul");
int num_nodes_before = graph->Nodes().size();
int num_matmul_nodes_before = GetNumOpNodes(graph, "matmul");
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_matmul_nodes_after = GetNumOpNodes(graph, "matmul");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after - 4,
platform::errors::InvalidArgument(
"num_nodes_before=%d, num_nodes_after=%d.",
num_nodes_before, num_nodes_after));
PADDLE_ENFORCE_EQ(num_matmul_nodes_after, 1,
platform::errors::InvalidArgument(
"num_matmul_nodes_after=%d.", num_matmul_nodes_after));
PADDLE_ENFORCE_EQ(
num_matmul_nodes_before, num_matmul_nodes_after + 1,
platform::errors::InvalidArgument(
"num_matmul_nodes_before=%d, num_matmul_nodes_after=%d.",
num_matmul_nodes_before, num_matmul_nodes_after));
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -4,7 +4,7 @@ include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform)
include_directories(${PADDLE_SOURCE_DIR}/paddle/utils)
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune
feed_fetch_method pass pass_builder parallel_executor profiler layer tracer engine scope_pool
feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator)
......
......@@ -38,6 +38,7 @@ limitations under the License. */
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/ir/coalesce_grad_tensor_pass.h"
#include "paddle/fluid/framework/ir/generate_pass.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -2356,6 +2357,21 @@ All parameter, weight, gradient are variables in Paddle.
m.def("disable_profiler", platform::DisableProfiler);
m.def("is_profiler_enabled", platform::IsProfileEnabled);
m.def("reset_profiler", platform::ResetProfiler);
m.def("register_pass", [](const std::string &pass_type,
const py::object &callable) {
PADDLE_ENFORCE_EQ(
framework::ir::PassRegistry::Instance().Has(pass_type), false,
platform::errors::AlreadyExists(
"Pass '%s' is registered more than once. Please use another name.",
pass_type));
framework::ir::PassRegistry::Instance().Insert(pass_type, [pass_type,
callable]() {
py::gil_scoped_acquire guard;
std::unique_ptr<framework::ir::Pass> pass(
new framework::ir::GeneratePass(py::cast<std::string>(callable())));
return pass;
});
});
m.def("get_pass", [](const std::string &pass_type) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_type);
return std::shared_ptr<framework::ir::Pass>(std::move(pass));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册