未验证 提交 30ce925c 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference]Add MatmulV2ToMatmul convert Pass, fix (matmul_v2, matmul,...

[Paddle-Inference]Add MatmulV2ToMatmul convert Pass, fix (matmul_v2, matmul, mul) convert pass, fix (matmul, mul) op_teller (#36652) (#36737)
上级 edff5b79
...@@ -181,7 +181,7 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -181,7 +181,7 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
"Weight scale should be nonzero, but get zero.")); "Weight scale should be nonzero, but get zero."));
weight_scale[i] = weight_scale[i] / range; weight_scale[i] = weight_scale[i] / range;
} }
} else { } else if (dequant_type == "fake_quantize_dequantize_abs_max") {
// Implement quantize_dequantize_abs_max quantization algorithm // Implement quantize_dequantize_abs_max quantization algorithm
float abs_max_weight = 0.; float abs_max_weight = 0.;
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
...@@ -192,6 +192,9 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -192,6 +192,9 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero")); "Weight scale should be nonzero, but get zero"));
weight_scale.push_back(abs_max_weight / range); weight_scale.push_back(abs_max_weight / range);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantize_dequantize op type: %s", dequant_type));
} }
nodes2rm.insert(quant_dequant_op_outscale); nodes2rm.insert(quant_dequant_op_outscale);
......
...@@ -1606,6 +1606,7 @@ PDNode *patterns::Matmul::operator()() { ...@@ -1606,6 +1606,7 @@ PDNode *patterns::Matmul::operator()() {
->assert_is_op_input("matmul", "X"); ->assert_is_op_input("matmul", "X");
auto matmul_in_y = pattern->NewNode(matmul_in_y_repr()) auto matmul_in_y = pattern->NewNode(matmul_in_y_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("matmul", "Y"); ->assert_is_op_input("matmul", "Y");
auto matmul_out = pattern->NewNode(matmul_out_repr()) auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput() ->AsOutput()
...@@ -1615,23 +1616,45 @@ PDNode *patterns::Matmul::operator()() { ...@@ -1615,23 +1616,45 @@ PDNode *patterns::Matmul::operator()() {
return matmul_out; return matmul_out;
} }
// MatmulV2: tensor * weight
PDNode *patterns::MatmulV2Weight::operator()() {
auto matmul_v2_op =
pattern->NewNode(matmul_v2_op_repr())->assert_is_op("matmul_v2");
auto matmul_v2_in_x = pattern->NewNode(matmul_v2_in_x_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "X");
auto matmul_v2_in_y = pattern->NewNode(matmul_v2_in_y_repr())
->AsInput()
->assert_is_persistable_var() // Y is weight
->assert_is_op_input("matmul_v2", "Y");
auto matmul_v2_out = pattern->NewNode(matmul_v2_out_repr())
->AsOutput()
->assert_is_op_output("matmul_v2", "Out");
matmul_v2_op->LinksFrom({matmul_v2_in_x, matmul_v2_in_y})
.LinksTo({matmul_v2_out});
return matmul_v2_out;
}
// MatmulV2: tensor * tensor or tensor * weight
PDNode *patterns::MatmulV2::operator()() { PDNode *patterns::MatmulV2::operator()() {
auto matmul_op = auto matmul_v2_op =
pattern->NewNode(matmul_op_repr())->assert_is_op("matmul_v2"); pattern->NewNode(matmul_v2_op_repr())->assert_is_op("matmul_v2");
auto matmul_in_x = pattern->NewNode(matmul_in_x_repr()) auto matmul_v2_in_x = pattern->NewNode(matmul_v2_in_x_repr())
->AsInput() ->AsInput()
->assert_is_op_input("matmul_v2", "X"); ->assert_is_op_input("matmul_v2", "X");
auto matmul_in_y = pattern->NewNode(matmul_in_y_repr()) auto matmul_v2_in_y = pattern->NewNode(matmul_v2_in_y_repr())
->assert_is_persistable_var() ->AsInput()
->AsInput() ->assert_is_op_input("matmul_v2", "Y");
->assert_is_op_input("matmul_v2", "Y"); auto matmul_v2_out = pattern->NewNode(matmul_v2_out_repr())
auto matmul_out = pattern->NewNode(matmul_out_repr()) ->AsOutput()
->AsOutput() ->assert_is_op_output("matmul_v2", "Out");
->assert_is_op_output("matmul_v2", "Out");
matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out}); matmul_v2_op->LinksFrom({matmul_v2_in_x, matmul_v2_in_y})
return matmul_out; .LinksTo({matmul_v2_out});
return matmul_v2_out;
} }
PDNode *patterns::Squeeze2Matmul::operator()() { PDNode *patterns::Squeeze2Matmul::operator()() {
......
...@@ -976,17 +976,28 @@ struct Matmul : public PatternBase { ...@@ -976,17 +976,28 @@ struct Matmul : public PatternBase {
PATTERN_DECL_NODE(matmul_out); PATTERN_DECL_NODE(matmul_out);
}; };
// Matmul_v2 op // MatmulV2: tensor * weight
// Forward pass for matmul_v2. struct MatmulV2Weight : public PatternBase {
MatmulV2Weight(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_v2_weight") {}
PDNode* operator()();
PATTERN_DECL_NODE(matmul_v2_in_x);
PATTERN_DECL_NODE(matmul_v2_in_y);
PATTERN_DECL_NODE(matmul_v2_op);
PATTERN_DECL_NODE(matmul_v2_out);
};
// MatmulV2: tensor * tensor or tensor * weight
struct MatmulV2 : public PatternBase { struct MatmulV2 : public PatternBase {
MatmulV2(PDPattern* pattern, const std::string& name_scope) MatmulV2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_v2") {} : PatternBase(pattern, name_scope, "matmul_v2") {}
PDNode* operator()(); PDNode* operator()();
PATTERN_DECL_NODE(matmul_in_x); PATTERN_DECL_NODE(matmul_v2_in_x);
PATTERN_DECL_NODE(matmul_in_y); PATTERN_DECL_NODE(matmul_v2_in_y);
PATTERN_DECL_NODE(matmul_op); PATTERN_DECL_NODE(matmul_v2_op);
PATTERN_DECL_NODE(matmul_out); PATTERN_DECL_NODE(matmul_v2_out);
}; };
// Squeeze2 + Matmul // Squeeze2 + Matmul
......
...@@ -68,7 +68,7 @@ MapMatmul2MulPass::MapMatmul2MulPass() { ...@@ -68,7 +68,7 @@ MapMatmul2MulPass::MapMatmul2MulPass() {
.End(); .End();
} }
MapMatmulv2ToMulPass::MapMatmulv2ToMulPass() { MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() {
AddOpCompat(OpCompat("matmul_v2")) AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -104,6 +104,45 @@ MapMatmulv2ToMulPass::MapMatmulv2ToMulPass() { ...@@ -104,6 +104,45 @@ MapMatmulv2ToMulPass::MapMatmulv2ToMulPass() {
.End(); .End();
} }
MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumEQ(1.0f)
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
}
Flatten2MatmulFusePass::Flatten2MatmulFusePass() { Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
AddOpCompat(OpCompat("matmul")) AddOpCompat(OpCompat("matmul"))
.AddInput("X") .AddInput("X")
...@@ -246,15 +285,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -246,15 +285,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape(); std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape();
size_t x_rank = x_shape.size(); size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size(); size_t y_rank = y_shape.size();
flag = flag && (x_rank == 2 || x_rank == 3) && y_rank == 2; flag = flag && x_rank >= 2 && y_rank == 2;
std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 &&
next_ops[0]->Name() == "elementwise_add";
if (flag) { if (flag) {
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed."; LOG(WARNING) << "MapMatmul2MulPass in op compat failed.";
return; return;
} }
OpDesc desc(matmul_op->Op()->Block()); OpDesc desc(matmul_op->Op()->Block());
...@@ -268,6 +303,8 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -268,6 +303,8 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
} }
auto mul_node = g->CreateOpNode(&desc); auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_in_x, mul_node); IR_NODE_LINK_TO(matmul_in_x, mul_node);
...@@ -287,66 +324,72 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -287,66 +324,72 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count); AddStatis(found_count);
} }
void MapMatmulv2ToMulPass::ApplyImpl(ir::Graph* graph) const { void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_v2_to_mul_pass"; std::string name_scope = "map_matmul_v2_to_mul_pass";
FusePassBase::Init(name_scope, graph); FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::MatmulV2 matmul_pattern(gpd.mutable_pattern(), name_scope); patterns::MatmulV2Weight matmul_v2_weight_pattern(gpd.mutable_pattern(),
matmul_pattern(); name_scope);
matmul_v2_weight_pattern();
int found_count = 0; int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "map matmul_v2 to mul"; VLOG(3) << "map matmul_v2 to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); matmul_v2_weight_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern); matmul_v2_weight_pattern);
bool flag = true; GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op,
matmul_v2_weight_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out,
matmul_v2_weight_pattern);
bool trans_x = BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("trans_x")); bool flag = true;
bool trans_y = BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("trans_y")); bool trans_x =
BOOST_GET_CONST(bool, matmul_v2_op->Op()->GetAttr("trans_x"));
bool trans_y =
BOOST_GET_CONST(bool, matmul_v2_op->Op()->GetAttr("trans_y"));
flag = flag && !trans_x && !trans_y; flag = flag && !trans_x && !trans_y;
std::vector<int64_t> x_shape = matmul_in_x->Var()->GetShape(); std::vector<int64_t> x_shape = matmul_v2_in_x->Var()->GetShape();
std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape(); std::vector<int64_t> y_shape = matmul_v2_in_y->Var()->GetShape();
size_t x_rank = x_shape.size(); size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size(); size_t y_rank = y_shape.size();
flag = flag && (x_rank == 2 || x_rank == 3) && y_rank == 2; flag = flag && x_rank >= 2 && y_rank == 2;
std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 &&
next_ops[0]->Name() == "elementwise_add";
if (flag) { if (flag) {
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed."; LOG(WARNING) << "MapMatmulV2ToMulPass in op compat failed.";
return; return;
} }
OpDesc desc(matmul_op->Op()->Block()); OpDesc desc(matmul_v2_op->Op()->Block());
desc.SetType("mul"); desc.SetType("mul");
desc.SetInput("X", {matmul_in_x->Name()}); desc.SetInput("X", {matmul_v2_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()}); desc.SetInput("Y", {matmul_v2_in_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()}); desc.SetOutput("Out", {matmul_v2_out->Name()});
desc.SetAttr("x_num_col_dims", static_cast<int>(x_rank - 1)); desc.SetAttr("x_num_col_dims", static_cast<int>(x_rank - 1));
desc.SetAttr("y_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) { if (matmul_v2_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); desc.SetAttr("weight_scale",
matmul_v2_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold",
matmul_v2_op->Op()->GetAttr("out_threshold"));
} }
auto mul_node = g->CreateOpNode(&desc); auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_in_x, mul_node); IR_NODE_LINK_TO(matmul_v2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node); IR_NODE_LINK_TO(matmul_v2_in_y, mul_node);
IR_NODE_LINK_TO(mul_node, matmul_out); IR_NODE_LINK_TO(mul_node, matmul_v2_out);
GraphSafeRemoveNodes(graph, {matmul_op}); GraphSafeRemoveNodes(graph, {matmul_v2_op});
++found_count; ++found_count;
if (!IsCompat(desc)) { if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmulv2ToMulPass in out mul op compat failed."; LOG(WARNING) << "MapMatmulV2ToMulPass in out mul op compat failed.";
return; return;
} }
} }
...@@ -356,6 +399,82 @@ void MapMatmulv2ToMulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -356,6 +399,82 @@ void MapMatmulv2ToMulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count); AddStatis(found_count);
} }
void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_v2_to_matmul_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::MatmulV2 matmul_v2_pattern(gpd.mutable_pattern(), name_scope);
matmul_v2_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "map matmul_v2 to matmul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op, matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out, matmul_v2_pattern);
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmulV2ToMatmulPass in op compat failed.";
return;
}
std::vector<int64_t> x_shape = matmul_v2_in_x->Var()->GetShape();
std::vector<int64_t> y_shape = matmul_v2_in_y->Var()->GetShape();
if (x_shape.size() != y_shape.size()) {
LOG(WARNING)
<< "matmul op not support broadcast, please check inputs'shape. ";
return;
}
uint64_t dims = 2;
for (size_t i = 0; i < x_shape.size() - dims; ++i) {
if (x_shape[i] != y_shape[i] && (x_shape[i] == 1 || y_shape[i] == 1)) {
LOG(WARNING) << "matmul op not support broadcast, please check "
"inputs'shape[i]. ";
return;
}
}
OpDesc desc(matmul_v2_op->Op()->Block());
desc.SetType("matmul");
desc.SetInput("X", {matmul_v2_in_x->Name()});
desc.SetInput("Y", {matmul_v2_in_y->Name()});
desc.SetOutput("Out", {matmul_v2_out->Name()});
desc.SetAttr("transpose_X", matmul_v2_op->Op()->GetAttr("trans_x"));
desc.SetAttr("transpose_Y", matmul_v2_op->Op()->GetAttr("trans_y"));
desc.SetAttr("alpha", 1.0f);
if (matmul_v2_op->Op()->HasAttr("use_mkldnn")) {
desc.SetAttr("use_mkldnn", matmul_v2_op->Op()->GetAttr("use_mkldnn"));
}
if (matmul_v2_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_v2_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold",
matmul_v2_op->Op()->GetAttr("out_threshold"));
}
auto matmul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_v2_in_x, matmul_node);
IR_NODE_LINK_TO(matmul_v2_in_y, matmul_node);
IR_NODE_LINK_TO(matmul_node, matmul_v2_out);
GraphSafeRemoveNodes(graph, {matmul_v2_op});
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmulV2ToMatmulPass in out matmul op compat failed.";
return;
}
};
gpd(graph, handler);
AddStatis(found_count);
}
void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
...@@ -402,7 +521,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -402,7 +521,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (flag) { if (flag) {
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed."; LOG(WARNING) << "Squeeze2MatmulFusePass in op compat failed.";
return; return;
} }
OpDesc desc(matmul_op->Op()->Block()); OpDesc desc(matmul_op->Op()->Block());
...@@ -416,6 +535,8 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -416,6 +535,8 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
} }
auto mul_node = g->CreateOpNode(&desc); auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(squeeze2_in_x, mul_node); IR_NODE_LINK_TO(squeeze2_in_x, mul_node);
...@@ -544,7 +665,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -544,7 +665,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (flag) { if (flag) {
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed."; LOG(WARNING) << "Reshape2MatmulFusePass in op compat failed.";
return; return;
} }
OpDesc desc(matmul_op->Op()->Block()); OpDesc desc(matmul_op->Op()->Block());
...@@ -558,9 +679,11 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -558,9 +679,11 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
} }
if (!IsCompat(desc)) { if (!IsCompat(desc)) {
LOG(WARNING) << "reshape2 matmul pass in out mul op compat failed."; LOG(WARNING) << "Reshape2MatmulFusePass in out mul op compat failed.";
return; return;
} }
auto mul_node = g->CreateOpNode(&desc); auto mul_node = g->CreateOpNode(&desc);
...@@ -629,7 +752,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -629,7 +752,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (pattern_found) { if (pattern_found) {
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed."; LOG(WARNING) << "Flatten2MatmulFusePass in op compat failed.";
return; return;
} }
OpDesc desc(matmul_op->Op()->Block()); OpDesc desc(matmul_op->Op()->Block());
...@@ -643,6 +766,8 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -643,6 +766,8 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
} }
auto mul_node = g->CreateOpNode(&desc); auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(flatten2_in_x, mul_node); IR_NODE_LINK_TO(flatten2_in_x, mul_node);
...@@ -674,13 +799,21 @@ REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass) ...@@ -674,13 +799,21 @@ REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass)
.EQ("mul", 0)); .EQ("mul", 0));
REGISTER_PASS(map_matmul_v2_to_mul_pass, REGISTER_PASS(map_matmul_v2_to_mul_pass,
paddle::framework::ir::MapMatmulv2ToMulPass); paddle::framework::ir::MapMatmulV2ToMulPass);
REGISTER_PASS_CAPABILITY(map_matmul_v2_to_mul_pass) REGISTER_PASS_CAPABILITY(map_matmul_v2_to_mul_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0) .EQ("matmul_v2", 0)
.EQ("mul", 0)); .EQ("mul", 0));
REGISTER_PASS(map_matmul_v2_to_matmul_pass,
paddle::framework::ir::MapMatmulV2ToMatmulPass);
REGISTER_PASS_CAPABILITY(map_matmul_v2_to_matmul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.LE("matmul", 1));
REGISTER_PASS(squeeze2_matmul_fuse_pass, REGISTER_PASS(squeeze2_matmul_fuse_pass,
paddle::framework::ir::Squeeze2MatmulFusePass); paddle::framework::ir::Squeeze2MatmulFusePass);
REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass) REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass)
......
...@@ -49,10 +49,22 @@ class MapMatmul2MulPass : public FusePassBase { ...@@ -49,10 +49,22 @@ class MapMatmul2MulPass : public FusePassBase {
/* /*
* Map matmul_v2 to mul, the same as MapMatmul2MulPass. * Map matmul_v2 to mul, the same as MapMatmul2MulPass.
*/ */
class MapMatmulv2ToMulPass : public FusePassBase { class MapMatmulV2ToMulPass : public FusePassBase {
public: public:
MapMatmulv2ToMulPass(); MapMatmulV2ToMulPass();
virtual ~MapMatmulv2ToMulPass() {} virtual ~MapMatmulV2ToMulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Map matmul_v2 to matmul, not supoort broadcast.
*/
class MapMatmulV2ToMatmulPass : public FusePassBase {
public:
MapMatmulV2ToMatmulPass();
virtual ~MapMatmulV2ToMatmulPass() {}
protected: protected:
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
......
...@@ -461,7 +461,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { ...@@ -461,7 +461,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2"); ->assert_is_op_output("transpose2");
transpose2_0_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops); transpose2_0_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops, "X");
auto* matmul_qk = auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops); pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops);
...@@ -1174,6 +1174,23 @@ MultiHeadMatmulV3FusePass::MultiHeadMatmulV3FusePass() { ...@@ -1174,6 +1174,23 @@ MultiHeadMatmulV3FusePass::MultiHeadMatmulV3FusePass() {
.IsType<bool>() .IsType<bool>()
.End(); .End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y") // QK(true) QKV(false)
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax")) AddOpCompat(OpCompat("softmax"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
......
...@@ -93,8 +93,9 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -93,8 +93,9 @@ const std::vector<std::string> kTRTSubgraphPasses({
"squeeze2_matmul_fuse_pass", // "squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", // "reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", // "flatten2_matmul_fuse_pass", //
"map_matmul_to_mul_pass", //
"map_matmul_v2_to_mul_pass", // "map_matmul_v2_to_mul_pass", //
"map_matmul_v2_to_matmul_pass", //
"map_matmul_to_mul_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
"tensorrt_subgraph_pass", // "tensorrt_subgraph_pass", //
...@@ -141,8 +142,9 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -141,8 +142,9 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"squeeze2_matmul_fuse_pass", // "squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", // "reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", // "flatten2_matmul_fuse_pass", //
"map_matmul_to_mul_pass", //
"map_matmul_v2_to_mul_pass", // "map_matmul_v2_to_mul_pass", //
"map_matmul_v2_to_matmul_pass", //
"map_matmul_to_mul_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", // "fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
...@@ -195,15 +197,16 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { ...@@ -195,15 +197,16 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// "embedding_fc_lstm_fuse_pass", // // "embedding_fc_lstm_fuse_pass", //
// TODO(wilber): fix correctness problem. // TODO(wilber): fix correctness problem.
// "fc_lstm_fuse_pass", // // "fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", // "mul_lstm_fuse_pass", //
"fc_gru_fuse_pass", // "fc_gru_fuse_pass", //
"mul_gru_fuse_pass", // "mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", // "seq_concat_fc_fuse_pass", //
"squeeze2_matmul_fuse_pass", // "squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", // "reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", // "flatten2_matmul_fuse_pass", //
"map_matmul_v2_to_mul_pass", //
// "map_matmul_v2_to_matmul_pass", //
"map_matmul_to_mul_pass", // "map_matmul_to_mul_pass", //
"map_matmul_v2_to_mul_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
"repeated_fc_relu_fuse_pass", // "repeated_fc_relu_fuse_pass", //
"squared_mat_sub_fuse_pass", // "squared_mat_sub_fuse_pass", //
......
...@@ -339,6 +339,26 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -339,6 +339,26 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
"the pass."; "the pass.";
return false; return false;
} }
// not support broadcast
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
auto* y_var_desc = block->FindVar(desc.Input("Y")[0]);
const auto x_shape = x_var_desc->GetShape();
const auto y_shape = y_var_desc->GetShape();
if (x_shape.size() != y_shape.size()) {
VLOG(3)
<< "matmul op not support broadcast, please check inputs'shape. ";
return false;
}
uint64_t dims = 2;
for (size_t i = 0; i < x_shape.size() - dims; ++i) {
if (x_shape[i] != y_shape[i] && (x_shape[i] == 1 || y_shape[i] == 1)) {
VLOG(3) << "matmul op not support broadcast, please check "
"inputs'shape[i]. ";
return false;
}
}
for (auto& param_name : desc.Inputs()) { for (auto& param_name : desc.Inputs()) {
for (auto& var_name : param_name.second) { for (auto& var_name : param_name.second) {
auto* var_desc = block->FindVar(var_name); auto* var_desc = block->FindVar(var_name);
...@@ -1228,6 +1248,47 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1228,6 +1248,47 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
if (op_type == "fc") { if (op_type == "fc") {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
// y'shapes == 2
auto fc_inputs = desc.Inputs();
std::string fc_y = "";
if (fc_inputs.find("Y") != fc_inputs.end()) {
fc_y = "Y";
} else if (fc_inputs.find("W") != fc_inputs.end()) {
fc_y = "W";
} else {
VLOG(3) << " input_y(fc_op) must be Y or W ";
return false;
}
// There is currently no input: Y(weight) more than two dimensions
/*
auto* y_var_desc = block->FindVar(desc.Input(fc_y)[0]);
const auto y_shape = y_var_desc->GetShape();
if (y_shape.size() != 2) {
VLOG(3)
<< " input_y(fc_op)'shapes must be 2, but input_y(fc_op)'shapes = "
<< y_shape.size();
return false;
}
// y_num_col_dims ==1
if (desc.HasAttr("y_num_col_dims")) {
int y_num_col_dims =
BOOST_GET_CONST(int, desc.GetAttr("y_num_col_dims"));
if (y_num_col_dims != 1) {
VLOG(3) << " fc_op'y_num_col_dims must be 1, but y_num_col_dims = "
<< y_num_col_dims;
return false;
}
}
*/
int x_num_col_dims = int x_num_col_dims =
desc.HasAttr("x_num_col_dims") desc.HasAttr("x_num_col_dims")
? BOOST_GET_CONST(int, desc.GetAttr("x_num_col_dims")) ? BOOST_GET_CONST(int, desc.GetAttr("x_num_col_dims"))
...@@ -1235,8 +1296,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1235,8 +1296,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
? BOOST_GET_CONST(int, desc.GetAttr("in_num_col_dims")) ? BOOST_GET_CONST(int, desc.GetAttr("in_num_col_dims"))
: 1); : 1);
if (x_num_col_dims < 1) { if (x_num_col_dims < 1) {
VLOG(3) << "converter expects x_num_col_dims >= 1, " VLOG(3) << "fc_op expects x_num_col_dims >= 1, "
"but x_num_col_dims = %d."; "but x_num_col_dims = "
<< x_num_col_dims;
return false; return false;
} }
} }
......
...@@ -36,10 +36,10 @@ TEST(Analyzer_seq_pool1_fuse_statis, fuse_statis) { ...@@ -36,10 +36,10 @@ TEST(Analyzer_seq_pool1_fuse_statis, fuse_statis) {
ASSERT_TRUE(fuse_statis.count("repeated_fc_relu_fuse")); ASSERT_TRUE(fuse_statis.count("repeated_fc_relu_fuse"));
ASSERT_EQ(fuse_statis.at("fc_fuse"), 10); ASSERT_EQ(fuse_statis.at("fc_fuse"), 10);
EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2); EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2);
EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 2); EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 0);
EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2); EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2);
LOG(INFO) << "num_ops: " << num_ops; LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 171); EXPECT_EQ(num_ops, 185);
} }
} // namespace seq_pool1_tester } // namespace seq_pool1_tester
......
...@@ -77,7 +77,7 @@ TEST(tensorrt_tester_LeViT, trt_fp32_bz2) { ...@@ -77,7 +77,7 @@ TEST(tensorrt_tester_LeViT, trt_fp32_bz2) {
FLAGS_modeldir + "/inference.pdiparams"); FLAGS_modeldir + "/inference.pdiparams");
config.EnableUseGpu(100, 0); config.EnableUseGpu(100, 0);
config.EnableTensorRtEngine( config.EnableTensorRtEngine(
1 << 20, 2, 6, paddle_infer::PrecisionType::kFloat32, false, false); 1 << 20, 2, 50, paddle_infer::PrecisionType::kFloat32, false, false);
// get groudtruth by disbale ir // get groudtruth by disbale ir
paddle_infer::services::PredictorPool pred_pool_no_ir(config_no_ir, 1); paddle_infer::services::PredictorPool pred_pool_no_ir(config_no_ir, 1);
SingleThreadPrediction(pred_pool_no_ir.Retrive(0), &my_input_data_map, SingleThreadPrediction(pred_pool_no_ir.Retrive(0), &my_input_data_map,
...@@ -103,7 +103,7 @@ TEST(tensorrt_tester_LeViT, serial_diff_batch_trt_fp32) { ...@@ -103,7 +103,7 @@ TEST(tensorrt_tester_LeViT, serial_diff_batch_trt_fp32) {
config.SetModel(FLAGS_modeldir + "/inference.pdmodel", config.SetModel(FLAGS_modeldir + "/inference.pdmodel",
FLAGS_modeldir + "/inference.pdiparams"); FLAGS_modeldir + "/inference.pdiparams");
config.EnableUseGpu(100, 0); config.EnableUseGpu(100, 0);
config.EnableTensorRtEngine(1 << 20, max_batch_size, 6, config.EnableTensorRtEngine(1 << 20, max_batch_size, 50,
paddle_infer::PrecisionType::kFloat32, false, paddle_infer::PrecisionType::kFloat32, false,
false); false);
paddle_infer::services::PredictorPool pred_pool(config, 1); paddle_infer::services::PredictorPool pred_pool(config, 1);
...@@ -145,7 +145,7 @@ TEST(tensorrt_tester_LeViT, multi_thread4_trt_fp32_bz2) { ...@@ -145,7 +145,7 @@ TEST(tensorrt_tester_LeViT, multi_thread4_trt_fp32_bz2) {
FLAGS_modeldir + "/inference.pdiparams"); FLAGS_modeldir + "/inference.pdiparams");
config.EnableUseGpu(100, 0); config.EnableUseGpu(100, 0);
config.EnableTensorRtEngine( config.EnableTensorRtEngine(
1 << 20, 2, 6, paddle_infer::PrecisionType::kFloat32, false, false); 1 << 20, 2, 50, paddle_infer::PrecisionType::kFloat32, false, false);
// get groudtruth by disbale ir // get groudtruth by disbale ir
paddle_infer::services::PredictorPool pred_pool_no_ir(config_no_ir, 1); paddle_infer::services::PredictorPool pred_pool_no_ir(config_no_ir, 1);
SingleThreadPrediction(pred_pool_no_ir.Retrive(0), &my_input_data_map, SingleThreadPrediction(pred_pool_no_ir.Retrive(0), &my_input_data_map,
......
...@@ -107,5 +107,43 @@ class TensorRTMatMulScaleTest(TensorRTMatMulTest): ...@@ -107,5 +107,43 @@ class TensorRTMatMulScaleTest(TensorRTMatMulTest):
self.alpha = 2.0 self.alpha = 2.0
class TensorRTMatMulBroadcastTest(InferencePassTest):
def setUp(self):
self.set_params()
place = fluid.CPUPlace()
with fluid.program_guard(self.main_program, self.startup_program):
data_x = fluid.data(
name="data_x", shape=[-1, 6, 24], dtype="float32")
data_y = fluid.data(name="data_y", shape=[24, 16], dtype="float32")
matmul_out = fluid.layers.matmul(
x=data_x,
y=data_y,
transpose_x=self.transpose_x,
transpose_y=self.transpose_y,
alpha=self.alpha)
out = fluid.layers.batch_norm(matmul_out, is_test=True)
self.feeds = {
"data_x": np.ones([2, 6, 24]).astype("float32"),
"data_y": np.ones([24, 16]).astype("float32")
}
self.enable_trt = True
self.trt_parameters = TensorRTMatMulBroadcastTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def set_params(self):
self.transpose_x = False
self.transpose_y = False
self.alpha = 1.0
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册