未验证 提交 db7d129e 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference] rebuild matmul pass: trt and gpu_cpu (#39369)

* rebuild matmul pass: trt and gpu_cpu

* rebuild matmul pass: trt and gpu_cpu

* rebuild matmul pass: trt and gpu_cpu

* rebuild matmul pass: trt and gpu_cpu
上级 772be4f5
......@@ -62,7 +62,6 @@ pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base DEPS string_helper)
pass_library(fc_fuse_pass inference)
pass_library(map_matmul_to_mul_pass inference)
pass_library(attention_lstm_fuse_pass inference)
pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference)
......@@ -98,8 +97,14 @@ pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference)
pass_library(add_support_int8_pass inference)
pass_library(matmul_scale_fuse_pass inference)
pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)
if(WITH_TENSORRT)
pass_library(trt_map_matmul_to_mul_pass inference)
endif()
if(WITH_GPU OR WITH_ROCM)
pass_library(cudnn_placement_pass base DEPS placement_pass_base)
pass_library(embedding_eltwise_layernorm_fuse_pass inference)
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/map_matmul_to_mul_pass.h"
#include "paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.h"
#include <cmath>
#include <string>
......@@ -28,7 +28,7 @@ namespace ir {
class Node;
MapMatmul2MulPass::MapMatmul2MulPass() {
GpuCpuMapMatmul2MulPass::GpuCpuMapMatmul2MulPass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
......@@ -68,7 +68,7 @@ MapMatmul2MulPass::MapMatmul2MulPass() {
.End();
}
MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() {
GpuCpuMapMatmulV2ToMulPass::GpuCpuMapMatmulV2ToMulPass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
......@@ -104,7 +104,7 @@ MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() {
.End();
}
MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() {
GpuCpuMapMatmulV2ToMatmulPass::GpuCpuMapMatmulV2ToMatmulPass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
......@@ -143,7 +143,7 @@ MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() {
.End();
}
Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
GpuCpuFlatten2MatmulFusePass::GpuCpuFlatten2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
......@@ -197,7 +197,7 @@ Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
.End();
}
Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() {
GpuCpuSqueeze2MatmulFusePass::GpuCpuSqueeze2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
......@@ -251,10 +251,10 @@ Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() {
.End();
}
void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_to_mul_pass";
std::string name_scope = "gpu_cpu_map_matmul_to_mul_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
......@@ -264,7 +264,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "map matmul to mul";
VLOG(4) << "gpu_cpu map matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern);
......@@ -286,7 +286,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmul2MulPass in op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmul2MulPass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
......@@ -311,7 +311,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmul2MulPass in out mul op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmul2MulPass in out mul op compat failed.";
return;
}
}
......@@ -321,10 +321,10 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}
void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_v2_to_mul_pass";
std::string name_scope = "gpu_cpu_map_matmul_v2_to_mul_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
......@@ -335,7 +335,7 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(3) << "map matmul_v2 to mul";
VLOG(3) << "gpu_cpu map matmul_v2 to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_weight_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
......@@ -360,7 +360,7 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmulV2ToMulPass in op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmulV2ToMulPass in op compat failed.";
return;
}
OpDesc desc(matmul_v2_op->Op()->Block());
......@@ -386,7 +386,8 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmulV2ToMulPass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuMapMatmulV2ToMulPass in out mul op compat failed.";
return;
}
}
......@@ -396,10 +397,10 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}
void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_v2_to_matmul_pass";
std::string name_scope = "gpu_cpu_map_matmul_v2_to_matmul_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
......@@ -409,7 +410,7 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "map matmul_v2 to matmul";
VLOG(4) << "gpu_cpu map matmul_v2 to matmul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
......@@ -417,7 +418,7 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op, matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out, matmul_v2_pattern);
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmulV2ToMatmulPass in op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmulV2ToMatmulPass in op compat failed.";
return;
}
......@@ -463,7 +464,8 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmulV2ToMatmulPass in out matmul op compat failed.";
LOG(WARNING)
<< "GpuCpuMapMatmulV2ToMatmulPass in out matmul op compat failed.";
return;
}
};
......@@ -472,10 +474,10 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}
void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "squeeze2_matmul_fuse_pass";
std::string name_scope = "gpu_cpu_squeeze2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
......@@ -485,7 +487,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse squeeze2+matmul to mul";
VLOG(4) << "gpu_cpu fuse squeeze2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
......@@ -518,7 +520,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Squeeze2MatmulFusePass in op compat failed.";
LOG(WARNING) << "GpuCpuSqueeze2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
......@@ -542,7 +544,8 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
GraphSafeRemoveNodes(graph, {squeeze2_op, matmul_in_x, matmul_op});
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "Squeeze2MatmulFusePass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuSqueeze2MatmulFusePass in out mul op compat failed.";
return;
}
}
......@@ -552,7 +555,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}
Reshape2MatmulFusePass::Reshape2MatmulFusePass() {
GpuCpuReshape2MatmulFusePass::GpuCpuReshape2MatmulFusePass() {
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
......@@ -614,10 +617,10 @@ Reshape2MatmulFusePass::Reshape2MatmulFusePass() {
.End();
}
void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "reshape2_matmul_fuse_pass";
std::string name_scope = "gpu_cpu_reshape2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
......@@ -627,7 +630,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse reshape2+matmul to mul";
VLOG(4) << "gpu_cpu fuse reshape2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
......@@ -662,7 +665,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Reshape2MatmulFusePass in op compat failed.";
LOG(WARNING) << "GpuCpuReshape2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
......@@ -680,7 +683,8 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
matmul_op->Op()->GetAttr("out_threshold"));
}
if (!IsCompat(desc)) {
LOG(WARNING) << "Reshape2MatmulFusePass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuReshape2MatmulFusePass in out mul op compat failed.";
return;
}
auto mul_node = g->CreateOpNode(&desc);
......@@ -696,10 +700,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}
void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "flatten2_matmul_fuse_pass";
std::string name_scope = "gpu_cpu_flatten2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
......@@ -709,7 +713,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse flatten2+matmul to mul";
VLOG(4) << "gpu_cpu fuse flatten2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
......@@ -749,7 +753,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (pattern_found) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Flatten2MatmulFusePass in op compat failed.";
LOG(WARNING) << "GpuCpuFlatten2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
......@@ -774,7 +778,8 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "Flatten2MatmulFusePass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuFlatten2MatmulFusePass in out mul op compat failed.";
return;
}
}
......@@ -788,50 +793,51 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
} // namespace framework
} // namespace paddle
REGISTER_PASS(map_matmul_to_mul_pass, paddle::framework::ir::MapMatmul2MulPass);
REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass)
REGISTER_PASS(gpu_cpu_map_matmul_to_mul_pass,
paddle::framework::ir::GpuCpuMapMatmul2MulPass);
REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_to_mul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("mul", 0));
REGISTER_PASS(map_matmul_v2_to_mul_pass,
paddle::framework::ir::MapMatmulV2ToMulPass);
REGISTER_PASS_CAPABILITY(map_matmul_v2_to_mul_pass)
REGISTER_PASS(gpu_cpu_map_matmul_v2_to_mul_pass,
paddle::framework::ir::GpuCpuMapMatmulV2ToMulPass);
REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_v2_to_mul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("mul", 0));
REGISTER_PASS(map_matmul_v2_to_matmul_pass,
paddle::framework::ir::MapMatmulV2ToMatmulPass);
REGISTER_PASS_CAPABILITY(map_matmul_v2_to_matmul_pass)
REGISTER_PASS(gpu_cpu_map_matmul_v2_to_matmul_pass,
paddle::framework::ir::GpuCpuMapMatmulV2ToMatmulPass);
REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_v2_to_matmul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.LE("matmul", 1));
REGISTER_PASS(squeeze2_matmul_fuse_pass,
paddle::framework::ir::Squeeze2MatmulFusePass);
REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass)
REGISTER_PASS(gpu_cpu_squeeze2_matmul_fuse_pass,
paddle::framework::ir::GpuCpuSqueeze2MatmulFusePass);
REGISTER_PASS_CAPABILITY(gpu_cpu_squeeze2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("squeeze2", 0)
.EQ("mul", 0));
REGISTER_PASS(reshape2_matmul_fuse_pass,
paddle::framework::ir::Reshape2MatmulFusePass);
REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass)
REGISTER_PASS(gpu_cpu_reshape2_matmul_fuse_pass,
paddle::framework::ir::GpuCpuReshape2MatmulFusePass);
REGISTER_PASS_CAPABILITY(gpu_cpu_reshape2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("reshape2", 0)
.EQ("mul", 0));
REGISTER_PASS(flatten2_matmul_fuse_pass,
paddle::framework::ir::Flatten2MatmulFusePass);
REGISTER_PASS_CAPABILITY(flatten2_matmul_fuse_pass)
REGISTER_PASS(gpu_cpu_flatten2_matmul_fuse_pass,
paddle::framework::ir::GpuCpuFlatten2MatmulFusePass);
REGISTER_PASS_CAPABILITY(gpu_cpu_flatten2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
......
......@@ -37,22 +37,22 @@ namespace ir {
*/
class Graph;
class MapMatmul2MulPass : public FusePassBase {
class GpuCpuMapMatmul2MulPass : public FusePassBase {
public:
MapMatmul2MulPass();
virtual ~MapMatmul2MulPass() {}
GpuCpuMapMatmul2MulPass();
virtual ~GpuCpuMapMatmul2MulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Map matmul_v2 to mul, the same as MapMatmul2MulPass.
* Map matmul_v2 to mul, the same as GpuCpuMapMatmul2MulPass.
*/
class MapMatmulV2ToMulPass : public FusePassBase {
class GpuCpuMapMatmulV2ToMulPass : public FusePassBase {
public:
MapMatmulV2ToMulPass();
virtual ~MapMatmulV2ToMulPass() {}
GpuCpuMapMatmulV2ToMulPass();
virtual ~GpuCpuMapMatmulV2ToMulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
......@@ -61,10 +61,10 @@ class MapMatmulV2ToMulPass : public FusePassBase {
/*
* Map matmul_v2 to matmul, not supoort broadcast.
*/
class MapMatmulV2ToMatmulPass : public FusePassBase {
class GpuCpuMapMatmulV2ToMatmulPass : public FusePassBase {
public:
MapMatmulV2ToMatmulPass();
virtual ~MapMatmulV2ToMatmulPass() {}
GpuCpuMapMatmulV2ToMatmulPass();
virtual ~GpuCpuMapMatmulV2ToMatmulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
......@@ -89,10 +89,10 @@ class MapMatmulV2ToMatmulPass : public FusePassBase {
* the above passes to reduce the impact on other models.
*/
class Squeeze2MatmulFusePass : public FusePassBase {
class GpuCpuSqueeze2MatmulFusePass : public FusePassBase {
public:
Squeeze2MatmulFusePass();
virtual ~Squeeze2MatmulFusePass() {}
GpuCpuSqueeze2MatmulFusePass();
virtual ~GpuCpuSqueeze2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
......@@ -119,19 +119,19 @@ class Squeeze2MatmulFusePass : public FusePassBase {
* the above passes to reduce the impact on other models.
*/
class Reshape2MatmulFusePass : public FusePassBase {
class GpuCpuReshape2MatmulFusePass : public FusePassBase {
public:
Reshape2MatmulFusePass();
virtual ~Reshape2MatmulFusePass() {}
GpuCpuReshape2MatmulFusePass();
virtual ~GpuCpuReshape2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
class Flatten2MatmulFusePass : public FusePassBase {
class GpuCpuFlatten2MatmulFusePass : public FusePassBase {
public:
Flatten2MatmulFusePass();
virtual ~Flatten2MatmulFusePass() {}
GpuCpuFlatten2MatmulFusePass();
virtual ~GpuCpuFlatten2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.h"
#include <cmath>
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
TrtMapMatmul2MulPass::TrtMapMatmul2MulPass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}
TrtMapMatmulV2ToMulPass::TrtMapMatmulV2ToMulPass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}
TrtMapMatmulV2ToMatmulPass::TrtMapMatmulV2ToMatmulPass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumEQ(1.0f)
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
}
TrtFlatten2MatmulFusePass::TrtFlatten2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("flatten2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}
TrtSqueeze2MatmulFusePass::TrtSqueeze2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("squeeze2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axes")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}
void TrtMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "trt_map_matmul_to_mul_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::Matmul matmul_pattern(gpd.mutable_pattern(), name_scope);
matmul_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "trt map matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);
bool flag = true;
bool transpose_X =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X"));
float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha"));
flag = flag && !transpose_X && std::abs(alpha - 1.0) < 1e-5;
std::vector<int64_t> x_shape = matmul_in_x->Var()->GetShape();
std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
flag = flag && x_rank >= 2 && y_rank == 2;
if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "TrtMapMatmul2MulPass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
desc.SetType("mul");
desc.SetInput("X", {matmul_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", static_cast<int>(x_rank - 1));
desc.SetAttr("y_num_col_dims", 1);
desc.SetAttr("transpose_Y", matmul_op->Op()->GetAttr("transpose_Y"));
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
}
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node);
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {matmul_op});
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "TrtMapMatmul2MulPass in out mul op compat failed.";
return;
}
}
};
gpd(graph, handler);
AddStatis(found_count);
}
void TrtMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "trt_map_matmul_v2_to_mul_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::MatmulV2Weight matmul_v2_weight_pattern(gpd.mutable_pattern(),
name_scope);
matmul_v2_weight_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(3) << "trt map matmul_v2 to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_weight_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
matmul_v2_weight_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op,
matmul_v2_weight_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out,
matmul_v2_weight_pattern);
bool flag = true;
bool trans_x =
BOOST_GET_CONST(bool, matmul_v2_op->Op()->GetAttr("trans_x"));
flag = flag && !trans_x;
std::vector<int64_t> x_shape = matmul_v2_in_x->Var()->GetShape();
std::vector<int64_t> y_shape = matmul_v2_in_y->Var()->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
flag = flag && x_rank >= 2 && y_rank == 2;
if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "TrtMapMatmulV2ToMulPass in op compat failed.";
return;
}
OpDesc desc(matmul_v2_op->Op()->Block());
desc.SetType("mul");
desc.SetInput("X", {matmul_v2_in_x->Name()});
desc.SetInput("Y", {matmul_v2_in_y->Name()});
desc.SetOutput("Out", {matmul_v2_out->Name()});
desc.SetAttr("x_num_col_dims", static_cast<int>(x_rank - 1));
desc.SetAttr("y_num_col_dims", 1);
desc.SetAttr("transpose_Y", matmul_v2_op->Op()->GetAttr("trans_y"));
if (matmul_v2_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale",
matmul_v2_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold",
matmul_v2_op->Op()->GetAttr("out_threshold"));
}
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_v2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_v2_in_y, mul_node);
IR_NODE_LINK_TO(mul_node, matmul_v2_out);
GraphSafeRemoveNodes(graph, {matmul_v2_op});
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "TrtMapMatmulV2ToMulPass in out mul op compat failed.";
return;
}
}
};
gpd(graph, handler);
AddStatis(found_count);
}
void TrtMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "trt_map_matmul_v2_to_matmul_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::MatmulV2 matmul_v2_pattern(gpd.mutable_pattern(), name_scope);
matmul_v2_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "trt map matmul_v2 to matmul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op, matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out, matmul_v2_pattern);
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "TrtMapMatmulV2ToMatmulPass in op compat failed.";
return;
}
std::vector<int64_t> x_shape = matmul_v2_in_x->Var()->GetShape();
std::vector<int64_t> y_shape = matmul_v2_in_y->Var()->GetShape();
if (x_shape.size() != y_shape.size()) {
LOG(WARNING)
<< "matmul op not support broadcast, please check inputs'shape. ";
return;
}
uint64_t dims = 2;
for (size_t i = 0; i < x_shape.size() - dims; ++i) {
if (x_shape[i] != y_shape[i] && (x_shape[i] == 1 || y_shape[i] == 1)) {
LOG(WARNING) << "matmul op not support broadcast, please check "
"inputs'shape[i]. ";
return;
}
}
OpDesc desc(matmul_v2_op->Op()->Block());
desc.SetType("matmul");
desc.SetInput("X", {matmul_v2_in_x->Name()});
desc.SetInput("Y", {matmul_v2_in_y->Name()});
desc.SetOutput("Out", {matmul_v2_out->Name()});
desc.SetAttr("transpose_X", matmul_v2_op->Op()->GetAttr("trans_x"));
desc.SetAttr("transpose_Y", matmul_v2_op->Op()->GetAttr("trans_y"));
desc.SetAttr("alpha", 1.0f);
if (matmul_v2_op->Op()->HasAttr("use_mkldnn")) {
desc.SetAttr("use_mkldnn", matmul_v2_op->Op()->GetAttr("use_mkldnn"));
}
if (matmul_v2_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_v2_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold",
matmul_v2_op->Op()->GetAttr("out_threshold"));
}
auto matmul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_v2_in_x, matmul_node);
IR_NODE_LINK_TO(matmul_v2_in_y, matmul_node);
IR_NODE_LINK_TO(matmul_node, matmul_v2_out);
GraphSafeRemoveNodes(graph, {matmul_v2_op});
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING)
<< "TrtMapMatmulV2ToMatmulPass in out matmul op compat failed.";
return;
}
};
gpd(graph, handler);
AddStatis(found_count);
}
void TrtSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "trt_squeeze2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::Squeeze2Matmul fuse_pattern(gpd.mutable_pattern(), name_scope);
fuse_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "trt fuse squeeze2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, fuse_pattern);
bool flag = true;
size_t squeeze2_in_x_rank = (squeeze2_in_x->Var()->GetShape()).size();
std::vector<int> squeeze2_op_axes =
BOOST_GET_CONST(std::vector<int>, squeeze2_op->Op()->GetAttr("axes"));
flag = flag && squeeze2_in_x_rank == 4 &&
squeeze2_op_axes == std::vector<int>{2, 3} &&
(matmul_in_x->outputs).size() == 1;
bool transpose_X =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X"));
bool transpose_Y =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_Y"));
float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha"));
size_t matmul_in_x_rank = (matmul_in_x->Var()->GetShape()).size();
size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size();
flag = flag && !transpose_X && !transpose_Y &&
std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 &&
matmul_in_y_rank == 2;
std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 &&
next_ops[0]->Name() == "elementwise_add";
if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "TrtSqueeze2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
desc.SetType("mul");
desc.SetInput("X", {squeeze2_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", 1);
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
}
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(squeeze2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node);
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {squeeze2_op, matmul_in_x, matmul_op});
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING)
<< "TrtSqueeze2MatmulFusePass in out mul op compat failed.";
return;
}
}
};
gpd(graph, handler);
AddStatis(found_count);
}
TrtReshape2MatmulFusePass::TrtReshape2MatmulFusePass() {
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape") // ints
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGT(0.99999f)
.IsNumLT(1.00001f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}
void TrtReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "trt_reshape2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::Reshape2Matmul fuse_pattern(gpd.mutable_pattern(), name_scope);
fuse_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "trt fuse reshape2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, fuse_pattern);
bool flag = true;
size_t reshape2_in_nums = reshape2_op->inputs.size();
auto reshape2_in_x_shape = reshape2_in_x->Var()->GetShape();
size_t reshape2_in_x_rank = reshape2_in_x_shape.size();
std::vector<int> reshape2_op_shape =
BOOST_GET_CONST(std::vector<int>, reshape2_op->Op()->GetAttr("shape"));
flag = flag && reshape2_in_nums == 1 && reshape2_in_x_rank == 4 &&
reshape2_in_x_shape[2] == 1 && reshape2_in_x_shape[3] == 1 &&
reshape2_op_shape.size() == 2 && (matmul_in_x->outputs).size() == 1;
bool transpose_X =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X"));
bool transpose_Y =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_Y"));
float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha"));
size_t matmul_in_x_rank = (matmul_in_x->Var()->GetShape()).size();
size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size();
flag = flag && !transpose_X && !transpose_Y &&
std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 &&
matmul_in_y_rank == 2;
std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 &&
next_ops[0]->Name() == "elementwise_add";
if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "TrtReshape2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
desc.SetType("mul");
desc.SetInput("X", {reshape2_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", 1);
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
}
if (!IsCompat(desc)) {
LOG(WARNING)
<< "TrtReshape2MatmulFusePass in out mul op compat failed.";
return;
}
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(reshape2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node);
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {reshape2_op, matmul_in_x, matmul_op});
++found_count;
}
};
gpd(graph, handler);
AddStatis(found_count);
}
void TrtFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "trt_flatten2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::Flatten2Matmul fuse_pattern(gpd.mutable_pattern(), name_scope);
fuse_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "trt fuse flatten2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, fuse_pattern);
bool pattern_found = true;
size_t flatten2_in_nums = flatten2_op->inputs.size();
auto flatten2_in_x_shape = flatten2_in_x->Var()->GetShape();
size_t flatten2_in_x_rank = flatten2_in_x_shape.size();
int flatten2_axis =
BOOST_GET_CONST(int, flatten2_op->Op()->GetAttr("axis"));
// only convert matmul to mul when the flatten2 has a single input
// and the rank of input is 4 and the size of the output of matmul
// is 1.
pattern_found = pattern_found && flatten2_in_nums == 1 &&
flatten2_in_x_rank == 4 &&
(matmul_in_x->outputs).size() == 1;
bool transpose_X =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X"));
bool transpose_Y =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_Y"));
float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha"));
size_t matmul_in_x_rank = (matmul_in_x->Var()->GetShape()).size();
size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size();
pattern_found = pattern_found && !transpose_X && !transpose_Y &&
std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 &&
matmul_in_y_rank == 2;
std::vector<Node*>& next_ops = matmul_out->outputs;
// we further require the matmul op is followed by one elementwise
// add op.
pattern_found = pattern_found && next_ops.size() == 1 &&
next_ops[0]->Name() == "elementwise_add";
if (pattern_found) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "TrtFlatten2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
desc.SetType("mul");
desc.SetInput("X", {flatten2_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", flatten2_axis);
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
}
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(flatten2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node);
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {flatten2_op, matmul_in_x, matmul_op});
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING)
<< "TrtFlatten2MatmulFusePass in out mul op compat failed.";
return;
}
}
};
gpd(graph, handler);
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(trt_map_matmul_to_mul_pass,
paddle::framework::ir::TrtMapMatmul2MulPass);
REGISTER_PASS_CAPABILITY(trt_map_matmul_to_mul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("mul", 0));
REGISTER_PASS(trt_map_matmul_v2_to_mul_pass,
paddle::framework::ir::TrtMapMatmulV2ToMulPass);
REGISTER_PASS_CAPABILITY(trt_map_matmul_v2_to_mul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("mul", 0));
REGISTER_PASS(trt_map_matmul_v2_to_matmul_pass,
paddle::framework::ir::TrtMapMatmulV2ToMatmulPass);
REGISTER_PASS_CAPABILITY(trt_map_matmul_v2_to_matmul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.LE("matmul", 1));
REGISTER_PASS(trt_squeeze2_matmul_fuse_pass,
paddle::framework::ir::TrtSqueeze2MatmulFusePass);
REGISTER_PASS_CAPABILITY(trt_squeeze2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("squeeze2", 0)
.EQ("mul", 0));
REGISTER_PASS(trt_reshape2_matmul_fuse_pass,
paddle::framework::ir::TrtReshape2MatmulFusePass);
REGISTER_PASS_CAPABILITY(trt_reshape2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("reshape2", 0)
.EQ("mul", 0));
REGISTER_PASS(trt_flatten2_matmul_fuse_pass,
paddle::framework::ir::TrtFlatten2MatmulFusePass);
REGISTER_PASS_CAPABILITY(trt_flatten2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("flatten2", 0)
.EQ("mul", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class TrtMapMatmul2MulPass : public FusePassBase {
public:
TrtMapMatmul2MulPass();
virtual ~TrtMapMatmul2MulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Map matmul_v2 to mul, the same as TrtMapMatmul2MulPass.
*/
class TrtMapMatmulV2ToMulPass : public FusePassBase {
public:
TrtMapMatmulV2ToMulPass();
virtual ~TrtMapMatmulV2ToMulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Map matmul_v2 to matmul, not supoort broadcast.
*/
class TrtMapMatmulV2ToMatmulPass : public FusePassBase {
public:
TrtMapMatmulV2ToMatmulPass();
virtual ~TrtMapMatmulV2ToMatmulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Fuse squeeze2+matmul to mul, so the optimization can use fc_fuse_pass.
* The squeeze2 op must satisfy the following conditions:
* 1. the rank of input X is 4
* 2. the axis attr is [2, 3]
* 3. the next op is only matmul
*
* The matmul op must satisfy the following conditions:
* 1. the transpose_X and transpose_Y attrs are false
* 2. the alpha attr is 1.0
* 3. the rank of input X and Y is 2
* 4. the next op of matmul is only elementwise_add
*
* Notice:
* the rank of input activation is obtained from var_desc,
* it maybe change in runtime. Therefore, the pass considers
* the above passes to reduce the impact on other models.
*/
class TrtSqueeze2MatmulFusePass : public FusePassBase {
public:
TrtSqueeze2MatmulFusePass();
virtual ~TrtSqueeze2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Fuse reshape2+matmul to mul, so the optimization can use fc_fuse_pass.
* The reshape2 op must satisfy the following conditions:
* 1. reshape2 has one input node, which means it don't
* have Shape or ShapeTensor input
* 2. the rank of input X is 4 and the last two dims of input X is 1
* 3. the rank of shape attr is 2
* 4. the next op is only matmul
*
* The matmul op must satisfy the following conditions:
* 1. the transpose_X and transpose_Y attrs are false
* 2. the alpha attr is 1.0
* 3. the rank of input X and Y is 2
* 4. the next op of matmul is only elementwise_add
*
* Notice:
* the shape and rank of input activation is obtained from var_desc,
* they maybe change in runtime. Therefore, the pass considers
* the above passes to reduce the impact on other models.
*/
class TrtReshape2MatmulFusePass : public FusePassBase {
public:
TrtReshape2MatmulFusePass();
virtual ~TrtReshape2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
class TrtFlatten2MatmulFusePass : public FusePassBase {
public:
TrtFlatten2MatmulFusePass();
virtual ~TrtFlatten2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -90,12 +90,12 @@ const std::vector<std::string> kTRTSubgraphPasses({
"skip_layernorm_fuse_pass", //
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
"squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"map_matmul_v2_to_mul_pass", //
"map_matmul_v2_to_matmul_pass", //
"map_matmul_to_mul_pass", //
"trt_squeeze2_matmul_fuse_pass", //
"trt_reshape2_matmul_fuse_pass", //
"trt_flatten2_matmul_fuse_pass", //
"trt_map_matmul_v2_to_mul_pass", //
"trt_map_matmul_v2_to_matmul_pass", //
"trt_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"add_support_int8_pass",
......@@ -140,12 +140,12 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"map_matmul_v2_to_mul_pass", //
"map_matmul_v2_to_matmul_pass", //
"map_matmul_to_mul_pass", //
"gpu_cpu_squeeze2_matmul_fuse_pass", //
"gpu_cpu_reshape2_matmul_fuse_pass", //
"gpu_cpu_flatten2_matmul_fuse_pass", //
"gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"gpu_cpu_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
......@@ -202,14 +202,14 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"fc_gru_fuse_pass", //
"mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", //
"squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"gpu_cpu_squeeze2_matmul_fuse_pass", //
"gpu_cpu_reshape2_matmul_fuse_pass", //
"gpu_cpu_flatten2_matmul_fuse_pass", //
"matmul_v2_scale_fuse_pass", //
"map_matmul_v2_to_mul_pass", //
"map_matmul_v2_to_matmul_pass", //
"gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"matmul_scale_fuse_pass", //
"map_matmul_to_mul_pass", //
"gpu_cpu_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"repeated_fc_relu_fuse_pass", //
"squared_mat_sub_fuse_pass", //
......
......@@ -67,13 +67,7 @@ class FcOpConverter : public OpConverter {
nvinfer1::Dims x_dim, int x_num_col_dims) {
// add shuffle after fc
nvinfer1::Dims reshape_after_fc_dim;
if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 &&
x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 1) {
// If use tensorrt'oss, the x_dim and x_num_col_dims need change
reshape_after_fc_dim.nbDims = 4;
} else {
reshape_after_fc_dim.nbDims = x_num_col_dims + 1;
}
for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) {
reshape_after_fc_dim.d[i] = 0;
}
......@@ -141,7 +135,6 @@ class FcOpConverter : public OpConverter {
"The fc's weight should be a matrix with 2 dims, but "
"it's %d-dimensional.",
Y_t->dims().size())); // a matrix
size_t n_output = Y_t->dims()[1];
int m = Y_t->dims()[0];
int n = Y_t->dims()[1];
auto tranpose_weight = [](const float* src, float* dst, int m, int n) {
......@@ -175,9 +168,10 @@ class FcOpConverter : public OpConverter {
fc_layer_int8->getOutput(0), x_dim, x_num_col_dims);
if (activation_type == "relu") {
fc_after_reshape_int8->setName(
("fc_op_int8_reshape_after_fc: Shuffle (Output: " + output_name +
")")
("int8_reshape_after_fc: Shuffle (Output: " + output_name + ")")
.c_str());
engine_->SetTensorDynamicRange(fc_after_reshape_int8->getOutput(0),
out_scale);
nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_after_reshape_int8->getOutput(0)),
nvinfer1::ActivationType::kRELU);
......@@ -200,8 +194,7 @@ class FcOpConverter : public OpConverter {
fc_layer_float->getOutput(0), x_dim, x_num_col_dims);
if (activation_type == "relu") {
fc_after_reshape_float->setName(
("fc_op_float_reshape_after_fc: Shuffle (Output: " + output_name +
")")
("float_reshape_after_fc: Shuffle (Output: " + output_name + ")")
.c_str());
nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_after_reshape_float->getOutput(0)),
......@@ -215,14 +208,28 @@ class FcOpConverter : public OpConverter {
}
};
bool transpose_y = false;
if (op_desc.HasAttr("transpose_Y")) {
transpose_y = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_Y"));
}
int weight_w, weight_h;
if (!transpose_y) {
std::vector<float> weight_data_tmp;
weight_data_tmp.reserve(Y_t->numel());
memcpy(weight_data_tmp.data(), weight_data, Y_t->numel() * sizeof(float));
tranpose_weight(weight_data_tmp.data(), weight_data, m, n);
weight_w = n;
weight_h = m;
} else {
weight_w = m;
weight_h = n;
}
size_t n_output = weight_w;
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<size_t>(Y_t->numel())};
weight.dims.assign({n, m});
weight.dims.assign({weight_w, weight_h});
float* bias_data = nullptr;
int bias_num = 0;
if (with_bias) {
......@@ -240,11 +247,57 @@ class FcOpConverter : public OpConverter {
if (!engine_->with_dynamic_shape()) {
x_num_col_dims--;
}
// If use tensorrt'oss, the x_dim and x_num_col_dims need change
// If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can
// not add Shuffle layer in ernie's multihead.
if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 &&
x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 2) {
x_num_col_dims = 1;
x_dim.d[3] == 1 && x_num_col_dims == 2) {
if (enable_int8) {
// add conv1x1 layer
nvinfer1::DimsHW nv_ksize(1, 1);
auto* fc_layer_int8 =
TRT_ENGINE_ADD_LAYER(engine_, Convolution, *X, n_output, nv_ksize,
weight.get(), bias.get());
if (activation_type == "relu") {
fc_layer_int8->setName(
("ernie_fc_op_int8: Convolution (Output: " + output_name + ")")
.c_str());
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("out_threshold"), true,
platform::errors::InvalidArgument(
"must have out threshold in fc layers in int8 mode"));
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0),
out_scale);
nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_layer_int8->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer_int8, "relu_after_ernie_fc_int8",
{output_name}, test_mode);
} else {
RreplenishLayerAndOutput(fc_layer_int8,
"ernie_fc_op_int8: Convolution",
{output_name}, test_mode);
}
} else {
// add fc layer
auto* fc_layer_float = TRT_ENGINE_ADD_LAYER(
engine_, FullyConnected, *X, n_output, weight.get(), bias.get());
if (activation_type == "relu") {
fc_layer_float->setName(
("ernie_fc_op_float: (Output: " + output_name + ")").c_str());
nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_layer_float->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer_float,
"relu_after_ernie_fc_float", {output_name},
test_mode);
} else {
RreplenishLayerAndOutput(fc_layer_float, "ernie_fc_op_float",
{output_name}, test_mode);
}
}
} else { // need reshape input before and after fc
PADDLE_ENFORCE_GT(
x_dim.nbDims, x_num_col_dims,
platform::errors::InvalidArgument(
......@@ -260,6 +313,7 @@ class FcOpConverter : public OpConverter {
}
regist_fc(reshape_itensor, n_output, weight, bias);
}
}
};
} // namespace tensorrt
......
......@@ -410,16 +410,16 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'multi_gru_fuse_pass')
graph = self._apply_pass(graph, 'multi_gru_seq_fuse_pass')
graph = self._apply_pass(graph, 'seq_concat_fc_fuse_pass')
graph = self._apply_pass(graph, 'squeeze2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'reshape2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'flatten2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'gpu_cpu_squeeze2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'gpu_cpu_reshape2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'gpu_cpu_flatten2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'matmul_v2_scale_fuse_pass')
graph = self._apply_pass(graph, 'squared_mat_sub_fuse_pass')
graph = self._apply_pass(graph, 'is_test_pass')
graph = self._apply_pass(graph, 'map_matmul_v2_to_mul_pass')
graph = self._apply_pass(graph, 'map_matmul_v2_to_matmul_pass')
graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_v2_to_mul_pass')
graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_v2_to_matmul_pass')
graph = self._apply_pass(graph, 'matmul_scale_fuse_pass')
graph = self._apply_pass(graph, 'map_matmul_to_mul_pass')
graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_to_mul_pass')
graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()])
......
......@@ -174,7 +174,7 @@ class TestFlatten2MatmulFusePass(PassAutoScanTest):
quant=False,
max_examples=50,
max_duration=1000,
passes=["flatten2_matmul_fuse_pass"])
passes=["gpu_cpu_flatten2_matmul_fuse_pass"])
if __name__ == "__main__":
......
......@@ -116,7 +116,7 @@ class TestMapMatmulToMulPass(PassAutoScanTest):
self.run_and_statis(
quant=False,
max_examples=100,
passes=["map_matmul_to_mul_pass"],
passes=["gpu_cpu_map_matmul_to_mul_pass"],
max_duration=180)
......
......@@ -127,7 +127,7 @@ class TestMapMatmulToMulPass(PassAutoScanTest):
self.run_and_statis(
quant=False,
max_examples=100,
passes=["map_matmul_v2_to_matmul_pass"])
passes=["gpu_cpu_map_matmul_v2_to_matmul_pass"])
if __name__ == "__main__":
......
......@@ -110,8 +110,9 @@ class TestMapMatmulToMulPass(PassAutoScanTest):
def test(self):
self.run_and_statis(
quant=False, max_examples=100,
passes=["map_matmul_v2_to_mul_pass"])
quant=False,
max_examples=100,
passes=["gpu_cpu_map_matmul_v2_to_mul_pass"])
if __name__ == "__main__":
......
......@@ -132,7 +132,7 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest):
return program_config
def sample_predictor_configs(self, program_config):
# map_matmul_v2_to_matmul_pass will affect the type of final fused op
# gpu_cpu_map_matmul_v2_to_matmul_pass will affect the type of final fused op
fused_op = "matmul_v2"
input1_dim1 = program_config.inputs["input_data1"].shape[0]
input2_dim1 = program_config.inputs["input_data2"].shape[0]
......
......@@ -172,7 +172,7 @@ class TestReshape2MatmulFusePass(PassAutoScanTest):
quant=False,
max_examples=50,
max_duration=1000,
passes=["reshape2_matmul_fuse_pass"])
passes=["gpu_cpu_reshape2_matmul_fuse_pass"])
if __name__ == "__main__":
......
......@@ -180,7 +180,7 @@ class TestSqueeze2MatmulFusePass(PassAutoScanTest):
quant=False,
max_examples=50,
max_duration=1000,
passes=["squeeze2_matmul_fuse_pass"])
passes=["gpu_cpu_squeeze2_matmul_fuse_pass"])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册