未验证 提交 db7d129e 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference] rebuild matmul pass: trt and gpu_cpu (#39369)

* rebuild matmul pass: trt and gpu_cpu

* rebuild matmul pass: trt and gpu_cpu

* rebuild matmul pass: trt and gpu_cpu

* rebuild matmul pass: trt and gpu_cpu
上级 772be4f5
...@@ -62,7 +62,6 @@ pass_library(graph_to_program_pass base) ...@@ -62,7 +62,6 @@ pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base) pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base DEPS string_helper) pass_library(lock_free_optimize_pass base DEPS string_helper)
pass_library(fc_fuse_pass inference) pass_library(fc_fuse_pass inference)
pass_library(map_matmul_to_mul_pass inference)
pass_library(attention_lstm_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference)
pass_library(fc_lstm_fuse_pass inference) pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference)
...@@ -98,8 +97,14 @@ pass_library(unsqueeze2_eltwise_fuse_pass inference) ...@@ -98,8 +97,14 @@ pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference) pass_library(layer_norm_fuse_pass inference)
pass_library(add_support_int8_pass inference) pass_library(add_support_int8_pass inference)
pass_library(matmul_scale_fuse_pass inference) pass_library(matmul_scale_fuse_pass inference)
pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
pass_library(generate_pass DEPS pass_desc_proto) pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto)
if(WITH_TENSORRT)
pass_library(trt_map_matmul_to_mul_pass inference)
endif()
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
pass_library(cudnn_placement_pass base DEPS placement_pass_base) pass_library(cudnn_placement_pass base DEPS placement_pass_base)
pass_library(embedding_eltwise_layernorm_fuse_pass inference) pass_library(embedding_eltwise_layernorm_fuse_pass inference)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/map_matmul_to_mul_pass.h" #include "paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.h"
#include <cmath> #include <cmath>
#include <string> #include <string>
...@@ -28,7 +28,7 @@ namespace ir { ...@@ -28,7 +28,7 @@ namespace ir {
class Node; class Node;
MapMatmul2MulPass::MapMatmul2MulPass() { GpuCpuMapMatmul2MulPass::GpuCpuMapMatmul2MulPass() {
AddOpCompat(OpCompat("matmul")) AddOpCompat(OpCompat("matmul"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -68,7 +68,7 @@ MapMatmul2MulPass::MapMatmul2MulPass() { ...@@ -68,7 +68,7 @@ MapMatmul2MulPass::MapMatmul2MulPass() {
.End(); .End();
} }
MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() { GpuCpuMapMatmulV2ToMulPass::GpuCpuMapMatmulV2ToMulPass() {
AddOpCompat(OpCompat("matmul_v2")) AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -104,7 +104,7 @@ MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() { ...@@ -104,7 +104,7 @@ MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() {
.End(); .End();
} }
MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() { GpuCpuMapMatmulV2ToMatmulPass::GpuCpuMapMatmulV2ToMatmulPass() {
AddOpCompat(OpCompat("matmul_v2")) AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -143,7 +143,7 @@ MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() { ...@@ -143,7 +143,7 @@ MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() {
.End(); .End();
} }
Flatten2MatmulFusePass::Flatten2MatmulFusePass() { GpuCpuFlatten2MatmulFusePass::GpuCpuFlatten2MatmulFusePass() {
AddOpCompat(OpCompat("matmul")) AddOpCompat(OpCompat("matmul"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -197,7 +197,7 @@ Flatten2MatmulFusePass::Flatten2MatmulFusePass() { ...@@ -197,7 +197,7 @@ Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
.End(); .End();
} }
Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() { GpuCpuSqueeze2MatmulFusePass::GpuCpuSqueeze2MatmulFusePass() {
AddOpCompat(OpCompat("matmul")) AddOpCompat(OpCompat("matmul"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -251,10 +251,10 @@ Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() { ...@@ -251,10 +251,10 @@ Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() {
.End(); .End();
} }
void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { void GpuCpuMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_to_mul_pass"; std::string name_scope = "gpu_cpu_map_matmul_to_mul_pass";
FusePassBase::Init(name_scope, graph); FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -264,7 +264,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -264,7 +264,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0; int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "map matmul to mul"; VLOG(4) << "gpu_cpu map matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern);
...@@ -286,7 +286,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -286,7 +286,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
if (flag) { if (flag) {
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmul2MulPass in op compat failed."; LOG(WARNING) << "GpuCpuMapMatmul2MulPass in op compat failed.";
return; return;
} }
OpDesc desc(matmul_op->Op()->Block()); OpDesc desc(matmul_op->Op()->Block());
...@@ -311,7 +311,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -311,7 +311,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
++found_count; ++found_count;
if (!IsCompat(desc)) { if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmul2MulPass in out mul op compat failed."; LOG(WARNING) << "GpuCpuMapMatmul2MulPass in out mul op compat failed.";
return; return;
} }
} }
...@@ -321,10 +321,10 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -321,10 +321,10 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count); AddStatis(found_count);
} }
void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { void GpuCpuMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_v2_to_mul_pass"; std::string name_scope = "gpu_cpu_map_matmul_v2_to_mul_pass";
FusePassBase::Init(name_scope, graph); FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -335,7 +335,7 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -335,7 +335,7 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0; int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(3) << "map matmul_v2 to mul"; VLOG(3) << "gpu_cpu map matmul_v2 to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x, GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_weight_pattern); matmul_v2_weight_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y, GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
...@@ -360,7 +360,7 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -360,7 +360,7 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
if (flag) { if (flag) {
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmulV2ToMulPass in op compat failed."; LOG(WARNING) << "GpuCpuMapMatmulV2ToMulPass in op compat failed.";
return; return;
} }
OpDesc desc(matmul_v2_op->Op()->Block()); OpDesc desc(matmul_v2_op->Op()->Block());
...@@ -386,7 +386,8 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -386,7 +386,8 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
++found_count; ++found_count;
if (!IsCompat(desc)) { if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmulV2ToMulPass in out mul op compat failed."; LOG(WARNING)
<< "GpuCpuMapMatmulV2ToMulPass in out mul op compat failed.";
return; return;
} }
} }
...@@ -396,10 +397,10 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -396,10 +397,10 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count); AddStatis(found_count);
} }
void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { void GpuCpuMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_v2_to_matmul_pass"; std::string name_scope = "gpu_cpu_map_matmul_v2_to_matmul_pass";
FusePassBase::Init(name_scope, graph); FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -409,7 +410,7 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -409,7 +410,7 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0; int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "map matmul_v2 to matmul"; VLOG(4) << "gpu_cpu map matmul_v2 to matmul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x, GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_pattern); matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y, GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
...@@ -417,7 +418,7 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -417,7 +418,7 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op, matmul_v2_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op, matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out, matmul_v2_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out, matmul_v2_pattern);
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmulV2ToMatmulPass in op compat failed."; LOG(WARNING) << "GpuCpuMapMatmulV2ToMatmulPass in op compat failed.";
return; return;
} }
...@@ -463,7 +464,8 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -463,7 +464,8 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
++found_count; ++found_count;
if (!IsCompat(desc)) { if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmulV2ToMatmulPass in out matmul op compat failed."; LOG(WARNING)
<< "GpuCpuMapMatmulV2ToMatmulPass in out matmul op compat failed.";
return; return;
} }
}; };
...@@ -472,10 +474,10 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -472,10 +474,10 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count); AddStatis(found_count);
} }
void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { void GpuCpuSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "squeeze2_matmul_fuse_pass"; std::string name_scope = "gpu_cpu_squeeze2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph); FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -485,7 +487,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -485,7 +487,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0; int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "fuse squeeze2+matmul to mul"; VLOG(4) << "gpu_cpu fuse squeeze2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
...@@ -518,7 +520,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -518,7 +520,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (flag) { if (flag) {
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Squeeze2MatmulFusePass in op compat failed."; LOG(WARNING) << "GpuCpuSqueeze2MatmulFusePass in op compat failed.";
return; return;
} }
OpDesc desc(matmul_op->Op()->Block()); OpDesc desc(matmul_op->Op()->Block());
...@@ -542,7 +544,8 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -542,7 +544,8 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
GraphSafeRemoveNodes(graph, {squeeze2_op, matmul_in_x, matmul_op}); GraphSafeRemoveNodes(graph, {squeeze2_op, matmul_in_x, matmul_op});
++found_count; ++found_count;
if (!IsCompat(desc)) { if (!IsCompat(desc)) {
LOG(WARNING) << "Squeeze2MatmulFusePass in out mul op compat failed."; LOG(WARNING)
<< "GpuCpuSqueeze2MatmulFusePass in out mul op compat failed.";
return; return;
} }
} }
...@@ -552,7 +555,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -552,7 +555,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count); AddStatis(found_count);
} }
Reshape2MatmulFusePass::Reshape2MatmulFusePass() { GpuCpuReshape2MatmulFusePass::GpuCpuReshape2MatmulFusePass() {
AddOpCompat(OpCompat("reshape2")) AddOpCompat(OpCompat("reshape2"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -614,10 +617,10 @@ Reshape2MatmulFusePass::Reshape2MatmulFusePass() { ...@@ -614,10 +617,10 @@ Reshape2MatmulFusePass::Reshape2MatmulFusePass() {
.End(); .End();
} }
void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { void GpuCpuReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "reshape2_matmul_fuse_pass"; std::string name_scope = "gpu_cpu_reshape2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph); FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -627,7 +630,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -627,7 +630,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0; int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "fuse reshape2+matmul to mul"; VLOG(4) << "gpu_cpu fuse reshape2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
...@@ -662,7 +665,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -662,7 +665,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (flag) { if (flag) {
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Reshape2MatmulFusePass in op compat failed."; LOG(WARNING) << "GpuCpuReshape2MatmulFusePass in op compat failed.";
return; return;
} }
OpDesc desc(matmul_op->Op()->Block()); OpDesc desc(matmul_op->Op()->Block());
...@@ -680,7 +683,8 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -680,7 +683,8 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
matmul_op->Op()->GetAttr("out_threshold")); matmul_op->Op()->GetAttr("out_threshold"));
} }
if (!IsCompat(desc)) { if (!IsCompat(desc)) {
LOG(WARNING) << "Reshape2MatmulFusePass in out mul op compat failed."; LOG(WARNING)
<< "GpuCpuReshape2MatmulFusePass in out mul op compat failed.";
return; return;
} }
auto mul_node = g->CreateOpNode(&desc); auto mul_node = g->CreateOpNode(&desc);
...@@ -696,10 +700,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -696,10 +700,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count); AddStatis(found_count);
} }
void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { void GpuCpuFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "flatten2_matmul_fuse_pass"; std::string name_scope = "gpu_cpu_flatten2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph); FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -709,7 +713,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -709,7 +713,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0; int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "fuse flatten2+matmul to mul"; VLOG(4) << "gpu_cpu fuse flatten2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
...@@ -749,7 +753,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -749,7 +753,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (pattern_found) { if (pattern_found) {
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Flatten2MatmulFusePass in op compat failed."; LOG(WARNING) << "GpuCpuFlatten2MatmulFusePass in op compat failed.";
return; return;
} }
OpDesc desc(matmul_op->Op()->Block()); OpDesc desc(matmul_op->Op()->Block());
...@@ -774,7 +778,8 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -774,7 +778,8 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
++found_count; ++found_count;
if (!IsCompat(desc)) { if (!IsCompat(desc)) {
LOG(WARNING) << "Flatten2MatmulFusePass in out mul op compat failed."; LOG(WARNING)
<< "GpuCpuFlatten2MatmulFusePass in out mul op compat failed.";
return; return;
} }
} }
...@@ -788,50 +793,51 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -788,50 +793,51 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(map_matmul_to_mul_pass, paddle::framework::ir::MapMatmul2MulPass); REGISTER_PASS(gpu_cpu_map_matmul_to_mul_pass,
REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass) paddle::framework::ir::GpuCpuMapMatmul2MulPass);
REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_to_mul_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1) .LE("matmul", 1)
.EQ("mul", 0)); .EQ("mul", 0));
REGISTER_PASS(map_matmul_v2_to_mul_pass, REGISTER_PASS(gpu_cpu_map_matmul_v2_to_mul_pass,
paddle::framework::ir::MapMatmulV2ToMulPass); paddle::framework::ir::GpuCpuMapMatmulV2ToMulPass);
REGISTER_PASS_CAPABILITY(map_matmul_v2_to_mul_pass) REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_v2_to_mul_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0) .EQ("matmul_v2", 0)
.EQ("mul", 0)); .EQ("mul", 0));
REGISTER_PASS(map_matmul_v2_to_matmul_pass, REGISTER_PASS(gpu_cpu_map_matmul_v2_to_matmul_pass,
paddle::framework::ir::MapMatmulV2ToMatmulPass); paddle::framework::ir::GpuCpuMapMatmulV2ToMatmulPass);
REGISTER_PASS_CAPABILITY(map_matmul_v2_to_matmul_pass) REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_v2_to_matmul_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0) .EQ("matmul_v2", 0)
.LE("matmul", 1)); .LE("matmul", 1));
REGISTER_PASS(squeeze2_matmul_fuse_pass, REGISTER_PASS(gpu_cpu_squeeze2_matmul_fuse_pass,
paddle::framework::ir::Squeeze2MatmulFusePass); paddle::framework::ir::GpuCpuSqueeze2MatmulFusePass);
REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass) REGISTER_PASS_CAPABILITY(gpu_cpu_squeeze2_matmul_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1) .LE("matmul", 1)
.EQ("squeeze2", 0) .EQ("squeeze2", 0)
.EQ("mul", 0)); .EQ("mul", 0));
REGISTER_PASS(reshape2_matmul_fuse_pass, REGISTER_PASS(gpu_cpu_reshape2_matmul_fuse_pass,
paddle::framework::ir::Reshape2MatmulFusePass); paddle::framework::ir::GpuCpuReshape2MatmulFusePass);
REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass) REGISTER_PASS_CAPABILITY(gpu_cpu_reshape2_matmul_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1) .LE("matmul", 1)
.EQ("reshape2", 0) .EQ("reshape2", 0)
.EQ("mul", 0)); .EQ("mul", 0));
REGISTER_PASS(flatten2_matmul_fuse_pass, REGISTER_PASS(gpu_cpu_flatten2_matmul_fuse_pass,
paddle::framework::ir::Flatten2MatmulFusePass); paddle::framework::ir::GpuCpuFlatten2MatmulFusePass);
REGISTER_PASS_CAPABILITY(flatten2_matmul_fuse_pass) REGISTER_PASS_CAPABILITY(gpu_cpu_flatten2_matmul_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1) .LE("matmul", 1)
......
...@@ -37,22 +37,22 @@ namespace ir { ...@@ -37,22 +37,22 @@ namespace ir {
*/ */
class Graph; class Graph;
class MapMatmul2MulPass : public FusePassBase { class GpuCpuMapMatmul2MulPass : public FusePassBase {
public: public:
MapMatmul2MulPass(); GpuCpuMapMatmul2MulPass();
virtual ~MapMatmul2MulPass() {} virtual ~GpuCpuMapMatmul2MulPass() {}
protected: protected:
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
}; };
/* /*
* Map matmul_v2 to mul, the same as MapMatmul2MulPass. * Map matmul_v2 to mul, the same as GpuCpuMapMatmul2MulPass.
*/ */
class MapMatmulV2ToMulPass : public FusePassBase { class GpuCpuMapMatmulV2ToMulPass : public FusePassBase {
public: public:
MapMatmulV2ToMulPass(); GpuCpuMapMatmulV2ToMulPass();
virtual ~MapMatmulV2ToMulPass() {} virtual ~GpuCpuMapMatmulV2ToMulPass() {}
protected: protected:
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
...@@ -61,10 +61,10 @@ class MapMatmulV2ToMulPass : public FusePassBase { ...@@ -61,10 +61,10 @@ class MapMatmulV2ToMulPass : public FusePassBase {
/* /*
* Map matmul_v2 to matmul, not supoort broadcast. * Map matmul_v2 to matmul, not supoort broadcast.
*/ */
class MapMatmulV2ToMatmulPass : public FusePassBase { class GpuCpuMapMatmulV2ToMatmulPass : public FusePassBase {
public: public:
MapMatmulV2ToMatmulPass(); GpuCpuMapMatmulV2ToMatmulPass();
virtual ~MapMatmulV2ToMatmulPass() {} virtual ~GpuCpuMapMatmulV2ToMatmulPass() {}
protected: protected:
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
...@@ -89,10 +89,10 @@ class MapMatmulV2ToMatmulPass : public FusePassBase { ...@@ -89,10 +89,10 @@ class MapMatmulV2ToMatmulPass : public FusePassBase {
* the above passes to reduce the impact on other models. * the above passes to reduce the impact on other models.
*/ */
class Squeeze2MatmulFusePass : public FusePassBase { class GpuCpuSqueeze2MatmulFusePass : public FusePassBase {
public: public:
Squeeze2MatmulFusePass(); GpuCpuSqueeze2MatmulFusePass();
virtual ~Squeeze2MatmulFusePass() {} virtual ~GpuCpuSqueeze2MatmulFusePass() {}
protected: protected:
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
...@@ -119,19 +119,19 @@ class Squeeze2MatmulFusePass : public FusePassBase { ...@@ -119,19 +119,19 @@ class Squeeze2MatmulFusePass : public FusePassBase {
* the above passes to reduce the impact on other models. * the above passes to reduce the impact on other models.
*/ */
class Reshape2MatmulFusePass : public FusePassBase { class GpuCpuReshape2MatmulFusePass : public FusePassBase {
public: public:
Reshape2MatmulFusePass(); GpuCpuReshape2MatmulFusePass();
virtual ~Reshape2MatmulFusePass() {} virtual ~GpuCpuReshape2MatmulFusePass() {}
protected: protected:
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
}; };
class Flatten2MatmulFusePass : public FusePassBase { class GpuCpuFlatten2MatmulFusePass : public FusePassBase {
public: public:
Flatten2MatmulFusePass(); GpuCpuFlatten2MatmulFusePass();
virtual ~Flatten2MatmulFusePass() {} virtual ~GpuCpuFlatten2MatmulFusePass() {}
protected: protected:
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
......
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class TrtMapMatmul2MulPass : public FusePassBase {
public:
TrtMapMatmul2MulPass();
virtual ~TrtMapMatmul2MulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Map matmul_v2 to mul, the same as TrtMapMatmul2MulPass.
*/
class TrtMapMatmulV2ToMulPass : public FusePassBase {
public:
TrtMapMatmulV2ToMulPass();
virtual ~TrtMapMatmulV2ToMulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Map matmul_v2 to matmul, not supoort broadcast.
*/
class TrtMapMatmulV2ToMatmulPass : public FusePassBase {
public:
TrtMapMatmulV2ToMatmulPass();
virtual ~TrtMapMatmulV2ToMatmulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Fuse squeeze2+matmul to mul, so the optimization can use fc_fuse_pass.
* The squeeze2 op must satisfy the following conditions:
* 1. the rank of input X is 4
* 2. the axis attr is [2, 3]
* 3. the next op is only matmul
*
* The matmul op must satisfy the following conditions:
* 1. the transpose_X and transpose_Y attrs are false
* 2. the alpha attr is 1.0
* 3. the rank of input X and Y is 2
* 4. the next op of matmul is only elementwise_add
*
* Notice:
* the rank of input activation is obtained from var_desc,
* it maybe change in runtime. Therefore, the pass considers
* the above passes to reduce the impact on other models.
*/
class TrtSqueeze2MatmulFusePass : public FusePassBase {
public:
TrtSqueeze2MatmulFusePass();
virtual ~TrtSqueeze2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Fuse reshape2+matmul to mul, so the optimization can use fc_fuse_pass.
* The reshape2 op must satisfy the following conditions:
* 1. reshape2 has one input node, which means it don't
* have Shape or ShapeTensor input
* 2. the rank of input X is 4 and the last two dims of input X is 1
* 3. the rank of shape attr is 2
* 4. the next op is only matmul
*
* The matmul op must satisfy the following conditions:
* 1. the transpose_X and transpose_Y attrs are false
* 2. the alpha attr is 1.0
* 3. the rank of input X and Y is 2
* 4. the next op of matmul is only elementwise_add
*
* Notice:
* the shape and rank of input activation is obtained from var_desc,
* they maybe change in runtime. Therefore, the pass considers
* the above passes to reduce the impact on other models.
*/
class TrtReshape2MatmulFusePass : public FusePassBase {
public:
TrtReshape2MatmulFusePass();
virtual ~TrtReshape2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
class TrtFlatten2MatmulFusePass : public FusePassBase {
public:
TrtFlatten2MatmulFusePass();
virtual ~TrtFlatten2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -90,12 +90,12 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -90,12 +90,12 @@ const std::vector<std::string> kTRTSubgraphPasses({
"skip_layernorm_fuse_pass", // "skip_layernorm_fuse_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", //
"squeeze2_matmul_fuse_pass", // "trt_squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", // "trt_reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", // "trt_flatten2_matmul_fuse_pass", //
"map_matmul_v2_to_mul_pass", // "trt_map_matmul_v2_to_mul_pass", //
"map_matmul_v2_to_matmul_pass", // "trt_map_matmul_v2_to_matmul_pass", //
"map_matmul_to_mul_pass", // "trt_map_matmul_to_mul_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
"add_support_int8_pass", "add_support_int8_pass",
...@@ -140,12 +140,12 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -140,12 +140,12 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", // "multihead_matmul_fuse_pass_v2", //
"squeeze2_matmul_fuse_pass", // "gpu_cpu_squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", // "gpu_cpu_reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", // "gpu_cpu_flatten2_matmul_fuse_pass", //
"map_matmul_v2_to_mul_pass", // "gpu_cpu_map_matmul_v2_to_mul_pass", //
"map_matmul_v2_to_matmul_pass", // "gpu_cpu_map_matmul_v2_to_matmul_pass", //
"map_matmul_to_mul_pass", // "gpu_cpu_map_matmul_to_mul_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", // "fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
...@@ -202,14 +202,14 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { ...@@ -202,14 +202,14 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"fc_gru_fuse_pass", // "fc_gru_fuse_pass", //
"mul_gru_fuse_pass", // "mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", // "seq_concat_fc_fuse_pass", //
"squeeze2_matmul_fuse_pass", // "gpu_cpu_squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", // "gpu_cpu_reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", // "gpu_cpu_flatten2_matmul_fuse_pass", //
"matmul_v2_scale_fuse_pass", // "matmul_v2_scale_fuse_pass", //
"map_matmul_v2_to_mul_pass", // "gpu_cpu_map_matmul_v2_to_mul_pass", //
"map_matmul_v2_to_matmul_pass", // "gpu_cpu_map_matmul_v2_to_matmul_pass", //
"matmul_scale_fuse_pass", // "matmul_scale_fuse_pass", //
"map_matmul_to_mul_pass", // "gpu_cpu_map_matmul_to_mul_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
"repeated_fc_relu_fuse_pass", // "repeated_fc_relu_fuse_pass", //
"squared_mat_sub_fuse_pass", // "squared_mat_sub_fuse_pass", //
......
...@@ -67,13 +67,7 @@ class FcOpConverter : public OpConverter { ...@@ -67,13 +67,7 @@ class FcOpConverter : public OpConverter {
nvinfer1::Dims x_dim, int x_num_col_dims) { nvinfer1::Dims x_dim, int x_num_col_dims) {
// add shuffle after fc // add shuffle after fc
nvinfer1::Dims reshape_after_fc_dim; nvinfer1::Dims reshape_after_fc_dim;
if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 && reshape_after_fc_dim.nbDims = x_num_col_dims + 1;
x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 1) {
// If use tensorrt'oss, the x_dim and x_num_col_dims need change
reshape_after_fc_dim.nbDims = 4;
} else {
reshape_after_fc_dim.nbDims = x_num_col_dims + 1;
}
for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) { for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) {
reshape_after_fc_dim.d[i] = 0; reshape_after_fc_dim.d[i] = 0;
} }
...@@ -141,7 +135,6 @@ class FcOpConverter : public OpConverter { ...@@ -141,7 +135,6 @@ class FcOpConverter : public OpConverter {
"The fc's weight should be a matrix with 2 dims, but " "The fc's weight should be a matrix with 2 dims, but "
"it's %d-dimensional.", "it's %d-dimensional.",
Y_t->dims().size())); // a matrix Y_t->dims().size())); // a matrix
size_t n_output = Y_t->dims()[1];
int m = Y_t->dims()[0]; int m = Y_t->dims()[0];
int n = Y_t->dims()[1]; int n = Y_t->dims()[1];
auto tranpose_weight = [](const float* src, float* dst, int m, int n) { auto tranpose_weight = [](const float* src, float* dst, int m, int n) {
...@@ -175,9 +168,10 @@ class FcOpConverter : public OpConverter { ...@@ -175,9 +168,10 @@ class FcOpConverter : public OpConverter {
fc_layer_int8->getOutput(0), x_dim, x_num_col_dims); fc_layer_int8->getOutput(0), x_dim, x_num_col_dims);
if (activation_type == "relu") { if (activation_type == "relu") {
fc_after_reshape_int8->setName( fc_after_reshape_int8->setName(
("fc_op_int8_reshape_after_fc: Shuffle (Output: " + output_name + ("int8_reshape_after_fc: Shuffle (Output: " + output_name + ")")
")")
.c_str()); .c_str());
engine_->SetTensorDynamicRange(fc_after_reshape_int8->getOutput(0),
out_scale);
nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER( nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_after_reshape_int8->getOutput(0)), engine_, Activation, *(fc_after_reshape_int8->getOutput(0)),
nvinfer1::ActivationType::kRELU); nvinfer1::ActivationType::kRELU);
...@@ -200,8 +194,7 @@ class FcOpConverter : public OpConverter { ...@@ -200,8 +194,7 @@ class FcOpConverter : public OpConverter {
fc_layer_float->getOutput(0), x_dim, x_num_col_dims); fc_layer_float->getOutput(0), x_dim, x_num_col_dims);
if (activation_type == "relu") { if (activation_type == "relu") {
fc_after_reshape_float->setName( fc_after_reshape_float->setName(
("fc_op_float_reshape_after_fc: Shuffle (Output: " + output_name + ("float_reshape_after_fc: Shuffle (Output: " + output_name + ")")
")")
.c_str()); .c_str());
nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER( nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_after_reshape_float->getOutput(0)), engine_, Activation, *(fc_after_reshape_float->getOutput(0)),
...@@ -215,14 +208,28 @@ class FcOpConverter : public OpConverter { ...@@ -215,14 +208,28 @@ class FcOpConverter : public OpConverter {
} }
}; };
std::vector<float> weight_data_tmp; bool transpose_y = false;
weight_data_tmp.reserve(Y_t->numel()); if (op_desc.HasAttr("transpose_Y")) {
memcpy(weight_data_tmp.data(), weight_data, Y_t->numel() * sizeof(float)); transpose_y = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_Y"));
tranpose_weight(weight_data_tmp.data(), weight_data, m, n); }
int weight_w, weight_h;
if (!transpose_y) {
std::vector<float> weight_data_tmp;
weight_data_tmp.reserve(Y_t->numel());
memcpy(weight_data_tmp.data(), weight_data, Y_t->numel() * sizeof(float));
tranpose_weight(weight_data_tmp.data(), weight_data, m, n);
weight_w = n;
weight_h = m;
} else {
weight_w = m;
weight_h = n;
}
size_t n_output = weight_w;
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data), static_cast<void*>(weight_data),
static_cast<size_t>(Y_t->numel())}; static_cast<size_t>(Y_t->numel())};
weight.dims.assign({n, m}); weight.dims.assign({weight_w, weight_h});
float* bias_data = nullptr; float* bias_data = nullptr;
int bias_num = 0; int bias_num = 0;
if (with_bias) { if (with_bias) {
...@@ -240,25 +247,72 @@ class FcOpConverter : public OpConverter { ...@@ -240,25 +247,72 @@ class FcOpConverter : public OpConverter {
if (!engine_->with_dynamic_shape()) { if (!engine_->with_dynamic_shape()) {
x_num_col_dims--; x_num_col_dims--;
} }
// If use tensorrt'oss, the x_dim and x_num_col_dims need change // If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can
// not add Shuffle layer in ernie's multihead.
if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 && if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 &&
x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 2) { x_dim.d[3] == 1 && x_num_col_dims == 2) {
x_num_col_dims = 1; if (enable_int8) {
} // add conv1x1 layer
PADDLE_ENFORCE_GT( nvinfer1::DimsHW nv_ksize(1, 1);
x_dim.nbDims, x_num_col_dims, auto* fc_layer_int8 =
platform::errors::InvalidArgument( TRT_ENGINE_ADD_LAYER(engine_, Convolution, *X, n_output, nv_ksize,
"Params and input dims mismatch. Paddle-TRT FC " weight.get(), bias.get());
"converter expects x_dim.nbDims > x_num_col_dims, but " if (activation_type == "relu") {
"x_dim.nbDims : %d, x_num_col_dims : %d.", fc_layer_int8->setName(
x_dim.nbDims, x_num_col_dims)); ("ernie_fc_op_int8: Convolution (Output: " + output_name + ")")
auto* reshape_before_fc_layer = .c_str());
reshape_before_fc(X, x_dim, x_num_col_dims, output_name); PADDLE_ENFORCE_EQ(
auto* reshape_itensor = reshape_before_fc_layer->getOutput(0); op_desc.HasAttr("out_threshold"), true,
if (enable_int8) { platform::errors::InvalidArgument(
engine_->SetTensorDynamicRange(reshape_itensor, in_scale); "must have out threshold in fc layers in int8 mode"));
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0),
out_scale);
nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_layer_int8->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer_int8, "relu_after_ernie_fc_int8",
{output_name}, test_mode);
} else {
RreplenishLayerAndOutput(fc_layer_int8,
"ernie_fc_op_int8: Convolution",
{output_name}, test_mode);
}
} else {
// add fc layer
auto* fc_layer_float = TRT_ENGINE_ADD_LAYER(
engine_, FullyConnected, *X, n_output, weight.get(), bias.get());
if (activation_type == "relu") {
fc_layer_float->setName(
("ernie_fc_op_float: (Output: " + output_name + ")").c_str());
nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_layer_float->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer_float,
"relu_after_ernie_fc_float", {output_name},
test_mode);
} else {
RreplenishLayerAndOutput(fc_layer_float, "ernie_fc_op_float",
{output_name}, test_mode);
}
}
} else { // need reshape input before and after fc
PADDLE_ENFORCE_GT(
x_dim.nbDims, x_num_col_dims,
platform::errors::InvalidArgument(
"Params and input dims mismatch. Paddle-TRT FC "
"converter expects x_dim.nbDims > x_num_col_dims, but "
"x_dim.nbDims : %d, x_num_col_dims : %d.",
x_dim.nbDims, x_num_col_dims));
auto* reshape_before_fc_layer =
reshape_before_fc(X, x_dim, x_num_col_dims, output_name);
auto* reshape_itensor = reshape_before_fc_layer->getOutput(0);
if (enable_int8) {
engine_->SetTensorDynamicRange(reshape_itensor, in_scale);
}
regist_fc(reshape_itensor, n_output, weight, bias);
} }
regist_fc(reshape_itensor, n_output, weight, bias);
} }
}; };
......
...@@ -410,16 +410,16 @@ class Quant2Int8MkldnnPass(object): ...@@ -410,16 +410,16 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'multi_gru_fuse_pass') graph = self._apply_pass(graph, 'multi_gru_fuse_pass')
graph = self._apply_pass(graph, 'multi_gru_seq_fuse_pass') graph = self._apply_pass(graph, 'multi_gru_seq_fuse_pass')
graph = self._apply_pass(graph, 'seq_concat_fc_fuse_pass') graph = self._apply_pass(graph, 'seq_concat_fc_fuse_pass')
graph = self._apply_pass(graph, 'squeeze2_matmul_fuse_pass') graph = self._apply_pass(graph, 'gpu_cpu_squeeze2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'reshape2_matmul_fuse_pass') graph = self._apply_pass(graph, 'gpu_cpu_reshape2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'flatten2_matmul_fuse_pass') graph = self._apply_pass(graph, 'gpu_cpu_flatten2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'matmul_v2_scale_fuse_pass') graph = self._apply_pass(graph, 'matmul_v2_scale_fuse_pass')
graph = self._apply_pass(graph, 'squared_mat_sub_fuse_pass') graph = self._apply_pass(graph, 'squared_mat_sub_fuse_pass')
graph = self._apply_pass(graph, 'is_test_pass') graph = self._apply_pass(graph, 'is_test_pass')
graph = self._apply_pass(graph, 'map_matmul_v2_to_mul_pass') graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_v2_to_mul_pass')
graph = self._apply_pass(graph, 'map_matmul_v2_to_matmul_pass') graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_v2_to_matmul_pass')
graph = self._apply_pass(graph, 'matmul_scale_fuse_pass') graph = self._apply_pass(graph, 'matmul_scale_fuse_pass')
graph = self._apply_pass(graph, 'map_matmul_to_mul_pass') graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_to_mul_pass')
graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass') graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
graph = self._apply_pass(graph, 'mkldnn_placement_pass', graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()]) ['mkldnn_enabled_op_types'], [set()])
......
...@@ -174,7 +174,7 @@ class TestFlatten2MatmulFusePass(PassAutoScanTest): ...@@ -174,7 +174,7 @@ class TestFlatten2MatmulFusePass(PassAutoScanTest):
quant=False, quant=False,
max_examples=50, max_examples=50,
max_duration=1000, max_duration=1000,
passes=["flatten2_matmul_fuse_pass"]) passes=["gpu_cpu_flatten2_matmul_fuse_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -116,7 +116,7 @@ class TestMapMatmulToMulPass(PassAutoScanTest): ...@@ -116,7 +116,7 @@ class TestMapMatmulToMulPass(PassAutoScanTest):
self.run_and_statis( self.run_and_statis(
quant=False, quant=False,
max_examples=100, max_examples=100,
passes=["map_matmul_to_mul_pass"], passes=["gpu_cpu_map_matmul_to_mul_pass"],
max_duration=180) max_duration=180)
......
...@@ -127,7 +127,7 @@ class TestMapMatmulToMulPass(PassAutoScanTest): ...@@ -127,7 +127,7 @@ class TestMapMatmulToMulPass(PassAutoScanTest):
self.run_and_statis( self.run_and_statis(
quant=False, quant=False,
max_examples=100, max_examples=100,
passes=["map_matmul_v2_to_matmul_pass"]) passes=["gpu_cpu_map_matmul_v2_to_matmul_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -110,8 +110,9 @@ class TestMapMatmulToMulPass(PassAutoScanTest): ...@@ -110,8 +110,9 @@ class TestMapMatmulToMulPass(PassAutoScanTest):
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
quant=False, max_examples=100, quant=False,
passes=["map_matmul_v2_to_mul_pass"]) max_examples=100,
passes=["gpu_cpu_map_matmul_v2_to_mul_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -132,7 +132,7 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest): ...@@ -132,7 +132,7 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest):
return program_config return program_config
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
# map_matmul_v2_to_matmul_pass will affect the type of final fused op # gpu_cpu_map_matmul_v2_to_matmul_pass will affect the type of final fused op
fused_op = "matmul_v2" fused_op = "matmul_v2"
input1_dim1 = program_config.inputs["input_data1"].shape[0] input1_dim1 = program_config.inputs["input_data1"].shape[0]
input2_dim1 = program_config.inputs["input_data2"].shape[0] input2_dim1 = program_config.inputs["input_data2"].shape[0]
......
...@@ -172,7 +172,7 @@ class TestReshape2MatmulFusePass(PassAutoScanTest): ...@@ -172,7 +172,7 @@ class TestReshape2MatmulFusePass(PassAutoScanTest):
quant=False, quant=False,
max_examples=50, max_examples=50,
max_duration=1000, max_duration=1000,
passes=["reshape2_matmul_fuse_pass"]) passes=["gpu_cpu_reshape2_matmul_fuse_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -180,7 +180,7 @@ class TestSqueeze2MatmulFusePass(PassAutoScanTest): ...@@ -180,7 +180,7 @@ class TestSqueeze2MatmulFusePass(PassAutoScanTest):
quant=False, quant=False,
max_examples=50, max_examples=50,
max_duration=1000, max_duration=1000,
passes=["squeeze2_matmul_fuse_pass"]) passes=["gpu_cpu_squeeze2_matmul_fuse_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册