未验证 提交 55aea350 编写于 作者: D Double_V 提交者: GitHub

add fusepass Reshape2MatmulFusePass AdaptivePool2dConvertGlobalPass (#33555)

* add transpose transpose opdef, test=develop

* add line, test=develop

* fix wrong name, test=develop

* add pass, test=develop

* fix bug, test=develop

* fix bug, test=develop

* delete limite about alpha, test=develop

* add mul to reshape2MatmulFusePass, test=develop

* add limit about alpha, test=develop

* fix bug,test=develop

* set adaptive as false and global_pooling as True, test=develop

* set x_num_col_dims as 1, test=develop

* fix reshape, add attr limit, test=develop

* fix conflict,test=develop

* fix comment, test=develop

* fix comment,test=develop

* fix comment,test=develop

* ,test=develop

* add IsType, test=develop

* add IsType, test=develop
上级 83284c8c
......@@ -24,6 +24,46 @@ namespace paddle {
namespace framework {
namespace ir {
AdaptivePool2dConvertGlobalPass::AdaptivePool2dConvertGlobalPass() {
AddOpCompat(OpCompat("pool2d"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("pooling_type")
.IsStringIn({"max", "avg"})
.End()
.AddAttr("ksize")
.IsType<std::vector<int>>()
.End()
.AddAttr("global_pooling")
.IsBoolEQ(true)
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("exclusive")
.IsType<bool>()
.End()
.AddAttr("adaptive")
.IsBoolEQ(false)
.End()
.AddAttr("ceil_mode")
.IsType<bool>()
.End()
.AddAttr("data_format")
.IsStringIn({"NHWC", "NCHW"})
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End();
}
void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const {
std::string name_scope = "adaptive_pool2d_convert_global_pass";
FusePassBase::Init(name_scope, graph);
......
......@@ -31,6 +31,7 @@ class Graph;
*/
class AdaptivePool2dConvertGlobalPass : public FusePassBase {
public:
AdaptivePool2dConvertGlobalPass();
virtual ~AdaptivePool2dConvertGlobalPass() {}
protected:
......
......@@ -267,6 +267,68 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}
Reshape2MatmulFusePass::Reshape2MatmulFusePass() {
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape") // ints
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGT(0.99999f)
.IsNumLT(1.00001f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ("False")
.End()
.AddAttr("transpose_Y")
.IsBoolEQ("False")
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}
void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
......@@ -280,6 +342,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "fuse reshape2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern);
......@@ -326,6 +392,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
}
if (!IsCompat(desc)) {
LOG(WARNING) << "reshape2 matmul pass in out mul op compat failed.";
return;
}
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(reshape2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node);
......
......@@ -96,6 +96,7 @@ class Squeeze2MatmulFusePass : public FusePassBase {
class Reshape2MatmulFusePass : public FusePassBase {
public:
Reshape2MatmulFusePass();
virtual ~Reshape2MatmulFusePass() {}
protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册