未验证 提交 db7d129e 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference] rebuild matmul pass: trt and gpu_cpu (#39369)

* rebuild matmul pass: trt and gpu_cpu

* rebuild matmul pass: trt and gpu_cpu

* rebuild matmul pass: trt and gpu_cpu

* rebuild matmul pass: trt and gpu_cpu
上级 772be4f5
......@@ -62,7 +62,6 @@ pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base DEPS string_helper)
pass_library(fc_fuse_pass inference)
pass_library(map_matmul_to_mul_pass inference)
pass_library(attention_lstm_fuse_pass inference)
pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference)
......@@ -98,8 +97,14 @@ pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference)
pass_library(add_support_int8_pass inference)
pass_library(matmul_scale_fuse_pass inference)
pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)
if(WITH_TENSORRT)
pass_library(trt_map_matmul_to_mul_pass inference)
endif()
if(WITH_GPU OR WITH_ROCM)
pass_library(cudnn_placement_pass base DEPS placement_pass_base)
pass_library(embedding_eltwise_layernorm_fuse_pass inference)
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/map_matmul_to_mul_pass.h"
#include "paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.h"
#include <cmath>
#include <string>
......@@ -28,7 +28,7 @@ namespace ir {
class Node;
MapMatmul2MulPass::MapMatmul2MulPass() {
GpuCpuMapMatmul2MulPass::GpuCpuMapMatmul2MulPass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
......@@ -68,7 +68,7 @@ MapMatmul2MulPass::MapMatmul2MulPass() {
.End();
}
MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() {
GpuCpuMapMatmulV2ToMulPass::GpuCpuMapMatmulV2ToMulPass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
......@@ -104,7 +104,7 @@ MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() {
.End();
}
MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() {
GpuCpuMapMatmulV2ToMatmulPass::GpuCpuMapMatmulV2ToMatmulPass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
......@@ -143,7 +143,7 @@ MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() {
.End();
}
Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
GpuCpuFlatten2MatmulFusePass::GpuCpuFlatten2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
......@@ -197,7 +197,7 @@ Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
.End();
}
Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() {
GpuCpuSqueeze2MatmulFusePass::GpuCpuSqueeze2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
......@@ -251,10 +251,10 @@ Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() {
.End();
}
void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_to_mul_pass";
std::string name_scope = "gpu_cpu_map_matmul_to_mul_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
......@@ -264,7 +264,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "map matmul to mul";
VLOG(4) << "gpu_cpu map matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern);
......@@ -286,7 +286,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmul2MulPass in op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmul2MulPass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
......@@ -311,7 +311,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmul2MulPass in out mul op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmul2MulPass in out mul op compat failed.";
return;
}
}
......@@ -321,10 +321,10 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}
void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_v2_to_mul_pass";
std::string name_scope = "gpu_cpu_map_matmul_v2_to_mul_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
......@@ -335,7 +335,7 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(3) << "map matmul_v2 to mul";
VLOG(3) << "gpu_cpu map matmul_v2 to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_weight_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
......@@ -360,7 +360,7 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmulV2ToMulPass in op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmulV2ToMulPass in op compat failed.";
return;
}
OpDesc desc(matmul_v2_op->Op()->Block());
......@@ -386,7 +386,8 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmulV2ToMulPass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuMapMatmulV2ToMulPass in out mul op compat failed.";
return;
}
}
......@@ -396,10 +397,10 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}
void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_v2_to_matmul_pass";
std::string name_scope = "gpu_cpu_map_matmul_v2_to_matmul_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
......@@ -409,7 +410,7 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "map matmul_v2 to matmul";
VLOG(4) << "gpu_cpu map matmul_v2 to matmul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
......@@ -417,7 +418,7 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op, matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out, matmul_v2_pattern);
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmulV2ToMatmulPass in op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmulV2ToMatmulPass in op compat failed.";
return;
}
......@@ -463,7 +464,8 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmulV2ToMatmulPass in out matmul op compat failed.";
LOG(WARNING)
<< "GpuCpuMapMatmulV2ToMatmulPass in out matmul op compat failed.";
return;
}
};
......@@ -472,10 +474,10 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}
void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "squeeze2_matmul_fuse_pass";
std::string name_scope = "gpu_cpu_squeeze2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
......@@ -485,7 +487,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse squeeze2+matmul to mul";
VLOG(4) << "gpu_cpu fuse squeeze2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
......@@ -518,7 +520,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Squeeze2MatmulFusePass in op compat failed.";
LOG(WARNING) << "GpuCpuSqueeze2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
......@@ -542,7 +544,8 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
GraphSafeRemoveNodes(graph, {squeeze2_op, matmul_in_x, matmul_op});
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "Squeeze2MatmulFusePass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuSqueeze2MatmulFusePass in out mul op compat failed.";
return;
}
}
......@@ -552,7 +555,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}
Reshape2MatmulFusePass::Reshape2MatmulFusePass() {
GpuCpuReshape2MatmulFusePass::GpuCpuReshape2MatmulFusePass() {
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
......@@ -614,10 +617,10 @@ Reshape2MatmulFusePass::Reshape2MatmulFusePass() {
.End();
}
void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "reshape2_matmul_fuse_pass";
std::string name_scope = "gpu_cpu_reshape2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
......@@ -627,7 +630,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse reshape2+matmul to mul";
VLOG(4) << "gpu_cpu fuse reshape2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
......@@ -662,7 +665,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Reshape2MatmulFusePass in op compat failed.";
LOG(WARNING) << "GpuCpuReshape2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
......@@ -680,7 +683,8 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
matmul_op->Op()->GetAttr("out_threshold"));
}
if (!IsCompat(desc)) {
LOG(WARNING) << "Reshape2MatmulFusePass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuReshape2MatmulFusePass in out mul op compat failed.";
return;
}
auto mul_node = g->CreateOpNode(&desc);
......@@ -696,10 +700,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}
void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "flatten2_matmul_fuse_pass";
std::string name_scope = "gpu_cpu_flatten2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
......@@ -709,7 +713,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse flatten2+matmul to mul";
VLOG(4) << "gpu_cpu fuse flatten2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
......@@ -749,7 +753,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (pattern_found) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Flatten2MatmulFusePass in op compat failed.";
LOG(WARNING) << "GpuCpuFlatten2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
......@@ -774,7 +778,8 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "Flatten2MatmulFusePass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuFlatten2MatmulFusePass in out mul op compat failed.";
return;
}
}
......@@ -788,50 +793,51 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
} // namespace framework
} // namespace paddle
REGISTER_PASS(map_matmul_to_mul_pass, paddle::framework::ir::MapMatmul2MulPass);
REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass)
REGISTER_PASS(gpu_cpu_map_matmul_to_mul_pass,
paddle::framework::ir::GpuCpuMapMatmul2MulPass);
REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_to_mul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("mul", 0));
REGISTER_PASS(map_matmul_v2_to_mul_pass,
paddle::framework::ir::MapMatmulV2ToMulPass);
REGISTER_PASS_CAPABILITY(map_matmul_v2_to_mul_pass)
REGISTER_PASS(gpu_cpu_map_matmul_v2_to_mul_pass,
paddle::framework::ir::GpuCpuMapMatmulV2ToMulPass);
REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_v2_to_mul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("mul", 0));
REGISTER_PASS(map_matmul_v2_to_matmul_pass,
paddle::framework::ir::MapMatmulV2ToMatmulPass);
REGISTER_PASS_CAPABILITY(map_matmul_v2_to_matmul_pass)
REGISTER_PASS(gpu_cpu_map_matmul_v2_to_matmul_pass,
paddle::framework::ir::GpuCpuMapMatmulV2ToMatmulPass);
REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_v2_to_matmul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.LE("matmul", 1));
REGISTER_PASS(squeeze2_matmul_fuse_pass,
paddle::framework::ir::Squeeze2MatmulFusePass);
REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass)
REGISTER_PASS(gpu_cpu_squeeze2_matmul_fuse_pass,
paddle::framework::ir::GpuCpuSqueeze2MatmulFusePass);
REGISTER_PASS_CAPABILITY(gpu_cpu_squeeze2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("squeeze2", 0)
.EQ("mul", 0));
REGISTER_PASS(reshape2_matmul_fuse_pass,
paddle::framework::ir::Reshape2MatmulFusePass);
REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass)
REGISTER_PASS(gpu_cpu_reshape2_matmul_fuse_pass,
paddle::framework::ir::GpuCpuReshape2MatmulFusePass);
REGISTER_PASS_CAPABILITY(gpu_cpu_reshape2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("reshape2", 0)
.EQ("mul", 0));
REGISTER_PASS(flatten2_matmul_fuse_pass,
paddle::framework::ir::Flatten2MatmulFusePass);
REGISTER_PASS_CAPABILITY(flatten2_matmul_fuse_pass)
REGISTER_PASS(gpu_cpu_flatten2_matmul_fuse_pass,
paddle::framework::ir::GpuCpuFlatten2MatmulFusePass);
REGISTER_PASS_CAPABILITY(gpu_cpu_flatten2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
......
......@@ -37,22 +37,22 @@ namespace ir {
*/
class Graph;
class MapMatmul2MulPass : public FusePassBase {
class GpuCpuMapMatmul2MulPass : public FusePassBase {
public:
MapMatmul2MulPass();
virtual ~MapMatmul2MulPass() {}
GpuCpuMapMatmul2MulPass();
virtual ~GpuCpuMapMatmul2MulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Map matmul_v2 to mul, the same as MapMatmul2MulPass.
* Map matmul_v2 to mul, the same as GpuCpuMapMatmul2MulPass.
*/
class MapMatmulV2ToMulPass : public FusePassBase {
class GpuCpuMapMatmulV2ToMulPass : public FusePassBase {
public:
MapMatmulV2ToMulPass();
virtual ~MapMatmulV2ToMulPass() {}
GpuCpuMapMatmulV2ToMulPass();
virtual ~GpuCpuMapMatmulV2ToMulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
......@@ -61,10 +61,10 @@ class MapMatmulV2ToMulPass : public FusePassBase {
/*
* Map matmul_v2 to matmul, not supoort broadcast.
*/
class MapMatmulV2ToMatmulPass : public FusePassBase {
class GpuCpuMapMatmulV2ToMatmulPass : public FusePassBase {
public:
MapMatmulV2ToMatmulPass();
virtual ~MapMatmulV2ToMatmulPass() {}
GpuCpuMapMatmulV2ToMatmulPass();
virtual ~GpuCpuMapMatmulV2ToMatmulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
......@@ -89,10 +89,10 @@ class MapMatmulV2ToMatmulPass : public FusePassBase {
* the above passes to reduce the impact on other models.
*/
class Squeeze2MatmulFusePass : public FusePassBase {
class GpuCpuSqueeze2MatmulFusePass : public FusePassBase {
public:
Squeeze2MatmulFusePass();
virtual ~Squeeze2MatmulFusePass() {}
GpuCpuSqueeze2MatmulFusePass();
virtual ~GpuCpuSqueeze2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
......@@ -119,19 +119,19 @@ class Squeeze2MatmulFusePass : public FusePassBase {
* the above passes to reduce the impact on other models.
*/
class Reshape2MatmulFusePass : public FusePassBase {
class GpuCpuReshape2MatmulFusePass : public FusePassBase {
public:
Reshape2MatmulFusePass();
virtual ~Reshape2MatmulFusePass() {}
GpuCpuReshape2MatmulFusePass();
virtual ~GpuCpuReshape2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
class Flatten2MatmulFusePass : public FusePassBase {
class GpuCpuFlatten2MatmulFusePass : public FusePassBase {
public:
Flatten2MatmulFusePass();
virtual ~Flatten2MatmulFusePass() {}
GpuCpuFlatten2MatmulFusePass();
virtual ~GpuCpuFlatten2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
......
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class TrtMapMatmul2MulPass : public FusePassBase {
public:
TrtMapMatmul2MulPass();
virtual ~TrtMapMatmul2MulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Map matmul_v2 to mul, the same as TrtMapMatmul2MulPass.
*/
class TrtMapMatmulV2ToMulPass : public FusePassBase {
public:
TrtMapMatmulV2ToMulPass();
virtual ~TrtMapMatmulV2ToMulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Map matmul_v2 to matmul, not supoort broadcast.
*/
class TrtMapMatmulV2ToMatmulPass : public FusePassBase {
public:
TrtMapMatmulV2ToMatmulPass();
virtual ~TrtMapMatmulV2ToMatmulPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Fuse squeeze2+matmul to mul, so the optimization can use fc_fuse_pass.
* The squeeze2 op must satisfy the following conditions:
* 1. the rank of input X is 4
* 2. the axis attr is [2, 3]
* 3. the next op is only matmul
*
* The matmul op must satisfy the following conditions:
* 1. the transpose_X and transpose_Y attrs are false
* 2. the alpha attr is 1.0
* 3. the rank of input X and Y is 2
* 4. the next op of matmul is only elementwise_add
*
* Notice:
* the rank of input activation is obtained from var_desc,
* it maybe change in runtime. Therefore, the pass considers
* the above passes to reduce the impact on other models.
*/
class TrtSqueeze2MatmulFusePass : public FusePassBase {
public:
TrtSqueeze2MatmulFusePass();
virtual ~TrtSqueeze2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Fuse reshape2+matmul to mul, so the optimization can use fc_fuse_pass.
* The reshape2 op must satisfy the following conditions:
* 1. reshape2 has one input node, which means it don't
* have Shape or ShapeTensor input
* 2. the rank of input X is 4 and the last two dims of input X is 1
* 3. the rank of shape attr is 2
* 4. the next op is only matmul
*
* The matmul op must satisfy the following conditions:
* 1. the transpose_X and transpose_Y attrs are false
* 2. the alpha attr is 1.0
* 3. the rank of input X and Y is 2
* 4. the next op of matmul is only elementwise_add
*
* Notice:
* the shape and rank of input activation is obtained from var_desc,
* they maybe change in runtime. Therefore, the pass considers
* the above passes to reduce the impact on other models.
*/
class TrtReshape2MatmulFusePass : public FusePassBase {
public:
TrtReshape2MatmulFusePass();
virtual ~TrtReshape2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
class TrtFlatten2MatmulFusePass : public FusePassBase {
public:
TrtFlatten2MatmulFusePass();
virtual ~TrtFlatten2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -90,12 +90,12 @@ const std::vector<std::string> kTRTSubgraphPasses({
"skip_layernorm_fuse_pass", //
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
"squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"map_matmul_v2_to_mul_pass", //
"map_matmul_v2_to_matmul_pass", //
"map_matmul_to_mul_pass", //
"trt_squeeze2_matmul_fuse_pass", //
"trt_reshape2_matmul_fuse_pass", //
"trt_flatten2_matmul_fuse_pass", //
"trt_map_matmul_v2_to_mul_pass", //
"trt_map_matmul_v2_to_matmul_pass", //
"trt_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"add_support_int8_pass",
......@@ -140,12 +140,12 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"map_matmul_v2_to_mul_pass", //
"map_matmul_v2_to_matmul_pass", //
"map_matmul_to_mul_pass", //
"gpu_cpu_squeeze2_matmul_fuse_pass", //
"gpu_cpu_reshape2_matmul_fuse_pass", //
"gpu_cpu_flatten2_matmul_fuse_pass", //
"gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"gpu_cpu_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
......@@ -202,14 +202,14 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"fc_gru_fuse_pass", //
"mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", //
"squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"gpu_cpu_squeeze2_matmul_fuse_pass", //
"gpu_cpu_reshape2_matmul_fuse_pass", //
"gpu_cpu_flatten2_matmul_fuse_pass", //
"matmul_v2_scale_fuse_pass", //
"map_matmul_v2_to_mul_pass", //
"map_matmul_v2_to_matmul_pass", //
"gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"matmul_scale_fuse_pass", //
"map_matmul_to_mul_pass", //
"gpu_cpu_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"repeated_fc_relu_fuse_pass", //
"squared_mat_sub_fuse_pass", //
......
......@@ -67,13 +67,7 @@ class FcOpConverter : public OpConverter {
nvinfer1::Dims x_dim, int x_num_col_dims) {
// add shuffle after fc
nvinfer1::Dims reshape_after_fc_dim;
if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 &&
x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 1) {
// If use tensorrt'oss, the x_dim and x_num_col_dims need change
reshape_after_fc_dim.nbDims = 4;
} else {
reshape_after_fc_dim.nbDims = x_num_col_dims + 1;
}
for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) {
reshape_after_fc_dim.d[i] = 0;
}
......@@ -141,7 +135,6 @@ class FcOpConverter : public OpConverter {
"The fc's weight should be a matrix with 2 dims, but "
"it's %d-dimensional.",
Y_t->dims().size())); // a matrix
size_t n_output = Y_t->dims()[1];
int m = Y_t->dims()[0];
int n = Y_t->dims()[1];
auto tranpose_weight = [](const float* src, float* dst, int m, int n) {
......@@ -175,9 +168,10 @@ class FcOpConverter : public OpConverter {
fc_layer_int8->getOutput(0), x_dim, x_num_col_dims);
if (activation_type == "relu") {
fc_after_reshape_int8->setName(
("fc_op_int8_reshape_after_fc: Shuffle (Output: " + output_name +
")")
("int8_reshape_after_fc: Shuffle (Output: " + output_name + ")")
.c_str());
engine_->SetTensorDynamicRange(fc_after_reshape_int8->getOutput(0),
out_scale);
nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_after_reshape_int8->getOutput(0)),
nvinfer1::ActivationType::kRELU);
......@@ -200,8 +194,7 @@ class FcOpConverter : public OpConverter {
fc_layer_float->getOutput(0), x_dim, x_num_col_dims);
if (activation_type == "relu") {
fc_after_reshape_float->setName(
("fc_op_float_reshape_after_fc: Shuffle (Output: " + output_name +
")")
("float_reshape_after_fc: Shuffle (Output: " + output_name + ")")
.c_str());
nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_after_reshape_float->getOutput(0)),
......@@ -215,14 +208,28 @@ class FcOpConverter : public OpConverter {
}
};
bool transpose_y = false;
if (op_desc.HasAttr("transpose_Y")) {
transpose_y = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_Y"));
}
int weight_w, weight_h;
if (!transpose_y) {
std::vector<float> weight_data_tmp;
weight_data_tmp.reserve(Y_t->numel());
memcpy(weight_data_tmp.data(), weight_data, Y_t->numel() * sizeof(float));
tranpose_weight(weight_data_tmp.data(), weight_data, m, n);
weight_w = n;
weight_h = m;
} else {
weight_w = m;
weight_h = n;
}
size_t n_output = weight_w;
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<size_t>(Y_t->numel())};
weight.dims.assign({n, m});
weight.dims.assign({weight_w, weight_h});
float* bias_data = nullptr;
int bias_num = 0;
if (with_bias) {
......@@ -240,11 +247,57 @@ class FcOpConverter : public OpConverter {
if (!engine_->with_dynamic_shape()) {
x_num_col_dims--;
}
// If use tensorrt'oss, the x_dim and x_num_col_dims need change
// If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can
// not add Shuffle layer in ernie's multihead.
if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 &&
x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 2) {
x_num_col_dims = 1;
x_dim.d[3] == 1 && x_num_col_dims == 2) {
if (enable_int8) {
// add conv1x1 layer
nvinfer1::DimsHW nv_ksize(1, 1);
auto* fc_layer_int8 =
TRT_ENGINE_ADD_LAYER(engine_, Convolution, *X, n_output, nv_ksize,
weight.get(), bias.get());
if (activation_type == "relu") {
fc_layer_int8->setName(
("ernie_fc_op_int8: Convolution (Output: " + output_name + ")")
.c_str());
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("out_threshold"), true,
platform::errors::InvalidArgument(
"must have out threshold in fc layers in int8 mode"));
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0),
out_scale);
nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_layer_int8->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer_int8, "relu_after_ernie_fc_int8",
{output_name}, test_mode);
} else {
RreplenishLayerAndOutput(fc_layer_int8,
"ernie_fc_op_int8: Convolution",
{output_name}, test_mode);
}
} else {
// add fc layer
auto* fc_layer_float = TRT_ENGINE_ADD_LAYER(
engine_, FullyConnected, *X, n_output, weight.get(), bias.get());
if (activation_type == "relu") {
fc_layer_float->setName(
("ernie_fc_op_float: (Output: " + output_name + ")").c_str());
nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_layer_float->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer_float,
"relu_after_ernie_fc_float", {output_name},
test_mode);
} else {
RreplenishLayerAndOutput(fc_layer_float, "ernie_fc_op_float",
{output_name}, test_mode);
}
}
} else { // need reshape input before and after fc
PADDLE_ENFORCE_GT(
x_dim.nbDims, x_num_col_dims,
platform::errors::InvalidArgument(
......@@ -260,6 +313,7 @@ class FcOpConverter : public OpConverter {
}
regist_fc(reshape_itensor, n_output, weight, bias);
}
}
};
} // namespace tensorrt
......
......@@ -410,16 +410,16 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'multi_gru_fuse_pass')
graph = self._apply_pass(graph, 'multi_gru_seq_fuse_pass')
graph = self._apply_pass(graph, 'seq_concat_fc_fuse_pass')
graph = self._apply_pass(graph, 'squeeze2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'reshape2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'flatten2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'gpu_cpu_squeeze2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'gpu_cpu_reshape2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'gpu_cpu_flatten2_matmul_fuse_pass')
graph = self._apply_pass(graph, 'matmul_v2_scale_fuse_pass')
graph = self._apply_pass(graph, 'squared_mat_sub_fuse_pass')
graph = self._apply_pass(graph, 'is_test_pass')
graph = self._apply_pass(graph, 'map_matmul_v2_to_mul_pass')
graph = self._apply_pass(graph, 'map_matmul_v2_to_matmul_pass')
graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_v2_to_mul_pass')
graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_v2_to_matmul_pass')
graph = self._apply_pass(graph, 'matmul_scale_fuse_pass')
graph = self._apply_pass(graph, 'map_matmul_to_mul_pass')
graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_to_mul_pass')
graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()])
......
......@@ -174,7 +174,7 @@ class TestFlatten2MatmulFusePass(PassAutoScanTest):
quant=False,
max_examples=50,
max_duration=1000,
passes=["flatten2_matmul_fuse_pass"])
passes=["gpu_cpu_flatten2_matmul_fuse_pass"])
if __name__ == "__main__":
......
......@@ -116,7 +116,7 @@ class TestMapMatmulToMulPass(PassAutoScanTest):
self.run_and_statis(
quant=False,
max_examples=100,
passes=["map_matmul_to_mul_pass"],
passes=["gpu_cpu_map_matmul_to_mul_pass"],
max_duration=180)
......
......@@ -127,7 +127,7 @@ class TestMapMatmulToMulPass(PassAutoScanTest):
self.run_and_statis(
quant=False,
max_examples=100,
passes=["map_matmul_v2_to_matmul_pass"])
passes=["gpu_cpu_map_matmul_v2_to_matmul_pass"])
if __name__ == "__main__":
......
......@@ -110,8 +110,9 @@ class TestMapMatmulToMulPass(PassAutoScanTest):
def test(self):
self.run_and_statis(
quant=False, max_examples=100,
passes=["map_matmul_v2_to_mul_pass"])
quant=False,
max_examples=100,
passes=["gpu_cpu_map_matmul_v2_to_mul_pass"])
if __name__ == "__main__":
......
......@@ -132,7 +132,7 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest):
return program_config
def sample_predictor_configs(self, program_config):
# map_matmul_v2_to_matmul_pass will affect the type of final fused op
# gpu_cpu_map_matmul_v2_to_matmul_pass will affect the type of final fused op
fused_op = "matmul_v2"
input1_dim1 = program_config.inputs["input_data1"].shape[0]
input2_dim1 = program_config.inputs["input_data2"].shape[0]
......
......@@ -172,7 +172,7 @@ class TestReshape2MatmulFusePass(PassAutoScanTest):
quant=False,
max_examples=50,
max_duration=1000,
passes=["reshape2_matmul_fuse_pass"])
passes=["gpu_cpu_reshape2_matmul_fuse_pass"])
if __name__ == "__main__":
......
......@@ -180,7 +180,7 @@ class TestSqueeze2MatmulFusePass(PassAutoScanTest):
quant=False,
max_examples=50,
max_duration=1000,
passes=["squeeze2_matmul_fuse_pass"])
passes=["gpu_cpu_squeeze2_matmul_fuse_pass"])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册