From 93c591e200be77887b69488f600c84d3dfabeb0b Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Tue, 26 Oct 2021 10:45:09 +0800 Subject: [PATCH] [Paddle-Inference]Add MatmulV2ToMatmul convert Pass, fix (matmul_v2, matmul, mul) convert pass, fix (matmul, mul) op_teller (#36652) * new_Matmul2ToMatmulToMul * new_Matmul2ToMatmulToMul * fix paddle_pass_builder * fix paddle_pass_builder * fix paddle_pass_builder * tem * tem * Add MatmulV2ToMatmul convert Pass; MatmulV2ToMul convert Pass * Add MatmulV2ToMatmul convert Pass; MatmulV2ToMul convert Pass * add matmul_broadcast_unitest * fix op_teller --- .../ir/delete_quant_dequant_filter_op_pass.cc | 5 +- .../framework/ir/graph_pattern_detector.cc | 51 ++-- .../framework/ir/graph_pattern_detector.h | 23 +- .../framework/ir/map_matmul_to_mul_pass.cc | 221 ++++++++++++++---- .../framework/ir/map_matmul_to_mul_pass.h | 18 +- .../ir/multihead_matmul_fuse_pass.cc | 19 +- .../inference/api/paddle_pass_builder.cc | 23 +- paddle/fluid/inference/tensorrt/op_teller.cc | 66 +++++- .../analyzer_seq_pool1_fuse_statis_tester.cc | 4 +- .../inference/tests/infer_ut/test_LeViT.cc | 6 +- .../unittests/ir/inference/test_trt_matmul.py | 38 +++ 11 files changed, 388 insertions(+), 86 deletions(-) diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc index b9cc337df8..2fc133edb7 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc @@ -181,7 +181,7 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { "Weight scale should be nonzero, but get zero.")); weight_scale[i] = weight_scale[i] / range; } - } else { + } else if (dequant_type == "fake_quantize_dequantize_abs_max") { // Implement quantize_dequantize_abs_max quantization algorithm float abs_max_weight = 0.; for (int j = 0; j < weight_tensor->numel(); j++) { @@ -192,6 +192,9 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { platform::errors::InvalidArgument( "Weight scale should be nonzero, but get zero")); weight_scale.push_back(abs_max_weight / range); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported quantize_dequantize op type: %s", dequant_type)); } nodes2rm.insert(quant_dequant_op_outscale); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 71b30d854c..6830a1f85e 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1606,6 +1606,7 @@ PDNode *patterns::Matmul::operator()() { ->assert_is_op_input("matmul", "X"); auto matmul_in_y = pattern->NewNode(matmul_in_y_repr()) ->AsInput() + ->assert_is_persistable_var() ->assert_is_op_input("matmul", "Y"); auto matmul_out = pattern->NewNode(matmul_out_repr()) ->AsOutput() @@ -1615,23 +1616,45 @@ PDNode *patterns::Matmul::operator()() { return matmul_out; } +// MatmulV2: tensor * weight +PDNode *patterns::MatmulV2Weight::operator()() { + auto matmul_v2_op = + pattern->NewNode(matmul_v2_op_repr())->assert_is_op("matmul_v2"); + + auto matmul_v2_in_x = pattern->NewNode(matmul_v2_in_x_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "X"); + auto matmul_v2_in_y = pattern->NewNode(matmul_v2_in_y_repr()) + ->AsInput() + ->assert_is_persistable_var() // Y is weight + ->assert_is_op_input("matmul_v2", "Y"); + auto matmul_v2_out = pattern->NewNode(matmul_v2_out_repr()) + ->AsOutput() + ->assert_is_op_output("matmul_v2", "Out"); + + matmul_v2_op->LinksFrom({matmul_v2_in_x, matmul_v2_in_y}) + .LinksTo({matmul_v2_out}); + return matmul_v2_out; +} + +// MatmulV2: tensor * tensor or tensor * weight PDNode *patterns::MatmulV2::operator()() { - auto matmul_op = - pattern->NewNode(matmul_op_repr())->assert_is_op("matmul_v2"); + auto matmul_v2_op = + pattern->NewNode(matmul_v2_op_repr())->assert_is_op("matmul_v2"); - auto matmul_in_x = pattern->NewNode(matmul_in_x_repr()) - ->AsInput() - ->assert_is_op_input("matmul_v2", "X"); - auto matmul_in_y = pattern->NewNode(matmul_in_y_repr()) - ->assert_is_persistable_var() - ->AsInput() - ->assert_is_op_input("matmul_v2", "Y"); - auto matmul_out = pattern->NewNode(matmul_out_repr()) - ->AsOutput() - ->assert_is_op_output("matmul_v2", "Out"); + auto matmul_v2_in_x = pattern->NewNode(matmul_v2_in_x_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "X"); + auto matmul_v2_in_y = pattern->NewNode(matmul_v2_in_y_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto matmul_v2_out = pattern->NewNode(matmul_v2_out_repr()) + ->AsOutput() + ->assert_is_op_output("matmul_v2", "Out"); - matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out}); - return matmul_out; + matmul_v2_op->LinksFrom({matmul_v2_in_x, matmul_v2_in_y}) + .LinksTo({matmul_v2_out}); + return matmul_v2_out; } PDNode *patterns::Squeeze2Matmul::operator()() { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index cc9d1c76ab..6657ab5a6a 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -976,17 +976,28 @@ struct Matmul : public PatternBase { PATTERN_DECL_NODE(matmul_out); }; -// Matmul_v2 op -// Forward pass for matmul_v2. +// MatmulV2: tensor * weight +struct MatmulV2Weight : public PatternBase { + MatmulV2Weight(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "matmul_v2_weight") {} + + PDNode* operator()(); + PATTERN_DECL_NODE(matmul_v2_in_x); + PATTERN_DECL_NODE(matmul_v2_in_y); + PATTERN_DECL_NODE(matmul_v2_op); + PATTERN_DECL_NODE(matmul_v2_out); +}; + +// MatmulV2: tensor * tensor or tensor * weight struct MatmulV2 : public PatternBase { MatmulV2(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "matmul_v2") {} PDNode* operator()(); - PATTERN_DECL_NODE(matmul_in_x); - PATTERN_DECL_NODE(matmul_in_y); - PATTERN_DECL_NODE(matmul_op); - PATTERN_DECL_NODE(matmul_out); + PATTERN_DECL_NODE(matmul_v2_in_x); + PATTERN_DECL_NODE(matmul_v2_in_y); + PATTERN_DECL_NODE(matmul_v2_op); + PATTERN_DECL_NODE(matmul_v2_out); }; // Squeeze2 + Matmul diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc index cdec49260f..865b556f30 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc @@ -68,7 +68,7 @@ MapMatmul2MulPass::MapMatmul2MulPass() { .End(); } -MapMatmulv2ToMulPass::MapMatmulv2ToMulPass() { +MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() { AddOpCompat(OpCompat("matmul_v2")) .AddInput("X") .IsTensor() @@ -104,6 +104,45 @@ MapMatmulv2ToMulPass::MapMatmulv2ToMulPass() { .End(); } +MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() { + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumEQ(1.0f) + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("transpose_X") + .IsType() + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); +} + Flatten2MatmulFusePass::Flatten2MatmulFusePass() { AddOpCompat(OpCompat("matmul")) .AddInput("X") @@ -246,15 +285,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { std::vector y_shape = matmul_in_y->Var()->GetShape(); size_t x_rank = x_shape.size(); size_t y_rank = y_shape.size(); - flag = flag && (x_rank == 2 || x_rank == 3) && y_rank == 2; - - std::vector& next_ops = matmul_out->outputs; - flag = flag && next_ops.size() == 1 && - next_ops[0]->Name() == "elementwise_add"; + flag = flag && x_rank >= 2 && y_rank == 2; if (flag) { if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "Pass in op compat failed."; + LOG(WARNING) << "MapMatmul2MulPass in op compat failed."; return; } OpDesc desc(matmul_op->Op()->Block()); @@ -268,6 +303,8 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("out_threshold", + matmul_op->Op()->GetAttr("out_threshold")); } auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(matmul_in_x, mul_node); @@ -287,66 +324,72 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_count); } -void MapMatmulv2ToMulPass::ApplyImpl(ir::Graph* graph) const { +void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); std::string name_scope = "map_matmul_v2_to_mul_pass"; FusePassBase::Init(name_scope, graph); GraphPatternDetector gpd; - patterns::MatmulV2 matmul_pattern(gpd.mutable_pattern(), name_scope); - matmul_pattern(); + patterns::MatmulV2Weight matmul_v2_weight_pattern(gpd.mutable_pattern(), + name_scope); + matmul_v2_weight_pattern(); int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(4) << "map matmul_v2 to mul"; - GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern); - GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); - GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern); - GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern); - bool flag = true; + VLOG(3) << "map matmul_v2 to mul"; + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x, + matmul_v2_weight_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y, + matmul_v2_weight_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op, + matmul_v2_weight_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out, + matmul_v2_weight_pattern); - bool trans_x = BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("trans_x")); - bool trans_y = BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("trans_y")); + bool flag = true; + bool trans_x = + BOOST_GET_CONST(bool, matmul_v2_op->Op()->GetAttr("trans_x")); + bool trans_y = + BOOST_GET_CONST(bool, matmul_v2_op->Op()->GetAttr("trans_y")); flag = flag && !trans_x && !trans_y; - std::vector x_shape = matmul_in_x->Var()->GetShape(); - std::vector y_shape = matmul_in_y->Var()->GetShape(); + std::vector x_shape = matmul_v2_in_x->Var()->GetShape(); + std::vector y_shape = matmul_v2_in_y->Var()->GetShape(); size_t x_rank = x_shape.size(); size_t y_rank = y_shape.size(); - flag = flag && (x_rank == 2 || x_rank == 3) && y_rank == 2; - - std::vector& next_ops = matmul_out->outputs; - flag = flag && next_ops.size() == 1 && - next_ops[0]->Name() == "elementwise_add"; + flag = flag && x_rank >= 2 && y_rank == 2; if (flag) { if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "Pass in op compat failed."; + LOG(WARNING) << "MapMatmulV2ToMulPass in op compat failed."; return; } - OpDesc desc(matmul_op->Op()->Block()); + OpDesc desc(matmul_v2_op->Op()->Block()); desc.SetType("mul"); - desc.SetInput("X", {matmul_in_x->Name()}); - desc.SetInput("Y", {matmul_in_y->Name()}); - desc.SetOutput("Out", {matmul_out->Name()}); + desc.SetInput("X", {matmul_v2_in_x->Name()}); + desc.SetInput("Y", {matmul_v2_in_y->Name()}); + desc.SetOutput("Out", {matmul_v2_out->Name()}); desc.SetAttr("x_num_col_dims", static_cast(x_rank - 1)); desc.SetAttr("y_num_col_dims", 1); - if (matmul_op->Op()->HasAttr("enable_int8")) { - desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + if (matmul_v2_op->Op()->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); + desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); + desc.SetAttr("weight_scale", + matmul_v2_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("out_threshold", + matmul_v2_op->Op()->GetAttr("out_threshold")); } auto mul_node = g->CreateOpNode(&desc); - IR_NODE_LINK_TO(matmul_in_x, mul_node); - IR_NODE_LINK_TO(matmul_in_y, mul_node); - IR_NODE_LINK_TO(mul_node, matmul_out); - GraphSafeRemoveNodes(graph, {matmul_op}); + IR_NODE_LINK_TO(matmul_v2_in_x, mul_node); + IR_NODE_LINK_TO(matmul_v2_in_y, mul_node); + IR_NODE_LINK_TO(mul_node, matmul_v2_out); + GraphSafeRemoveNodes(graph, {matmul_v2_op}); ++found_count; if (!IsCompat(desc)) { - LOG(WARNING) << "MapMatmulv2ToMulPass in out mul op compat failed."; + LOG(WARNING) << "MapMatmulV2ToMulPass in out mul op compat failed."; return; } } @@ -356,6 +399,82 @@ void MapMatmulv2ToMulPass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_count); } +void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + std::string name_scope = "map_matmul_v2_to_matmul_pass"; + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::MatmulV2 matmul_v2_pattern(gpd.mutable_pattern(), name_scope); + matmul_v2_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "map matmul_v2 to matmul"; + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x, + matmul_v2_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y, + matmul_v2_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op, matmul_v2_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out, matmul_v2_pattern); + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "MapMatmulV2ToMatmulPass in op compat failed."; + return; + } + + std::vector x_shape = matmul_v2_in_x->Var()->GetShape(); + std::vector y_shape = matmul_v2_in_y->Var()->GetShape(); + if (x_shape.size() != y_shape.size()) { + LOG(WARNING) + << "matmul op not support broadcast, please check inputs'shape. "; + return; + } + uint64_t dims = 2; + for (size_t i = 0; i < x_shape.size() - dims; ++i) { + if (x_shape[i] != y_shape[i] && (x_shape[i] == 1 || y_shape[i] == 1)) { + LOG(WARNING) << "matmul op not support broadcast, please check " + "inputs'shape[i]. "; + return; + } + } + + OpDesc desc(matmul_v2_op->Op()->Block()); + desc.SetType("matmul"); + desc.SetInput("X", {matmul_v2_in_x->Name()}); + desc.SetInput("Y", {matmul_v2_in_y->Name()}); + desc.SetOutput("Out", {matmul_v2_out->Name()}); + desc.SetAttr("transpose_X", matmul_v2_op->Op()->GetAttr("trans_x")); + desc.SetAttr("transpose_Y", matmul_v2_op->Op()->GetAttr("trans_y")); + desc.SetAttr("alpha", 1.0f); + if (matmul_v2_op->Op()->HasAttr("use_mkldnn")) { + desc.SetAttr("use_mkldnn", matmul_v2_op->Op()->GetAttr("use_mkldnn")); + } + if (matmul_v2_op->Op()->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); + desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); + desc.SetAttr("weight_scale", matmul_v2_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("out_threshold", + matmul_v2_op->Op()->GetAttr("out_threshold")); + } + auto matmul_node = g->CreateOpNode(&desc); + IR_NODE_LINK_TO(matmul_v2_in_x, matmul_node); + IR_NODE_LINK_TO(matmul_v2_in_y, matmul_node); + IR_NODE_LINK_TO(matmul_node, matmul_v2_out); + GraphSafeRemoveNodes(graph, {matmul_v2_op}); + ++found_count; + + if (!IsCompat(desc)) { + LOG(WARNING) << "MapMatmulV2ToMatmulPass in out matmul op compat failed."; + return; + } + }; + + gpd(graph, handler); + AddStatis(found_count); +} + void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); @@ -402,7 +521,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { if (flag) { if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "Pass in op compat failed."; + LOG(WARNING) << "Squeeze2MatmulFusePass in op compat failed."; return; } OpDesc desc(matmul_op->Op()->Block()); @@ -416,6 +535,8 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("out_threshold", + matmul_op->Op()->GetAttr("out_threshold")); } auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(squeeze2_in_x, mul_node); @@ -544,7 +665,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { if (flag) { if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "Pass in op compat failed."; + LOG(WARNING) << "Reshape2MatmulFusePass in op compat failed."; return; } OpDesc desc(matmul_op->Op()->Block()); @@ -558,9 +679,11 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("out_threshold", + matmul_op->Op()->GetAttr("out_threshold")); } if (!IsCompat(desc)) { - LOG(WARNING) << "reshape2 matmul pass in out mul op compat failed."; + LOG(WARNING) << "Reshape2MatmulFusePass in out mul op compat failed."; return; } auto mul_node = g->CreateOpNode(&desc); @@ -629,7 +752,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { if (pattern_found) { if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "Pass in op compat failed."; + LOG(WARNING) << "Flatten2MatmulFusePass in op compat failed."; return; } OpDesc desc(matmul_op->Op()->Block()); @@ -643,6 +766,8 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("out_threshold", + matmul_op->Op()->GetAttr("out_threshold")); } auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(flatten2_in_x, mul_node); @@ -674,13 +799,21 @@ REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass) .EQ("mul", 0)); REGISTER_PASS(map_matmul_v2_to_mul_pass, - paddle::framework::ir::MapMatmulv2ToMulPass); + paddle::framework::ir::MapMatmulV2ToMulPass); REGISTER_PASS_CAPABILITY(map_matmul_v2_to_mul_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .EQ("matmul_v2", 0) .EQ("mul", 0)); +REGISTER_PASS(map_matmul_v2_to_matmul_pass, + paddle::framework::ir::MapMatmulV2ToMatmulPass); +REGISTER_PASS_CAPABILITY(map_matmul_v2_to_matmul_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("matmul_v2", 0) + .LE("matmul", 1)); + REGISTER_PASS(squeeze2_matmul_fuse_pass, paddle::framework::ir::Squeeze2MatmulFusePass); REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass) diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h index 8f462810fc..a924cd8ddf 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h @@ -49,10 +49,22 @@ class MapMatmul2MulPass : public FusePassBase { /* * Map matmul_v2 to mul, the same as MapMatmul2MulPass. */ -class MapMatmulv2ToMulPass : public FusePassBase { +class MapMatmulV2ToMulPass : public FusePassBase { public: - MapMatmulv2ToMulPass(); - virtual ~MapMatmulv2ToMulPass() {} + MapMatmulV2ToMulPass(); + virtual ~MapMatmulV2ToMulPass() {} + + protected: + void ApplyImpl(Graph* graph) const override; +}; + +/* + * Map matmul_v2 to matmul, not supoort broadcast. + */ +class MapMatmulV2ToMatmulPass : public FusePassBase { + public: + MapMatmulV2ToMatmulPass(); + virtual ~MapMatmulV2ToMatmulPass() {} protected: void ApplyImpl(Graph* graph) const override; diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index 4c0b28fd42..8bbe6a12d8 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -461,7 +461,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) ->assert_is_op_output("transpose2"); - transpose2_0_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops); + transpose2_0_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops, "X"); auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops); @@ -1174,6 +1174,23 @@ MultiHeadMatmulV3FusePass::MultiHeadMatmulV3FusePass() { .IsType() .End(); + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsBoolEQ(false) + .End() + .AddAttr("trans_y") // QK(true) QKV(false) + .IsType() + .End(); + AddOpCompat(OpCompat("softmax")) .AddInput("X") .IsTensor() diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 9eccf0a614..8a54b04f4d 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -93,8 +93,9 @@ const std::vector kTRTSubgraphPasses({ "squeeze2_matmul_fuse_pass", // "reshape2_matmul_fuse_pass", // "flatten2_matmul_fuse_pass", // - "map_matmul_to_mul_pass", // "map_matmul_v2_to_mul_pass", // + "map_matmul_v2_to_matmul_pass", // + "map_matmul_to_mul_pass", // "fc_fuse_pass", // "conv_elementwise_add_fuse_pass", // "add_support_int8_pass", @@ -142,8 +143,9 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "squeeze2_matmul_fuse_pass", // "reshape2_matmul_fuse_pass", // "flatten2_matmul_fuse_pass", // - "map_matmul_to_mul_pass", // "map_matmul_v2_to_mul_pass", // + "map_matmul_v2_to_matmul_pass", // + "map_matmul_to_mul_pass", // "fc_fuse_pass", // "fc_elementwise_layernorm_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be @@ -196,15 +198,16 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { // "embedding_fc_lstm_fuse_pass", // // TODO(wilber): fix correctness problem. // "fc_lstm_fuse_pass", // - "mul_lstm_fuse_pass", // - "fc_gru_fuse_pass", // - "mul_gru_fuse_pass", // - "seq_concat_fc_fuse_pass", // - "squeeze2_matmul_fuse_pass", // - "reshape2_matmul_fuse_pass", // - "flatten2_matmul_fuse_pass", // + "mul_lstm_fuse_pass", // + "fc_gru_fuse_pass", // + "mul_gru_fuse_pass", // + "seq_concat_fc_fuse_pass", // + "squeeze2_matmul_fuse_pass", // + "reshape2_matmul_fuse_pass", // + "flatten2_matmul_fuse_pass", // + "map_matmul_v2_to_mul_pass", // + // "map_matmul_v2_to_matmul_pass", // "map_matmul_to_mul_pass", // - "map_matmul_v2_to_mul_pass", // "fc_fuse_pass", // "repeated_fc_relu_fuse_pass", // "squared_mat_sub_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 7049df4b30..93ecde789c 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -340,6 +340,26 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, "the pass."; return false; } + + // not support broadcast + auto* x_var_desc = block->FindVar(desc.Input("X")[0]); + auto* y_var_desc = block->FindVar(desc.Input("Y")[0]); + const auto x_shape = x_var_desc->GetShape(); + const auto y_shape = y_var_desc->GetShape(); + if (x_shape.size() != y_shape.size()) { + VLOG(3) + << "matmul op not support broadcast, please check inputs'shape. "; + return false; + } + uint64_t dims = 2; + for (size_t i = 0; i < x_shape.size() - dims; ++i) { + if (x_shape[i] != y_shape[i] && (x_shape[i] == 1 || y_shape[i] == 1)) { + VLOG(3) << "matmul op not support broadcast, please check " + "inputs'shape[i]. "; + return false; + } + } + for (auto& param_name : desc.Inputs()) { for (auto& var_name : param_name.second) { auto* var_desc = block->FindVar(var_name); @@ -1330,6 +1350,47 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } if (op_type == "fc") { + auto* block = desc.Block(); + if (block == nullptr) { + VLOG(3) << "The block desc is nullptr, we can't continue to analyze. " + "Developers need to check whether block_desc is passed in " + "the pass."; + return false; + } + + // y'shapes == 2 + auto fc_inputs = desc.Inputs(); + std::string fc_y = ""; + if (fc_inputs.find("Y") != fc_inputs.end()) { + fc_y = "Y"; + } else if (fc_inputs.find("W") != fc_inputs.end()) { + fc_y = "W"; + } else { + VLOG(3) << " input_y(fc_op) must be Y or W "; + return false; + } + + // There is currently no input: Y(weight) more than two dimensions + /* + auto* y_var_desc = block->FindVar(desc.Input(fc_y)[0]); + const auto y_shape = y_var_desc->GetShape(); + if (y_shape.size() != 2) { + VLOG(3) + << " input_y(fc_op)'shapes must be 2, but input_y(fc_op)'shapes = " + << y_shape.size(); + return false; + } + // y_num_col_dims ==1 + if (desc.HasAttr("y_num_col_dims")) { + int y_num_col_dims = + BOOST_GET_CONST(int, desc.GetAttr("y_num_col_dims")); + if (y_num_col_dims != 1) { + VLOG(3) << " fc_op'y_num_col_dims must be 1, but y_num_col_dims = " + << y_num_col_dims; + return false; + } + } + */ int x_num_col_dims = desc.HasAttr("x_num_col_dims") ? BOOST_GET_CONST(int, desc.GetAttr("x_num_col_dims")) @@ -1337,8 +1398,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ? BOOST_GET_CONST(int, desc.GetAttr("in_num_col_dims")) : 1); if (x_num_col_dims < 1) { - VLOG(3) << "converter expects x_num_col_dims >= 1, " - "but x_num_col_dims = %d."; + VLOG(3) << "fc_op expects x_num_col_dims >= 1, " + "but x_num_col_dims = " + << x_num_col_dims; return false; } } diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_fuse_statis_tester.cc b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_fuse_statis_tester.cc index b8ccb8cee5..d33b11c389 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_fuse_statis_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_fuse_statis_tester.cc @@ -36,10 +36,10 @@ TEST(Analyzer_seq_pool1_fuse_statis, fuse_statis) { ASSERT_TRUE(fuse_statis.count("repeated_fc_relu_fuse")); ASSERT_EQ(fuse_statis.at("fc_fuse"), 10); EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2); - EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 2); + EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 0); EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2); LOG(INFO) << "num_ops: " << num_ops; - EXPECT_EQ(num_ops, 171); + EXPECT_EQ(num_ops, 185); } } // namespace seq_pool1_tester diff --git a/paddle/fluid/inference/tests/infer_ut/test_LeViT.cc b/paddle/fluid/inference/tests/infer_ut/test_LeViT.cc index 2fe9b6c144..b74d1189b8 100644 --- a/paddle/fluid/inference/tests/infer_ut/test_LeViT.cc +++ b/paddle/fluid/inference/tests/infer_ut/test_LeViT.cc @@ -77,7 +77,7 @@ TEST(tensorrt_tester_LeViT, trt_fp32_bz2) { FLAGS_modeldir + "/inference.pdiparams"); config.EnableUseGpu(100, 0); config.EnableTensorRtEngine( - 1 << 20, 2, 6, paddle_infer::PrecisionType::kFloat32, false, false); + 1 << 20, 2, 50, paddle_infer::PrecisionType::kFloat32, false, false); // get groudtruth by disbale ir paddle_infer::services::PredictorPool pred_pool_no_ir(config_no_ir, 1); SingleThreadPrediction(pred_pool_no_ir.Retrive(0), &my_input_data_map, @@ -103,7 +103,7 @@ TEST(tensorrt_tester_LeViT, serial_diff_batch_trt_fp32) { config.SetModel(FLAGS_modeldir + "/inference.pdmodel", FLAGS_modeldir + "/inference.pdiparams"); config.EnableUseGpu(100, 0); - config.EnableTensorRtEngine(1 << 20, max_batch_size, 6, + config.EnableTensorRtEngine(1 << 20, max_batch_size, 50, paddle_infer::PrecisionType::kFloat32, false, false); paddle_infer::services::PredictorPool pred_pool(config, 1); @@ -145,7 +145,7 @@ TEST(tensorrt_tester_LeViT, multi_thread4_trt_fp32_bz2) { FLAGS_modeldir + "/inference.pdiparams"); config.EnableUseGpu(100, 0); config.EnableTensorRtEngine( - 1 << 20, 2, 6, paddle_infer::PrecisionType::kFloat32, false, false); + 1 << 20, 2, 50, paddle_infer::PrecisionType::kFloat32, false, false); // get groudtruth by disbale ir paddle_infer::services::PredictorPool pred_pool_no_ir(config_no_ir, 1); SingleThreadPrediction(pred_pool_no_ir.Retrive(0), &my_input_data_map, diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py index 080d1ccc90..99e99a8387 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py @@ -107,5 +107,43 @@ class TensorRTMatMulScaleTest(TensorRTMatMulTest): self.alpha = 2.0 +class TensorRTMatMulBroadcastTest(InferencePassTest): + def setUp(self): + self.set_params() + place = fluid.CPUPlace() + with fluid.program_guard(self.main_program, self.startup_program): + data_x = fluid.data( + name="data_x", shape=[-1, 6, 24], dtype="float32") + data_y = fluid.data(name="data_y", shape=[24, 16], dtype="float32") + matmul_out = fluid.layers.matmul( + x=data_x, + y=data_y, + transpose_x=self.transpose_x, + transpose_y=self.transpose_y, + alpha=self.alpha) + out = fluid.layers.batch_norm(matmul_out, is_test=True) + + self.feeds = { + "data_x": np.ones([2, 6, 24]).astype("float32"), + "data_y": np.ones([24, 16]).astype("float32") + } + self.enable_trt = True + self.trt_parameters = TensorRTMatMulBroadcastTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def set_params(self): + self.transpose_x = False + self.transpose_y = False + self.alpha = 1.0 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + if __name__ == "__main__": unittest.main() -- GitLab