Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
20b38cfa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
20b38cfa
编写于
6月 09, 2022
作者:
M
minghaoBD
提交者:
GitHub
6月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[sparse inference] Supporting 2:4 sparse inference (#43179)
上级
1a80b484
变更
32
展开全部
显示空白变更内容
内联
并排
Showing
32 changed file
with
3212 addition
and
91 deletion
+3212
-91
CMakeLists.txt
CMakeLists.txt
+1
-0
cmake/external/cusparselt.cmake
cmake/external/cusparselt.cmake
+61
-0
cmake/inference_lib.cmake
cmake/inference_lib.cmake
+8
-0
cmake/third_party.cmake
cmake/third_party.cmake
+5
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+10
-0
paddle/fluid/framework/ir/dense_fc_to_sparse_pass.cc
paddle/fluid/framework/ir/dense_fc_to_sparse_pass.cc
+148
-0
paddle/fluid/framework/ir/dense_fc_to_sparse_pass.h
paddle/fluid/framework/ir/dense_fc_to_sparse_pass.h
+61
-0
paddle/fluid/framework/ir/dense_fc_to_sparse_pass_tester.cc
paddle/fluid/framework/ir/dense_fc_to_sparse_pass_tester.cc
+107
-0
paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.cc
...uid/framework/ir/dense_multihead_matmul_to_sparse_pass.cc
+174
-0
paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.h
...luid/framework/ir/dense_multihead_matmul_to_sparse_pass.h
+61
-0
paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass_tester.cc
...mework/ir/dense_multihead_matmul_to_sparse_pass_tester.cc
+152
-0
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+4
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+4
-2
paddle/fluid/inference/tensorrt/CMakeLists.txt
paddle/fluid/inference/tensorrt/CMakeLists.txt
+4
-3
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+68
-59
paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc
paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc
+371
-0
paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc
.../inference/tensorrt/convert/sparse_multihead_matmul_op.cc
+441
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+16
-0
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+36
-27
paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu
paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu
+923
-0
paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h
paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h
+158
-0
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
+175
-0
paddle/fluid/platform/dynload/CMakeLists.txt
paddle/fluid/platform/dynload/CMakeLists.txt
+4
-0
paddle/fluid/platform/dynload/cusparseLt.cc
paddle/fluid/platform/dynload/cusparseLt.cc
+29
-0
paddle/fluid/platform/dynload/cusparseLt.h
paddle/fluid/platform/dynload/cusparseLt.h
+60
-0
paddle/fluid/platform/dynload/dynamic_loader.cc
paddle/fluid/platform/dynload/dynamic_loader.cc
+4
-0
paddle/fluid/platform/dynload/dynamic_loader.h
paddle/fluid/platform/dynload/dynamic_loader.h
+1
-0
paddle/phi/backends/dynload/CMakeLists.txt
paddle/phi/backends/dynload/CMakeLists.txt
+4
-0
paddle/phi/backends/dynload/cusparseLt.cc
paddle/phi/backends/dynload/cusparseLt.cc
+28
-0
paddle/phi/backends/dynload/cusparseLt.h
paddle/phi/backends/dynload/cusparseLt.h
+78
-0
paddle/phi/backends/dynload/dynamic_loader.cc
paddle/phi/backends/dynload/dynamic_loader.cc
+15
-0
paddle/phi/backends/dynload/dynamic_loader.h
paddle/phi/backends/dynload/dynamic_loader.h
+1
-0
未找到文件。
CMakeLists.txt
浏览文件 @
20b38cfa
...
...
@@ -60,6 +60,7 @@ option(WITH_IPU "Compile PaddlePaddle with Graphcore IPU" OFF)
option
(
WITH_ASCEND_CL
"Compile PaddlePaddle with ASCEND CL"
${
WITH_ASCEND
}
)
option
(
WITH_ASCEND_CXX11
"Compile PaddlePaddle with ASCEND and CXX11 ABI"
OFF
)
option
(
WITH_ONNXRUNTIME
"Compile PaddlePaddle with ONNXRUNTIME"
OFF
)
option
(
WITH_CUSPARSELT
"Compile PaddlePaddle with CUSPARSELT"
OFF
)
# Note(zhouwei): It use option above, so put here
include
(
init
)
include
(
generic
)
# simplify cmake module
...
...
cmake/external/cusparselt.cmake
0 → 100644
浏览文件 @
20b38cfa
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if
(
NOT
(
WITH_CUSPARSELT AND WITH_TENSORRT
))
return
()
endif
()
if
(
WITH_ARM OR WIN32
)
message
(
SEND_ERROR
"The current sparselt support linux only"
)
return
()
endif
()
include
(
ExternalProject
)
set
(
CUSPARSELT_PROJECT
"extern_cusparselt"
)
set
(
CUSPARSELT_P
"https://developer.download.nvidia.com/compute"
)
set
(
CUSPARSELT_F
"libcusparse_lt-linux-x86_64-0.2.0.1.tar.gz"
)
set
(
CUSPARSELT_URL
"
${
CUSPARSELT_P
}
/libcusparse-lt/0.2.0/local_installers/
${
CUSPARSELT_F
}
"
CACHE STRING
""
FORCE
)
set
(
CUSPARSELT_PREFIX_DIR
${
THIRD_PARTY_PATH
}
/cusparselt
)
set
(
CUSPARSELT_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/cusparselt
)
set
(
CUSPARSELT_INC_DIR
"
${
CUSPARSELT_INSTALL_DIR
}
/include"
CACHE PATH
"sparselt include directory."
FORCE
)
set
(
CUSPARSELT_LIB_DIR
"
${
CUSPARSELT_INSTALL_DIR
}
/lib64"
CACHE PATH
"sparselt lib directory."
FORCE
)
set_directory_properties
(
PROPERTIES CLEAN_NO_CUSTOM 1
)
include_directories
(
${
CUSPARSELT_INC_DIR
}
)
ExternalProject_Add
(
${
CUSPARSELT_PROJECT
}
${
EXTERNAL_PROJECT_LOG_ARGS
}
URL
${
CUSPARSELT_URL
}
PREFIX
${
CUSPARSELT_PREFIX_DIR
}
DOWNLOAD_NO_PROGRESS 1
CONFIGURE_COMMAND
""
BUILD_COMMAND
""
INSTALL_COMMAND
${
CMAKE_COMMAND
}
-E copy_directory
${
CUSPARSELT_PREFIX_DIR
}
/src/extern_cusparselt/lib64
${
CUSPARSELT_LIB_DIR
}
&&
${
CMAKE_COMMAND
}
-E copy_directory
${
CUSPARSELT_PREFIX_DIR
}
/src/extern_cusparselt/include
${
CUSPARSELT_INC_DIR
}
UPDATE_COMMAND
""
)
add_library
(
cusparselt INTERFACE
)
add_dependencies
(
cusparselt
${
CUSPARSELT_PROJECT
}
)
set
(
CUSPARSELT_FOUND ON
)
add_definitions
(
-DPADDLE_WITH_CUSPARSELT
)
cmake/inference_lib.cmake
浏览文件 @
20b38cfa
...
...
@@ -108,6 +108,14 @@ function(copy_part_of_thrid_party TARGET DST)
SRCS
${
CBLAS_INSTALL_DIR
}
/lib
${
CBLAS_INSTALL_DIR
}
/include
DSTS
${
dst_dir
}
${
dst_dir
}
)
endif
()
if
(
WITH_SPARSELT
)
set
(
dst_dir
"
${
DST
}
/third_party/install/cusparselt"
)
copy
(
${
TARGET
}
SRCS
${
CUSPARSELT_INC_DIR
}
${
CUSPARSELT_LIB_DIR
}
DSTS
${
dst_dir
}
${
dst_dir
}
)
endif
()
endif
()
if
(
WITH_MKLDNN
)
...
...
cmake/third_party.cmake
浏览文件 @
20b38cfa
...
...
@@ -496,4 +496,9 @@ if(WITH_IPU)
list
(
APPEND third_party_deps extern_poplar
)
endif
()
if
(
WITH_CUSPARSELT
)
include
(
external/cusparselt
)
# download, build, install cusparselt
list
(
APPEND third_party_deps extern_cusparselt
)
endif
()
add_custom_target
(
third_party ALL DEPENDS
${
third_party_deps
}
)
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
20b38cfa
...
...
@@ -156,6 +156,8 @@ pass_library(add_support_int8_pass inference)
pass_library
(
matmul_scale_fuse_pass inference
)
pass_library
(
gpu_cpu_map_matmul_to_mul_pass inference
)
pass_library
(
mixed_precision_configure_pass inference
)
pass_library
(
dense_fc_to_sparse_pass inference
)
pass_library
(
dense_multihead_matmul_to_sparse_pass inference
)
pass_library
(
generate_pass DEPS pass_desc_proto
)
target_link_libraries
(
generate_pass pass_desc_proto
)
...
...
@@ -379,6 +381,14 @@ if(NOT WIN32)
test_sync_batch_norm_pass
SRCS sync_batch_norm_pass_tester.cc
DEPS sync_batch_norm_pass
)
cc_test
(
test_dense_fc_to_sparse_pass_cc
SRCS dense_fc_to_sparse_pass_tester.cc
DEPS fc_fuse_pass dense_fc_to_sparse_pass framework_proto
)
cc_test
(
test_dense_multihead_matmul_to_sparse_pass
SRCS dense_multihead_matmul_to_sparse_pass_tester.cc
DEPS multihead_matmul_fuse_pass dense_multihead_matmul_to_sparse_pass
)
endif
()
if
(
WITH_MKLDNN
)
cc_test
(
...
...
paddle/fluid/framework/ir/dense_fc_to_sparse_pass.cc
0 → 100644
浏览文件 @
20b38cfa
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/dense_fc_to_sparse_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
PDNode
*
patterns
::
DenseFC
::
operator
()()
{
auto
*
fc
=
pattern
->
NewNode
(
fc_repr
())
->
assert_is_op
(
"fc"
);
// Input
auto
*
fc_input
=
pattern
->
NewNode
(
fc_input_repr
())
->
AsInput
()
->
assert_is_op_input
(
"fc"
,
"Input"
);
// Filter
auto
*
fc_weights
=
pattern
->
NewNode
(
fc_weights_repr
())
->
AsInput
()
->
assert_is_op_input
(
"fc"
,
"W"
);
// Bias
auto
*
fc_bias
=
pattern
->
NewNode
(
fc_bias_repr
())
->
AsInput
()
->
assert_is_op_input
(
"fc"
,
"Bias"
);
// Output
auto
*
fc_out
=
pattern
->
NewNode
(
fc_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"fc"
,
"Out"
)
->
assert_is_only_output_of_op
(
"fc"
);
fc
->
LinksFrom
({
fc_input
,
fc_weights
,
fc_bias
}).
LinksTo
({
fc_out
});
return
fc_out
;
}
}
// namespace patterns
DenseFCToSparsePass
::
DenseFCToSparsePass
()
{
AddOpCompat
(
OpCompat
(
"fc"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"W"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
();
}
void
DenseFCToSparsePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
std
::
string
name_scope
=
"dense_fc_to_sparse_pass"
;
FusePassBase
::
Init
(
name_scope
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
DenseFC
dense_fc_pattern
(
gpd
.
mutable_pattern
(),
"dense_fc_replace_pass"
);
dense_fc_pattern
();
int
found_dense_fc_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"Replace dense fc with sparse_fc."
;
/* if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}*/
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
fc_out
,
dense_fc_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc
,
fc
,
dense_fc_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_input
,
fc_input
,
dense_fc_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_weights
,
fc_weights
,
dense_fc_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_bias
,
fc_bias
,
dense_fc_pattern
);
auto
*
fc_op
=
fc
->
Op
();
auto
w_name
=
fc_op
->
Input
(
"W"
)[
0
];
// recognize sparse op by name
if
(
w_name
.
find
(
"sparse_2_4"
)
!=
w_name
.
npos
)
{
// fake op
OpDesc
desc
(
fc_op
->
Block
());
desc
.
SetType
(
"sparse_fc"
);
desc
.
SetInput
(
"Input"
,
{
fc_input
->
Name
()});
desc
.
SetInput
(
"W"
,
{
fc_weights
->
Name
()});
desc
.
SetInput
(
"Bias"
,
{
fc_bias
->
Name
()});
desc
.
SetOutput
(
"Out"
,
{
fc_out
->
Name
()});
// copy all attr
if
(
fc_op
->
HasAttr
(
"x_num_col_dims"
))
{
desc
.
SetAttr
(
"x_num_col_dims"
,
fc_op
->
GetAttr
(
"x_num_col_dims"
));
}
if
(
fc_op
->
HasAttr
(
"in_num_col_dims"
))
{
desc
.
SetAttr
(
"in_num_col_dims"
,
fc_op
->
GetAttr
(
"in_num_col_dims"
));
}
desc
.
SetAttr
(
"activation_type"
,
fc_op
->
GetAttr
(
"activation_type"
));
if
(
fc_op
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
fc_op
->
GetAttr
(
"enable_int8"
));
}
if
(
fc_op
->
HasAttr
(
"Input_scale"
))
{
desc
.
SetAttr
(
"Input_scale"
,
fc_op
->
GetAttr
(
"Input_scale"
));
}
if
(
fc_op
->
HasAttr
(
"support_int8"
))
{
desc
.
SetAttr
(
"support_int8"
,
fc_op
->
GetAttr
(
"support_int8"
));
}
if
(
fc_op
->
HasAttr
(
"out_threshold"
))
{
desc
.
SetAttr
(
"out_threshold"
,
fc_op
->
GetAttr
(
"out_threshold"
));
}
desc
.
Flush
();
GraphSafeRemoveNodes
(
g
,
{
fc
});
auto
sparse_fc_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
fc_input
,
sparse_fc_node
);
IR_NODE_LINK_TO
(
fc_weights
,
sparse_fc_node
);
IR_NODE_LINK_TO
(
fc_bias
,
sparse_fc_node
);
IR_NODE_LINK_TO
(
sparse_fc_node
,
fc_out
);
found_dense_fc_count
++
;
}
};
gpd
(
graph
,
handler
);
AddStatis
(
found_dense_fc_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
dense_fc_to_sparse_pass
,
paddle
::
framework
::
ir
::
DenseFCToSparsePass
);
paddle/fluid/framework/ir/dense_fc_to_sparse_pass.h
0 → 100644
浏览文件 @
20b38cfa
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
DenseFC
:
public
PatternBase
{
DenseFC
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"dense_fc"
)
{}
PDNode
*
operator
()();
// declare operator node's name
PATTERN_DECL_NODE
(
fc
);
PATTERN_DECL_NODE
(
fc_out
);
PATTERN_DECL_NODE
(
fc_input
);
PATTERN_DECL_NODE
(
fc_weights
);
PATTERN_DECL_NODE
(
fc_bias
);
};
}
// namespace patterns
/**
* Replace dense op with sparse op
*/
class
Graph
;
class
DenseFCToSparsePass
:
public
FusePassBase
{
public:
DenseFCToSparsePass
();
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
const
std
::
string
name_scope_
{
"dense_fc_to_sparse_pass"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/dense_fc_to_sparse_pass_tester.cc
0 → 100644
浏览文件 @
20b38cfa
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/dense_fc_to_sparse_pass.h"
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
AddVarToScope
(
Scope
*
param_scope
,
const
std
::
string
&
name
,
const
DDim
&
dims
)
{
auto
*
tensor
=
param_scope
->
Var
(
name
)
->
GetMutable
<
LoDTensor
>
();
tensor
->
Resize
(
dims
);
tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
}
Scope
*
CreateParamScope
()
{
auto
param_scope
=
new
Scope
();
AddVarToScope
(
param_scope
,
"conv2d_filters_0"
,
{});
AddVarToScope
(
param_scope
,
"conv2d_bias_0"
,
{});
AddVarToScope
(
param_scope
,
"weights_0_sparse_2_4"
,
{});
AddVarToScope
(
param_scope
,
"weights_1"
,
{});
AddVarToScope
(
param_scope
,
"bias_1"
,
{});
AddVarToScope
(
param_scope
,
"bias_2"
,
{});
return
param_scope
;
}
TEST
(
FCFusePass
,
basic
)
{
// inputs operator output
// --------------------------------------------------------
// (a, filters_0 bias_0) conv2d -> conv2d_out
// conv2d_out relu -> relu_out_0
// (relu_out_0, weights_0_sparse_2_4) mul -> mul_out_0
// (mul_out_0, bias_1) elementwise_add -> add_out_0
// add_out_0 relu -> relu_out_1
// (relu_out_1, weights_1) mul -> mul_out_1
// (mul_out_1, bias_2) elementwise_add -> add_out_1
Layers
layers
;
auto
*
a
=
layers
.
data
(
"a"
);
auto
*
filters_0
=
layers
.
data
(
"conv2d_filters_0"
,
{},
true
);
auto
*
bias_0
=
layers
.
data
(
"conv2d_bias_0"
,
{},
true
);
auto
*
conv2d_out
=
layers
.
conv2d
(
a
,
filters_0
,
bias_0
,
false
);
auto
*
relu_out_0
=
layers
.
relu
(
conv2d_out
);
auto
*
weights_0
=
layers
.
data
(
"weights_0_sparse_2_4"
,
{
5
,
4
},
true
);
auto
*
mul_out_0
=
layers
.
mul
(
relu_out_0
,
weights_0
);
auto
*
bias_1
=
layers
.
data
(
"bias_1"
,
{
4
},
true
);
auto
*
add_out_0
=
layers
.
elementwise_add
(
mul_out_0
,
bias_1
,
nullptr
,
1
);
auto
*
relu_out_1
=
layers
.
relu
(
add_out_0
);
auto
*
weights_1
=
layers
.
data
(
"weights_1"
,
{
8
,
9
},
true
);
auto
*
mul_out_1
=
layers
.
mul
(
relu_out_1
,
weights_1
);
auto
*
bias_2
=
layers
.
data
(
"bias_2"
,
{
1
,
9
},
true
);
auto
*
add_out_1
=
layers
.
elementwise_add
(
mul_out_1
,
bias_2
,
nullptr
,
1
);
VLOG
(
4
)
<<
add_out_1
;
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
auto
fuse_pass
=
PassRegistry
::
Instance
().
Get
(
"fc_fuse_pass"
);
auto
sparse_pass
=
PassRegistry
::
Instance
().
Get
(
"dense_fc_to_sparse_pass"
);
fuse_pass
->
Set
(
"use_gpu"
,
new
bool
(
true
));
sparse_pass
->
Set
(
"use_gpu"
,
new
bool
(
true
));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
int
num_nodes_before
=
graph
->
Nodes
().
size
();
int
num_mul_nodes_before
=
GetNumOpNodes
(
graph
,
"mul"
);
VLOG
(
3
)
<<
DebugString
(
graph
);
graph
.
reset
(
fuse_pass
->
Apply
(
graph
.
release
()));
graph
.
reset
(
sparse_pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
int
num_fc_nodes_after
=
GetNumOpNodes
(
graph
,
"fc"
);
int
num_sparse_fc_nodes_after
=
GetNumOpNodes
(
graph
,
"sparse_fc"
);
VLOG
(
3
)
<<
DebugString
(
graph
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
6
,
platform
::
errors
::
InvalidArgument
(
"num_nodes_before=%d, num_nodes_after=%d."
,
num_nodes_before
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fc_nodes_after
,
1
,
platform
::
errors
::
InvalidArgument
(
"num_fc_nodes_after=%d."
,
num_fc_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_mul_nodes_before
,
num_fc_nodes_after
+
num_sparse_fc_nodes_after
,
platform
::
errors
::
InvalidArgument
(
"num_mul_nodes_before=%d, num_fc_nodes_after=%d + "
"num_sparse_fc_nodes_after=%d."
,
num_mul_nodes_before
,
num_fc_nodes_after
,
num_sparse_fc_nodes_after
));
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
fc_fuse_pass
);
USE_PASS
(
dense_fc_to_sparse_pass
);
paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.cc
0 → 100644
浏览文件 @
20b38cfa
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
PDNode
*
patterns
::
DenseMultiheadMatmul
::
operator
()()
{
auto
*
multihead_matmul
=
pattern
->
NewNode
(
multihead_matmul_repr
())
->
assert_is_op
(
"multihead_matmul"
);
// Input
auto
*
multihead_matmul_input
=
pattern
->
NewNode
(
multihead_matmul_input_repr
())
->
AsInput
()
->
assert_is_op_input
(
"multihead_matmul"
,
"Input"
);
// Filter
auto
*
multihead_matmul_weights
=
pattern
->
NewNode
(
multihead_matmul_weights_repr
())
->
AsInput
()
->
assert_is_op_input
(
"multihead_matmul"
,
"W"
);
// Bias
auto
*
multihead_matmul_bias
=
pattern
->
NewNode
(
multihead_matmul_bias_repr
())
->
AsInput
()
->
assert_is_op_input
(
"multihead_matmul"
,
"Bias"
);
// BiasQK
auto
*
multihead_matmul_biasqk
=
pattern
->
NewNode
(
multihead_matmul_biasqk_repr
())
->
AsInput
()
->
assert_is_op_input
(
"multihead_matmul"
,
"BiasQK"
);
// Output
auto
*
multihead_matmul_out
=
pattern
->
NewNode
(
multihead_matmul_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"multihead_matmul"
,
"Out"
)
->
assert_is_only_output_of_op
(
"multihead_matmul"
);
multihead_matmul
->
LinksFrom
({
multihead_matmul_input
,
multihead_matmul_weights
,
multihead_matmul_bias
,
multihead_matmul_biasqk
})
.
LinksTo
({
multihead_matmul_out
});
return
multihead_matmul_out
;
}
}
// namespace patterns
DenseMultiheadMatmulToSparsePass
::
DenseMultiheadMatmulToSparsePass
()
{
AddOpCompat
(
OpCompat
(
"multihead_matmul"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"W"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"BiasQK"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
();
}
void
DenseMultiheadMatmulToSparsePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
std
::
string
name_scope
=
"dense_multihead_matmul_to_sparse_pass"
;
FusePassBase
::
Init
(
name_scope
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
DenseMultiheadMatmul
multihead_matmul_pattern
(
gpd
.
mutable_pattern
(),
"dense_multihead_matmul_replace_pass"
);
multihead_matmul_pattern
();
int
found_multihead_matmul_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"Replace dense multihead matmul with sparse multihead matmul."
;
/* if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}*/
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_out
,
multihead_matmul_out
,
multihead_matmul_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul
,
multihead_matmul
,
multihead_matmul_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_input
,
multihead_matmul_input
,
multihead_matmul_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_weights
,
multihead_matmul_weights
,
multihead_matmul_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_bias
,
multihead_matmul_bias
,
multihead_matmul_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_biasqk
,
multihead_matmul_biasqk
,
multihead_matmul_pattern
);
auto
*
multihead_matmul_op
=
multihead_matmul
->
Op
();
auto
w_name
=
multihead_matmul_op
->
Input
(
"W"
)[
0
];
// recognize sparse op by name
if
(
w_name
.
find
(
"sparse_2_4"
)
!=
w_name
.
npos
)
{
// fake op
OpDesc
desc
(
multihead_matmul_op
->
Block
());
desc
.
SetType
(
"sparse_multihead_matmul"
);
desc
.
SetInput
(
"Input"
,
{
multihead_matmul_input
->
Name
()});
desc
.
SetInput
(
"W"
,
{
multihead_matmul_weights
->
Name
()});
desc
.
SetInput
(
"Bias"
,
{
multihead_matmul_bias
->
Name
()});
desc
.
SetInput
(
"BiasQK"
,
{
multihead_matmul_biasqk
->
Name
()});
desc
.
SetOutput
(
"Out"
,
{
multihead_matmul_out
->
Name
()});
// copy all attr
desc
.
SetAttr
(
"alpha"
,
multihead_matmul_op
->
GetAttr
(
"alpha"
));
desc
.
SetAttr
(
"head_number"
,
multihead_matmul_op
->
GetAttr
(
"head_number"
));
if
(
multihead_matmul_op
->
HasAttr
(
"Input_scale"
))
{
desc
.
SetAttr
(
"Input_scale"
,
multihead_matmul_op
->
GetAttr
(
"Input_scale"
));
}
if
(
multihead_matmul_op
->
HasAttr
(
"fc_out_threshold"
))
{
desc
.
SetAttr
(
"fc_out_threshold"
,
multihead_matmul_op
->
GetAttr
(
"fc_out_threshold"
));
}
if
(
multihead_matmul_op
->
HasAttr
(
"qkv2context_plugin_int8"
))
{
desc
.
SetAttr
(
"qkv2context_plugin_int8"
,
multihead_matmul_op
->
GetAttr
(
"qkv2context_plugin_int8"
));
}
if
(
multihead_matmul_op
->
HasAttr
(
"dp_probs"
))
{
desc
.
SetAttr
(
"dp_probs"
,
multihead_matmul_op
->
GetAttr
(
"dp_probs"
));
}
if
(
multihead_matmul_op
->
HasAttr
(
"out_threshold"
))
{
desc
.
SetAttr
(
"out_threshold"
,
multihead_matmul_op
->
GetAttr
(
"out_threshold"
));
}
desc
.
Flush
();
GraphSafeRemoveNodes
(
g
,
{
multihead_matmul
});
auto
sparse_multihead_matmul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
multihead_matmul_input
,
sparse_multihead_matmul_node
);
IR_NODE_LINK_TO
(
multihead_matmul_weights
,
sparse_multihead_matmul_node
);
IR_NODE_LINK_TO
(
multihead_matmul_bias
,
sparse_multihead_matmul_node
);
IR_NODE_LINK_TO
(
multihead_matmul_biasqk
,
sparse_multihead_matmul_node
);
IR_NODE_LINK_TO
(
sparse_multihead_matmul_node
,
multihead_matmul_out
);
found_multihead_matmul_count
++
;
}
};
gpd
(
graph
,
handler
);
AddStatis
(
found_multihead_matmul_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
dense_multihead_matmul_to_sparse_pass
,
paddle
::
framework
::
ir
::
DenseMultiheadMatmulToSparsePass
);
paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.h
0 → 100644
浏览文件 @
20b38cfa
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
DenseMultiheadMatmul
:
public
PatternBase
{
DenseMultiheadMatmul
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"dense_multihead_matmul"
)
{}
PDNode
*
operator
()();
// declare operator node's name
PATTERN_DECL_NODE
(
multihead_matmul
);
PATTERN_DECL_NODE
(
multihead_matmul_out
);
PATTERN_DECL_NODE
(
multihead_matmul_input
);
PATTERN_DECL_NODE
(
multihead_matmul_weights
);
PATTERN_DECL_NODE
(
multihead_matmul_bias
);
PATTERN_DECL_NODE
(
multihead_matmul_biasqk
);
};
}
// namespace patterns
/**
* Replace dense multihead_matmul op with sparse multihead_matmul op
*/
class
Graph
;
class
DenseMultiheadMatmulToSparsePass
:
public
FusePassBase
{
public:
DenseMultiheadMatmulToSparsePass
();
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
const
std
::
string
name_scope_
{
"dense_multihead_matmul_to_sparse_pass"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass_tester.cc
0 → 100644
浏览文件 @
20b38cfa
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.h" // NOLINT
#include "paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h" // NOLINT
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
AddVarToScope
(
Scope
*
param_scope
,
const
std
::
string
&
name
,
const
DDim
&
dims
)
{
auto
*
tensor
=
param_scope
->
Var
(
name
)
->
GetMutable
<
LoDTensor
>
();
tensor
->
Resize
(
dims
);
tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
}
Scope
*
CreateParamScope
()
{
auto
param_scope
=
new
Scope
();
AddVarToScope
(
param_scope
,
"weights0_sparse_2_4"
,
{
768
,
768
});
AddVarToScope
(
param_scope
,
"weights1_sparse_2_4"
,
{
768
,
768
});
AddVarToScope
(
param_scope
,
"weights2_sparse_2_4"
,
{
768
,
768
});
AddVarToScope
(
param_scope
,
"bias_0"
,
{
768
});
AddVarToScope
(
param_scope
,
"bias_1"
,
{
768
});
AddVarToScope
(
param_scope
,
"bias_2"
,
{
768
});
AddVarToScope
(
param_scope
,
"biasqk"
,
{
768
});
AddVarToScope
(
param_scope
,
"weightsl"
,
{
768
,
768
});
return
param_scope
;
}
TEST
(
DenseMultiHeadMatmulToSparsePass
,
basic
)
{
// inputs operator output
// --------------------------------------------------------------------
// (x) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0_sparse_2_4) mul -> mul_out0
// (layer_norm_out, weights_1_sparse_2_4) mul -> mul_out1
// (layer_norm_out, weights_2_sparse_2_4) mul -> mul_out2
// (mul_out0, bias_0) elementweise_add -> eltadd_0
// (mul_out1, bias_1) elementweise_add -> eltadd_1
// (mul_out2, bias_2) elementweise_add -> eltadd_2
// (eltadd_0) reshape2 -> reshape_0
// (eltadd_1) reshape2 -> reshape_1
// (eltadd_2) reshape2 -> reshape_2
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_0) scale -> scale_0
// (scale_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) mul -> mul_qkv
Layers
layers
;
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
768
});
auto
out
=
layers
.
layer_norm
(
x
);
auto
*
layer_out
=
out
[
0
];
auto
*
weights_0
=
layers
.
data
(
"weights0_sparse_2_4"
,
{
768
,
768
},
true
);
auto
*
weights_1
=
layers
.
data
(
"weights1_sparse_2_4"
,
{
768
,
768
},
true
);
auto
*
weights_2
=
layers
.
data
(
"weights2_sparse_2_4"
,
{
768
,
768
},
true
);
auto
*
mul_out_0
=
layers
.
mul
(
layer_out
,
weights_0
,
nullptr
,
2
);
auto
*
mul_out_1
=
layers
.
mul
(
layer_out
,
weights_1
,
nullptr
,
2
);
auto
*
mul_out_2
=
layers
.
mul
(
layer_out
,
weights_2
,
nullptr
,
2
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
768
},
true
);
auto
*
b1
=
layers
.
data
(
"bias_1"
,
{
768
},
true
);
auto
*
b2
=
layers
.
data
(
"bias_2"
,
{
768
},
true
);
auto
*
elementwise_out_0
=
layers
.
elementwise_add
(
mul_out_0
,
b0
,
nullptr
,
2
);
auto
*
elementwise_out_1
=
layers
.
elementwise_add
(
mul_out_1
,
b1
,
nullptr
,
2
);
auto
*
elementwise_out_2
=
layers
.
elementwise_add
(
mul_out_2
,
b2
,
nullptr
,
2
);
std
::
vector
<
int
>
shape
=
{
1
,
128
,
12
,
64
};
auto
*
reshape_0
=
layers
.
reshape2
(
elementwise_out_0
,
shape
,
true
);
auto
*
reshape_1
=
layers
.
reshape2
(
elementwise_out_1
,
shape
,
true
);
auto
*
reshape_2
=
layers
.
reshape2
(
elementwise_out_2
,
shape
,
true
);
std
::
vector
<
int
>
axis
=
{
0
,
2
,
1
,
3
};
auto
*
transpose_0
=
layers
.
transpose2
(
reshape_0
,
axis
,
true
);
auto
*
transpose_1
=
layers
.
transpose2
(
reshape_1
,
axis
,
true
);
auto
*
transpose_2
=
layers
.
transpose2
(
reshape_2
,
axis
,
true
);
auto
*
scale_0
=
layers
.
scale
(
transpose_0
,
0.125
,
0
,
false
);
auto
*
matmul_qk
=
layers
.
matmul
(
scale_0
,
transpose_1
,
nullptr
,
false
,
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
matmul_qkv
=
layers
.
matmul
(
softmax_qk
,
transpose_2
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
768
},
true
);
auto
*
weights_l
=
layers
.
data
(
"weightsl"
,
{
768
,
768
},
true
);
layers
.
mul
(
reshape_qkv_out
,
weights_l
,
nullptr
,
2
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
auto
fuse_pass
=
PassRegistry
::
Instance
().
Get
(
"multihead_matmul_fuse_pass_v2"
);
auto
sparse_pass
=
PassRegistry
::
Instance
().
Get
(
"dense_multihead_matmul_to_sparse_pass"
);
if
(
fuse_pass
.
get
()
==
nullptr
||
sparse_pass
.
get
()
==
nullptr
)
LOG
(
INFO
)
<<
"asdfasdf"
;
int
num_nodes_before
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
graph
.
reset
(
fuse_pass
->
Apply
(
graph
.
release
()));
graph
.
reset
(
sparse_pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"sparse_multihead_matmul"
);
VLOG
(
3
)
<<
DebugString
(
graph
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
39
,
platform
::
errors
::
InvalidArgument
(
"After the multihead_matmul pass and sparse pass, The "
"node num in graph "
"should be %d, but the result is %d"
,
num_nodes_before
-
39
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
platform
::
errors
::
InvalidArgument
(
"After the multihead_matmul pass and sparse pass, "
"there should be one "
"sparse_multihead_matmul op, but the result is %d"
,
num_fused_nodes_after
));
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
multihead_matmul_fuse_pass
);
USE_PASS
(
multihead_matmul_fuse_pass_v2
);
USE_PASS
(
dense_multihead_matmul_to_sparse_pass
);
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
20b38cfa
...
...
@@ -1960,6 +1960,10 @@ USE_TRT_CONVERTER(strided_slice)
USE_TRT_CONVERTER
(
transformer_input_convert
)
USE_TRT_CONVERTER
(
recover_padding
)
USE_TRT_CONVERTER
(
remove_padding
)
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER
(
sparse_fc
)
USE_TRT_CONVERTER
(
sparse_multihead_matmul
)
#endif
#endif
namespace
paddle_infer
{
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
20b38cfa
...
...
@@ -115,6 +115,8 @@ const std::vector<std::string> kTRTSubgraphPasses({
"remove_padding_recover_padding_pass"
,
//
"delete_remove_padding_recover_padding_pass"
,
//
// "yolo_box_fuse_pass", //
"dense_fc_to_sparse_pass"
,
//
"dense_multihead_matmul_to_sparse_pass"
,
//
"tensorrt_subgraph_pass"
,
//
"conv_bn_fuse_pass"
,
//
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
...
...
paddle/fluid/inference/tensorrt/CMakeLists.txt
浏览文件 @
20b38cfa
# Compiling with WITH_PYTHON=ON and WITH_TENSORRT=ON failed on windows. Temporarily add paddle_inference_api dependency to solve the problem
# Compiling with WITH_PYTHON=ON and WITH_TENSORRT=ON failed on windows.
# Temporarily add paddle_inference_api dependency to solve the problem
if
(
WIN32
)
nv_library
(
tensorrt_engine
...
...
@@ -21,7 +22,7 @@ nv_test(
DEPS dynload_cuda device_context dynamic_loader
)
nv_test
(
test_tensorrt_engine
SRCS test_engine.cc
DEPS dynload_cuda tensorrt_engine
)
SRCS test_engine.cc
test_dynamic_engine.cc
DEPS dynload_cuda tensorrt_engine
tensorrt_plugin
)
add_subdirectory
(
plugin
)
add_subdirectory
(
convert
)
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
20b38cfa
# Add TRT tests
nv_library
(
tensorrt_converter
SRCS matmul_op.cc
list
(
APPEND
CONVERT_FILES
matmul_op.cc
conv2d_op.cc
fc_op.cc
pool2d_op.cc
...
...
@@ -59,7 +60,15 @@ nv_library(
roll_op.cc
transformer_input_convert_op.cc
remove_padding_op.cc
recover_padding_op.cc
recover_padding_op.cc
)
if
(
CUSPARSELT_FOUND AND
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 8
)
list
(
APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc
)
endif
()
nv_library
(
tensorrt_converter
SRCS
${
CONVERT_FILES
}
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto
op_registry
)
...
...
paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc
0 → 100644
浏览文件 @
20b38cfa
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
proto
{
class
OpDesc
;
}
// namespace proto
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
/*
* FC converter convert a sparse_fc op to a sparse_fc plugin in TRT.
*/
class
SparseFcOpConverter
:
public
OpConverter
{
public:
nvinfer1
::
ILayer
*
reshape_before_fc
(
nvinfer1
::
ITensor
*
before_fc
,
nvinfer1
::
Dims
x_dim
,
int
x_num_col_dims
,
std
::
string
output_name
)
{
// add shuffle before fc
nvinfer1
::
Dims
reshape_before_fc_dim
;
reshape_before_fc_dim
.
nbDims
=
x_num_col_dims
+
3
;
// padding shape "* x q x 1 x 1"
for
(
int
i
=
0
;
i
<
reshape_before_fc_dim
.
nbDims
;
i
++
)
{
reshape_before_fc_dim
.
d
[
i
]
=
1
;
}
for
(
int
i
=
0
;
i
<
x_dim
.
nbDims
;
i
++
)
{
if
(
i
<
x_num_col_dims
)
{
reshape_before_fc_dim
.
d
[
i
]
=
0
;
}
else
{
if
(
x_dim
.
d
[
i
]
<
0
)
{
reshape_before_fc_dim
.
d
[
x_num_col_dims
]
=
-
1
;
break
;
}
reshape_before_fc_dim
.
d
[
x_num_col_dims
]
*=
x_dim
.
d
[
i
];
}
}
auto
*
reshape_before_fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
before_fc
);
reshape_before_fc_layer
->
setReshapeDimensions
(
reshape_before_fc_dim
);
reshape_before_fc_layer
->
setName
(
(
"sparse_fc_op_reshape_before_fc: Shuffle (Output: "
+
output_name
+
")"
)
.
c_str
());
return
reshape_before_fc_layer
;
}
nvinfer1
::
ILayer
*
reshape_after_fc
(
nvinfer1
::
ITensor
*
after_fc
,
nvinfer1
::
Dims
x_dim
,
int
x_num_col_dims
)
{
// add shuffle after fc
nvinfer1
::
Dims
reshape_after_fc_dim
;
reshape_after_fc_dim
.
nbDims
=
x_num_col_dims
+
1
;
for
(
int
i
=
0
;
i
<
reshape_after_fc_dim
.
nbDims
;
i
++
)
{
reshape_after_fc_dim
.
d
[
i
]
=
0
;
}
auto
*
reshape_after_fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
after_fc
);
reshape_after_fc_layer
->
setReshapeDimensions
(
reshape_after_fc_dim
);
return
reshape_after_fc_layer
;
}
plugin
::
SpmmPluginDynamic
*
new_spmm_plugin
(
TensorRTEngine
::
Weight
*
weight
,
TensorRTEngine
::
Weight
*
bias
,
const
std
::
string
&
activation_type
,
nvinfer1
::
DataType
type
,
int
outdim
)
{
plugin
::
SpmmPluginDynamic
::
Activation
act
=
plugin
::
SpmmPluginDynamic
::
Activation
::
kNone
;
if
(
activation_type
==
"relu"
)
{
act
=
plugin
::
SpmmPluginDynamic
::
Activation
::
kRelu
;
}
else
if
(
activation_type
==
"gelu"
)
{
act
=
plugin
::
SpmmPluginDynamic
::
Activation
::
kGelu
;
}
else
if
(
activation_type
!=
""
)
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"unknown activation_type %s"
,
activation_type
.
c_str
()));
}
return
new
plugin
::
SpmmPluginDynamic
(
"CustomSpmmPluginDynamic"
,
type
,
outdim
,
weight
->
get
(),
bias
->
get
(),
act
);
}
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a sparse_fc op to tensorrt sparse_fc plugin"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
auto
input_names
=
op_desc
.
InputNames
();
bool
with_bias
=
input_names
.
size
()
>=
3
;
std
::
string
w_name
=
"Y"
;
std
::
string
i_name
=
"X"
;
if
(
with_bias
)
{
w_name
=
"W"
;
i_name
=
"Input"
;
}
// Declare inputs
auto
*
X
=
engine_
->
GetITensor
(
op_desc
.
Input
(
i_name
).
front
());
auto
x_dim
=
X
->
getDimensions
();
// Declare weights
auto
*
Y_v
=
scope
.
FindVar
(
op_desc
.
Input
(
w_name
).
front
());
PADDLE_ENFORCE_NOT_NULL
(
Y_v
,
platform
::
errors
::
NotFound
(
"Can not find %s presistale var of sparse_fc in scope."
,
w_name
));
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
int
x_num_col_dims
=
op_desc
.
HasAttr
(
"x_num_col_dims"
)
?
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"x_num_col_dims"
))
:
(
op_desc
.
HasAttr
(
"in_num_col_dims"
)
?
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"in_num_col_dims"
))
:
1
);
const
std
::
string
activation_type
=
op_desc
.
HasAttr
(
"activation_type"
)
?
BOOST_GET_CONST
(
std
::
string
,
op_desc
.
GetAttr
(
"activation_type"
))
:
""
;
float
*
weight_data
=
nullptr
;
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
bool
support_int8
=
false
;
if
(
op_desc
.
HasAttr
(
"support_int8"
))
{
support_int8
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"support_int8"
));
}
float
in_scale
=
0
;
if
(
enable_int8
||
support_int8
)
{
if
(
enable_int8
)
{
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Input_scale"
));
}
else
{
// attr X is generated by add_support_int8_pass
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"X"
));
}
engine_
->
SetTensorDynamicRange
(
X
,
in_scale
);
}
weight_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
w_name
).
front
(),
Y_t
);
PADDLE_ENFORCE_EQ
(
Y_t
->
dims
().
size
(),
2UL
,
platform
::
errors
::
InvalidArgument
(
"The sparse_fc's weight should be a matrix with 2 dims, but "
"it's %d-dimensional."
,
Y_t
->
dims
().
size
()));
// a matrix
int
m
=
Y_t
->
dims
()[
0
];
int
n
=
Y_t
->
dims
()[
1
];
auto
tranpose_weight
=
[](
const
float
*
src
,
float
*
dst
,
int
m
,
int
n
)
{
for
(
int
i
=
0
;
i
<
m
;
i
++
)
{
for
(
int
j
=
0
;
j
<
n
;
j
++
)
{
dst
[
j
*
m
+
i
]
=
src
[
i
*
n
+
j
];
}
}
};
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
auto
regist_fc
=
[
&
](
nvinfer1
::
ITensor
*
inputs
,
int
n_output
,
TensorRTEngine
::
Weight
&
weight
,
TensorRTEngine
::
Weight
&
bias
)
{
if
(
enable_int8
||
support_int8
)
{
// add conv1x1 layer
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
auto
*
fc_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
X
,
n_output
,
nv_ksize
,
weight
.
get
(),
bias
.
get
());
if
(
activation_type
==
"relu"
)
{
fc_layer_int8
->
setName
(
(
"ernie_fc_op_int8: Convolution (Output: "
+
output_name
+
")"
)
.
c_str
());
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out threshold in fc layers in int8 mode"
));
float
out_scale
=
0
;
if
(
enable_int8
)
{
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
}
else
{
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Out"
));
}
engine_
->
SetTensorDynamicRange
(
fc_layer_int8
->
getOutput
(
0
),
out_scale
);
nvinfer1
::
IActivationLayer
*
relu_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Activation
,
*
(
fc_layer_int8
->
getOutput
(
0
)),
nvinfer1
::
ActivationType
::
kRELU
);
RreplenishLayerAndOutput
(
relu_layer_int8
,
"relu_after_ernie_fc_int8"
,
{
output_name
},
test_mode
);
}
else
{
RreplenishLayerAndOutput
(
fc_layer_int8
,
"ernie_fc_op_int8: Convolution"
,
{
output_name
},
test_mode
);
}
}
else
{
// add fc layer
auto
*
fc_layer_float
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
X
,
n_output
,
weight
.
get
(),
bias
.
get
());
if
(
activation_type
==
"relu"
)
{
fc_layer_float
->
setName
(
(
"ernie_fc_op_float: (Output: "
+
output_name
+
")"
).
c_str
());
nvinfer1
::
IActivationLayer
*
relu_layer_float
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Activation
,
*
(
fc_layer_float
->
getOutput
(
0
)),
nvinfer1
::
ActivationType
::
kRELU
);
RreplenishLayerAndOutput
(
relu_layer_float
,
"relu_after_ernie_fc_float"
,
{
output_name
},
test_mode
);
}
else
{
RreplenishLayerAndOutput
(
fc_layer_float
,
"ernie_fc_op_float"
,
{
output_name
},
test_mode
);
}
}
};
auto
regist_sparse_fc
=
[
&
](
nvinfer1
::
ITensor
*
inputs
,
int
n_output
,
TensorRTEngine
::
Weight
*
weight
,
TensorRTEngine
::
Weight
*
bias
)
{
if
(
enable_int8
||
support_int8
)
{
// add conv layer
float
out_scale
=
0
;
if
(
enable_int8
)
{
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out threshold in sparse_fc layers in int8 mode"
));
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
}
else
{
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Out"
));
}
plugin
::
SpmmPluginDynamic
*
plugin
=
new_spmm_plugin
(
weight
,
bias
,
activation_type
,
nvinfer1
::
DataType
::
kINT8
,
n
);
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
inputs
);
auto
fc_layer_int8
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
fc_layer_int8
->
setName
(
(
"sparse_fc_op_int8: (Output: "
+
output_name
+
")"
).
c_str
());
engine_
->
SetTensorDynamicRange
(
fc_layer_int8
->
getOutput
(
0
),
out_scale
);
auto
*
fc_after_reshape_int8
=
reshape_after_fc
(
fc_layer_int8
->
getOutput
(
0
),
x_dim
,
x_num_col_dims
);
RreplenishLayerAndOutput
(
fc_after_reshape_int8
,
"sparse_fc_op_int8_reshape_after_fc: Shuffle"
,
{
output_name
},
test_mode
);
}
else
{
plugin
::
SpmmPluginDynamic
*
plugin
=
new_spmm_plugin
(
weight
,
bias
,
activation_type
,
with_fp16
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
n
);
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
inputs
);
auto
fc_layer_float
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
fc_layer_float
->
setName
(
(
"sparse_fc_op_float: FullyConnected (Output: "
+
output_name
+
")"
)
.
c_str
());
auto
*
fc_after_reshape_float
=
reshape_after_fc
(
fc_layer_float
->
getOutput
(
0
),
x_dim
,
x_num_col_dims
);
RreplenishLayerAndOutput
(
fc_after_reshape_float
,
"shuffle_after_sparse_fc"
,
{
output_name
},
test_mode
);
}
};
bool
transpose_y
=
false
;
if
(
op_desc
.
HasAttr
(
"transpose_Y"
))
{
transpose_y
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"transpose_Y"
));
}
int
weight_w
,
weight_h
;
if
(
!
transpose_y
)
{
std
::
vector
<
float
>
weight_data_tmp
;
weight_data_tmp
.
reserve
(
Y_t
->
numel
());
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
Y_t
->
numel
()
*
sizeof
(
float
));
tranpose_weight
(
weight_data_tmp
.
data
(),
weight_data
,
m
,
n
);
weight_w
=
n
;
weight_h
=
m
;
}
else
{
weight_w
=
m
;
weight_h
=
n
;
}
size_t
n_output
=
weight_w
;
float
*
bias_data
=
nullptr
;
int
bias_num
=
0
;
if
(
with_bias
)
{
auto
*
b_v
=
scope
.
GetVar
(
op_desc
.
Input
(
"Bias"
).
front
());
auto
*
b_t
=
b_v
->
GetMutable
<
framework
::
LoDTensor
>
();
bias_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
"Bias"
).
front
(),
b_t
);
bias_num
=
b_t
->
numel
();
}
// Running the TRT Static Shape mode: x_num_col_dims-1
if
(
!
engine_
->
with_dynamic_shape
())
{
x_num_col_dims
--
;
}
// If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can
// not add Shuffle layer in ernie's multihead.
// Sparse inference doesn't support variable length for now.
if
(
x_dim
.
nbDims
==
4
&&
x_num_col_dims
==
1
)
{
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
size_t
>
(
Y_t
->
numel
())};
weight
.
dims
.
assign
({
weight_w
,
weight_h
});
TensorRTEngine
::
Weight
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
bias_data
),
static_cast
<
size_t
>
(
bias_num
)};
regist_fc
(
X
,
n_output
,
weight
,
bias
);
}
else
{
// need reshape input before and after fc
PADDLE_ENFORCE_GT
(
x_dim
.
nbDims
,
x_num_col_dims
,
platform
::
errors
::
InvalidArgument
(
"Params and input dims mismatch. Paddle-TRT FC "
"converter expects x_dim.nbDims > x_num_col_dims, but "
"x_dim.nbDims : %d, x_num_col_dims : %d."
,
x_dim
.
nbDims
,
x_num_col_dims
));
half
*
half_data
=
nullptr
;
void
*
w_data
=
nullptr
;
if
(
with_fp16
)
{
half_data
=
new
half
[
Y_t
->
numel
()];
for
(
int
i
=
0
;
i
<
Y_t
->
numel
();
i
++
)
{
half_data
[
i
]
=
static_cast
<
half
>
(
weight_data
[
i
]);
}
w_data
=
static_cast
<
void
*>
(
half_data
);
}
else
{
w_data
=
static_cast
<
void
*>
(
weight_data
);
}
TensorRTEngine
::
Weight
weight
{
with_fp16
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
w_data
,
static_cast
<
size_t
>
(
Y_t
->
numel
())};
weight
.
dims
.
assign
({
weight_w
,
weight_h
});
void
*
b_data
=
nullptr
;
if
(
with_bias
)
{
half
*
half_bias_data
=
nullptr
;
if
(
with_fp16
)
{
half_bias_data
=
new
half
[
bias_num
];
for
(
int
i
=
0
;
i
<
bias_num
;
i
++
)
{
half_bias_data
[
i
]
=
static_cast
<
half
>
(
bias_data
[
i
]);
}
b_data
=
static_cast
<
void
*>
(
half_bias_data
);
}
else
{
b_data
=
static_cast
<
void
*>
(
bias_data
);
}
}
TensorRTEngine
::
Weight
bias
{
with_fp16
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
b_data
,
static_cast
<
size_t
>
(
bias_num
)};
auto
*
reshape_before_fc_layer
=
reshape_before_fc
(
X
,
x_dim
,
x_num_col_dims
,
output_name
);
auto
*
reshape_itensor
=
reshape_before_fc_layer
->
getOutput
(
0
);
if
(
enable_int8
||
support_int8
)
{
engine_
->
SetTensorDynamicRange
(
reshape_itensor
,
in_scale
);
}
regist_sparse_fc
(
reshape_itensor
,
n_output
,
&
weight
,
&
bias
);
}
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
sparse_fc
,
SparseFcOpConverter
);
paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc
0 → 100644
浏览文件 @
20b38cfa
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See
the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
SparseMultiheadMatMulOpConverter
:
public
OpConverter
{
public:
plugin
::
SpmmPluginDynamic
*
new_spmm_plugin
(
TensorRTEngine
::
Weight
*
weight
,
TensorRTEngine
::
Weight
*
bias
,
nvinfer1
::
DataType
type
,
int
outdim
)
{
plugin
::
SpmmPluginDynamic
::
Activation
act
=
plugin
::
SpmmPluginDynamic
::
Activation
::
kNone
;
return
new
plugin
::
SpmmPluginDynamic
(
"CustomSpmmPluginDynamic"
,
type
,
outdim
,
weight
->
get
(),
bias
->
get
(),
act
);
}
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a fluid sparse_multihead_matmul op to a corresponding "
"tensorrt "
"network structure"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Input"
).
front
());
// fc weights and fc bias
auto
weight_name
=
op_desc
.
Input
(
"W"
).
front
();
auto
bias_name
=
op_desc
.
Input
(
"Bias"
).
front
();
auto
*
weight_v
=
scope
.
FindVar
(
weight_name
);
auto
*
weight_t
=
weight_v
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
bias_v
=
scope
.
FindVar
(
bias_name
);
auto
*
bias_t
=
bias_v
->
GetMutable
<
framework
::
LoDTensor
>
();
float
*
weight_data
=
nullptr
;
bool
qkv2context_plugin_int8
=
op_desc
.
HasAttr
(
"qkv2context_plugin_int8"
);
float
in_scale
=
0.
;
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Input_scale"
));
engine_
->
SetTensorDynamicRange
(
input
,
in_scale
);
}
weight_data
=
engine_
->
GetWeightCPUData
(
weight_name
,
weight_t
);
float
*
bias_data
=
engine_
->
GetWeightCPUData
(
bias_name
,
bias_t
);
std
::
vector
<
float
>
weight_data_tmp
;
weight_data_tmp
.
reserve
(
weight_t
->
numel
());
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
weight_t
->
numel
()
*
sizeof
(
float
));
// (hidden_in, 3, hidden_out)
const
auto
&
weight_dims
=
weight_t
->
dims
();
int
hidden_in
=
weight_dims
[
0
];
// channels_in
int
three
=
weight_dims
[
1
];
// channels_out
int
hidden_out
=
weight_dims
[
2
];
// channels_out
int
m
=
hidden_in
;
int
n
=
three
*
hidden_out
;
auto
tranpose_weight
=
[](
const
float
*
src
,
float
*
dst
,
int
m
,
int
n
)
{
for
(
int
i
=
0
;
i
<
m
;
i
++
)
{
for
(
int
j
=
0
;
j
<
n
;
j
++
)
{
dst
[
j
*
m
+
i
]
=
src
[
i
*
n
+
j
];
}
}
};
tranpose_weight
(
weight_data_tmp
.
data
(),
weight_data
,
m
,
n
);
int
head_number
=
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"head_number"
));
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
nvinfer1
::
ILayer
*
layer
=
nullptr
;
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
bool
flag_varseqlen
=
engine_
->
use_varseqlen
()
&&
engine_
->
tensorrt_transformer_posid
()
!=
""
&&
engine_
->
tensorrt_transformer_maskid
()
!=
""
;
if
(
engine_
->
with_dynamic_shape
())
{
if
(
flag_varseqlen
)
{
if
(
engine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kFloat32
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"use use_varseqlen must be int8 or half, not float32."
));
}
nvinfer1
::
Weights
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
int32_t
>
(
weight_t
->
numel
())};
nvinfer1
::
Weights
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
bias_data
),
static_cast
<
int32_t
>
(
bias_t
->
numel
())};
if
(
engine_
->
with_interleaved
())
{
VLOG
(
4
)
<<
"fused multihead_matmul op: use_varseqlen and "
"with_interleaved"
;
if
(
!
op_desc
.
HasAttr
(
"Input_scale"
))
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"use with_interleaved must be int8."
));
}
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
float
dp_probs
=
1.0
/
127.0
;
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
input
,
n
,
nv_ksize
,
weight
,
bias
);
fc_layer
->
setName
(
(
"Multihead: Convolution/FullyConnected: (Output: "
+
output_name
+
")"
)
.
c_str
());
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out_threshold in multihead layers in int8 mode"
));
float
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"fc_out_threshold"
));
engine_
->
SetTensorDynamicRange
(
fc_layer
->
getOutput
(
0
),
out_scale
);
if
(
qkv2context_plugin_int8
)
{
dp_probs
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"dp_probs"
))
/
127.0
;
}
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomQKVToContextPluginDynamic"
,
"3"
);
assert
(
creator
!=
nullptr
);
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"hidden_size"
,
&
hidden_out
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
}};
if
(
qkv2context_plugin_int8
)
{
fields
.
push_back
({
"dq_probs"
,
&
dp_probs
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
1
});
}
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
*
plugin_collection
)
+
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
// remember to free
plugin_collection
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
plugin_collection
->
fields
=
fields
.
data
();
auto
plugin
=
creator
->
createPlugin
(
"CustomQKVToContextPluginDynamic"
,
plugin_collection
);
free
(
plugin_collection
);
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
fc_layer
->
getOutput
(
0
));
if
(
engine_
->
Has
(
"ernie_pos_name"
))
{
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
engine_
->
Get
<
std
::
string
>
(
"ernie_pos_name"
)));
}
else
{
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
2
)
->
getName
()));
// cu_seqlens, eval_placeholder_2
}
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
3
)
->
getName
());
engine_
->
SetTensorDynamicRange
(
max_seqlen_tensor
,
1.0
f
);
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
max_seqlen_tensor
));
nvinfer1
::
Dims
shape_dim
;
shape_dim
.
nbDims
=
1
;
shape_dim
.
d
[
0
]
=
-
1
;
shuffle_layer
->
setReshapeDimensions
(
shape_dim
);
engine_
->
SetTensorDynamicRange
(
shuffle_layer
->
getOutput
(
0
),
1.0
f
);
plugin_inputs
.
emplace_back
(
shuffle_layer
->
getOutput
(
0
));
// max_seqlen, eval_placeholder_3
shuffle_layer
->
setName
(
(
"Multihead: Shuffle: (Output: "
+
output_name
+
")"
).
c_str
());
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
layer
=
plugin_layer
;
}
else
{
int
head_size
=
hidden_out
/
head_number
;
// [3, head_number, head_size, hidden_in] -> [head_number, 3,
// head_size,
// hidden_in]
auto
transpose_weight_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
three
,
int
head_number
,
int
head_size
,
int
hidden_in
)
{
const
int
HH
=
head_size
*
hidden_in
;
for
(
int
i
=
0
;
i
<
three
;
++
i
)
{
for
(
int
n
=
0
;
n
<
head_number
;
++
n
)
{
for
(
int
hh
=
0
;
hh
<
HH
;
++
hh
)
{
dst
[
n
*
three
*
HH
+
i
*
HH
+
hh
]
=
src
[
i
*
head_number
*
HH
+
n
*
HH
+
hh
];
}
}
}
};
// [3, head_number, head_size] -> [head_number, 3, head_size]
auto
transpose_bias_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
N
,
int
H
)
{
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
dst
[
n
*
3
*
H
+
i
*
H
+
h
]
=
src
[
i
*
N
*
H
+
n
*
H
+
h
];
}
}
}
};
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
weight_t
->
numel
()
*
sizeof
(
float
));
transpose_weight_v2
(
weight_data_tmp
.
data
(),
weight_data
,
three
,
head_number
,
head_size
,
hidden_in
);
std
::
vector
<
float
>
bias_data_tmp
;
bias_data_tmp
.
reserve
(
bias_t
->
numel
());
memcpy
(
bias_data_tmp
.
data
(),
bias_data
,
bias_t
->
numel
()
*
sizeof
(
float
));
transpose_bias_v2
(
bias_data_tmp
.
data
(),
bias_data
,
head_number
,
head_size
);
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
float
dp_probs
=
1.0
/
127.0
;
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
input
,
n
,
nv_ksize
,
weight
,
bias
);
}
else
{
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
input
,
n
,
weight
,
bias
);
}
if
(
op_desc
.
HasAttr
(
"fc_out_threshold"
))
{
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out threshold in multihead layers "
"in int8 mode"
));
float
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"fc_out_threshold"
));
engine_
->
SetTensorDynamicRange
(
fc_layer
->
getOutput
(
0
),
out_scale
);
if
(
qkv2context_plugin_int8
)
{
dp_probs
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"dp_probs"
))
/
127.0
;
}
}
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomQKVToContextPluginDynamic"
,
"2"
);
assert
(
creator
!=
nullptr
);
int
type
=
static_cast
<
int
>
(
nvinfer1
::
DataType
::
kHALF
);
if
(
qkv2context_plugin_int8
&&
(
engine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kInt8
))
{
type
=
static_cast
<
int
>
(
nvinfer1
::
DataType
::
kINT8
);
}
bool
has_mask
=
true
;
int
var_seqlen
=
1
;
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"hidden_size"
,
&
hidden_out
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"var_seqlen"
,
&
var_seqlen
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
}};
if
(
qkv2context_plugin_int8
)
{
fields
.
push_back
({
"dq_probs"
,
&
dp_probs
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
1
});
}
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
*
plugin_collection
)
+
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
// remember to free
plugin_collection
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
plugin_collection
->
fields
=
fields
.
data
();
auto
plugin
=
creator
->
createPlugin
(
"CustomQKVToContextPluginDynamic"
,
plugin_collection
);
free
(
plugin_collection
);
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
fc_layer
->
getOutput
(
0
));
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"qkv_plugin_mask"
));
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"pos_id"
));
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
"mask_id"
);
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
max_seqlen_tensor
));
nvinfer1
::
Dims
shape_dim
;
shape_dim
.
nbDims
=
1
;
shape_dim
.
d
[
0
]
=
-
1
;
shuffle_layer
->
setReshapeDimensions
(
shape_dim
);
engine_
->
SetTensorDynamicRange
(
shuffle_layer
->
getOutput
(
0
),
1.0
f
);
plugin_inputs
.
emplace_back
(
shuffle_layer
->
getOutput
(
0
));
// max_seqlen, eval_placeholder_3
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
layer
=
plugin_layer
;
}
}
else
{
PADDLE_ENFORCE_EQ
(
input
->
getDimensions
().
nbDims
,
3
,
platform
::
errors
::
InvalidArgument
(
"The Input dim of the SparseMultiheadMatMul should be 3, "
"but it's (%d) now."
,
input
->
getDimensions
().
nbDims
));
// transpose weight_data from m * n to n * m
auto
*
input_bias_qk
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"BiasQK"
).
front
());
half
*
half_data
=
nullptr
;
void
*
w_data
=
nullptr
;
if
(
with_fp16
)
{
half_data
=
new
half
[
weight_t
->
numel
()];
for
(
int
i
=
0
;
i
<
weight_t
->
numel
();
i
++
)
{
half_data
[
i
]
=
static_cast
<
half
>
(
weight_data
[
i
]);
}
w_data
=
static_cast
<
void
*>
(
half_data
);
}
else
{
w_data
=
static_cast
<
void
*>
(
weight_data
);
}
TensorRTEngine
::
Weight
weight
{
with_fp16
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
w_data
),
static_cast
<
size_t
>
(
weight_t
->
numel
())};
weight
.
dims
.
assign
({
n
,
m
});
half
*
half_bias_data
=
nullptr
;
void
*
b_data
=
nullptr
;
if
(
with_fp16
)
{
half_bias_data
=
new
half
[
bias_t
->
numel
()];
for
(
int
i
=
0
;
i
<
bias_t
->
numel
();
i
++
)
{
half_bias_data
[
i
]
=
static_cast
<
half
>
(
bias_data
[
i
]);
}
b_data
=
static_cast
<
void
*>
(
half_bias_data
);
}
else
{
b_data
=
static_cast
<
void
*>
(
bias_data
);
}
TensorRTEngine
::
Weight
bias
{
with_fp16
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
b_data
,
static_cast
<
size_t
>
(
bias_t
->
numel
())};
// add shuffle before fc
nvinfer1
::
Dims
reshape_before_fc_dim
;
reshape_before_fc_dim
.
nbDims
=
5
;
reshape_before_fc_dim
.
d
[
0
]
=
0
;
reshape_before_fc_dim
.
d
[
1
]
=
0
;
reshape_before_fc_dim
.
d
[
2
]
=
0
;
reshape_before_fc_dim
.
d
[
3
]
=
1
;
reshape_before_fc_dim
.
d
[
4
]
=
1
;
auto
*
reshape_before_fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
engine_
->
SetTensorDynamicRange
(
reshape_before_fc_layer
->
getOutput
(
0
),
in_scale
);
}
reshape_before_fc_layer
->
setReshapeDimensions
(
reshape_before_fc_dim
);
reshape_before_fc_layer
->
setName
(
(
"shuffle_before_sparse_multihead_mamul(Output: "
+
output_name
+
")"
)
.
c_str
());
// add layer fc
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
plugin
::
SpmmPluginDynamic
*
plugin
=
new_spmm_plugin
(
&
weight
,
&
bias
,
nvinfer1
::
DataType
::
kINT8
,
n
);
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
reshape_before_fc_layer
->
getOutput
(
0
));
fc_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
}
else
{
plugin
::
SpmmPluginDynamic
*
plugin
=
new_spmm_plugin
(
&
weight
,
&
bias
,
with_fp16
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
n
);
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
reshape_before_fc_layer
->
getOutput
(
0
));
fc_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
}
if
(
op_desc
.
HasAttr
(
"fc_out_threshold"
))
{
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out threshold in multihead layers in int8 mode"
));
float
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"fc_out_threshold"
));
engine_
->
SetTensorDynamicRange
(
fc_layer
->
getOutput
(
0
),
out_scale
);
}
fc_layer
->
setName
(
(
"sparse_multihead_mamul_fc(Output: "
+
output_name
+
")"
).
c_str
());
// no need to add shuffle after fc, just change it in
// QkvToContextPluginDynamic
// add qkv to context
int
head_size
=
hidden_out
/
head_number
;
float
scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"alpha"
));
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
push_back
(
fc_layer
->
getOutput
(
0
));
plugin_inputs
.
push_back
(
input_bias_qk
);
if
(
engine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kInt8
)
{
with_fp16
=
true
;
}
plugin
::
DynamicPluginTensorRT
*
plugin
=
new
plugin
::
QkvToContextPluginDynamic
(
hidden_in
,
head_number
,
head_size
,
scale
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
plugin_inputs
.
data
(),
2
,
plugin
);
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the Ernie(Bert) model in static shape mode, which "
"is not supported for the time being.
\n
"
"You can use the config.SetTRTDynamicShapeInfo(...) interface to set "
"the shape information to run the dynamic shape mode."
));
}
RreplenishLayerAndOutput
(
layer
,
"multihead_matmul"
,
{
output_name
},
test_mode
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
sparse_multihead_matmul
,
SparseMultiheadMatMulOpConverter
);
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
20b38cfa
...
...
@@ -46,6 +46,12 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set
.
insert
(
"reshape2"
);
int8_teller_set
.
insert
(
"reshape"
);
int8_teller_set
.
insert
(
"reshape2"
);
#endif
#if IS_TRT_VERSION_GE(8000)
teller_set
.
insert
(
"sparse_fc"
);
int8_teller_set
.
insert
(
"sparse_fc"
);
teller_set
.
insert
(
"sparse_multihead_matmul"
);
int8_teller_set
.
insert
(
"sparse_multihead_matmul"
);
#endif
}
...
...
@@ -1753,6 +1759,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
}
#if IS_TRT_VERSION_GE(8000)
if
(
op_type
==
"sparse_fc"
||
op_type
==
"sparse_multihead_matmul"
)
{
if
(
!
with_dynamic_shape
)
{
VLOG
(
3
)
<<
"the sparse_fc and sparse_multihead_matmul does not support "
"static shape yet"
;
return
false
;
}
}
#endif
if
((
*
teller
)(
op_type
,
desc
,
use_no_calib_int8
))
return
true
;
}
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
20b38cfa
nv_library
(
tensorrt_plugin
SRCS trt_plugin.cc
list
(
APPEND
TRT_FILES
trt_plugin.cc
split_op_plugin.cu
elementwise_op_plugin.cu
prelu_op_plugin.cu
...
...
@@ -26,7 +27,15 @@ nv_library(
matmul_op_int8_plugin.cu
transformer_input_convert_plugin.cu
remove_padding_plugin.cu
recover_padding_plugin.cu
recover_padding_plugin.cu
)
if
(
CUSPARSELT_FOUND AND
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 8
)
list
(
APPEND TRT_FILES spmm_plugin.cu
)
endif
()
nv_library
(
tensorrt_plugin
SRCS
${
TRT_FILES
}
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
nv_test
(
...
...
paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu
0 → 100644
浏览文件 @
20b38cfa
此差异已折叠。
点击以展开。
paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h
0 → 100644
浏览文件 @
20b38cfa
/* Copyright (c) 2022, PaddlePaddle Authors, NVIDIA CORPORATION. All rights
reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#pragma once
#include <algorithm>
#include <cassert>
#include <iostream>
#include <stdexcept>
#include <vector>
#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/dynload/cusparseLt.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
class
SpmmPluginDynamic
:
public
nvinfer1
::
IPluginV2DynamicExt
{
public:
enum
class
Activation
{
kNone
,
kRelu
,
kGelu
};
SpmmPluginDynamic
(
const
std
::
string
&
name
,
const
nvinfer1
::
DataType
precision
,
const
int
out_dim
,
const
nvinfer1
::
Weights
&
weight
,
const
nvinfer1
::
Weights
&
bias
,
Activation
activation
);
// The second constructor is for clone member function
SpmmPluginDynamic
(
const
std
::
string
&
name
,
const
nvinfer1
::
DataType
precision
,
const
int
out_dim
,
const
int
k
,
const
void
*
weight
,
size_t
compressed_size
,
const
void
*
bias
,
bool
is_configured
,
const
int
m_max
,
const
int
optim_alg
,
Activation
activation
);
SpmmPluginDynamic
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
);
SpmmPluginDynamic
()
=
delete
;
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
noexcept
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
noexcept
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
noexcept
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
noexcept
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
noexcept
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
noexcept
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
noexcept
override
;
const
char
*
getPluginType
()
const
noexcept
override
;
const
char
*
getPluginVersion
()
const
noexcept
override
;
int
getNbOutputs
()
const
noexcept
override
;
int
initialize
()
noexcept
override
;
void
terminate
()
noexcept
override
;
size_t
getSerializationSize
()
const
noexcept
override
;
void
serialize
(
void
*
buffer
)
const
noexcept
override
;
void
destroy
()
noexcept
override
;
void
setPluginNamespace
(
const
char
*
pluginNamespace
)
noexcept
override
;
const
char
*
getPluginNamespace
()
const
noexcept
override
;
private:
struct
cusparseLtContext
{
cusparseLtHandle_t
handle
;
cusparseLtMatDescriptor_t
matA
;
cusparseLtMatDescriptor_t
matB
;
cusparseLtMatDescriptor_t
matC
;
cusparseLtMatmulDescriptor_t
matmul
;
cusparseLtMatmulAlgSelection_t
alg_sel
;
cusparseLtMatmulPlan_t
plan
;
cusparseLtContext
();
~
cusparseLtContext
();
size_t
workspace_size
{
0
};
bool
is_initialized
{
false
};
int
activation
{
0
};
float
relu_upper_bound
{
0
};
float
relu_threshold
{
0
};
void
init
(
int
m
,
int
n
,
int
k
,
cudaDataType_t
type
,
void
*
bias_ptr
,
SpmmPluginDynamic
::
Activation
activation
);
void
setAlgo
(
int
id
);
void
destroy
();
void
compressMatB
(
int
n
,
int
k
,
cudaDataType_t
type
,
void
*
src
,
void
**
dest
,
size_t
*
compressed_size
);
};
// struct SpmmPluginDynamic::cusparseLtContext
const
std
::
string
layer_name_
;
std
::
string
namespace_
;
nvinfer1
::
DataType
precision_
;
size_t
precision_size_
;
size_t
element_size_
;
// size of weight (float if INT8 or FLOAT; half if HALF)
int
out_dim_
;
int
k_
;
int
m_max_
;
bool
is_configured_
;
// already get m, scale bias, and search the optim alg
// or not
int
optim_alg_
;
// the index of optimal algorithm
float
weight_scale_
;
// record the weight scale from constructor
void
*
weight_compressed_
;
// host compressed weight
void
*
weight_compressed_dev_
;
// device compressed weight
std
::
shared_ptr
<
void
>
weight_compressed_dev_global_
;
// shared pointer to the
// device compressed weight
size_t
compressed_size_
;
// size of compressed weight
bool
has_bias_
;
// there is bias or not
void
*
bias_
;
// host bias
void
*
bias_dev_
;
// device bias
Activation
activation_
;
// record the activation type
cusparseLtContext
spmm_context_
;
};
// class SpmmPluginDynamic
class
SpmmPluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
SpmmPluginDynamicCreator
();
const
char
*
getPluginName
()
const
noexcept
override
;
const
char
*
getPluginVersion
()
const
noexcept
override
;
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
noexcept
override
;
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
noexcept
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
noexcept
override
;
void
setPluginNamespace
(
const
char
*
pluginNamespace
)
noexcept
override
;
const
char
*
getPluginNamespace
()
const
noexcept
override
;
private:
static
nvinfer1
::
PluginFieldCollection
field_collection_
;
static
std
::
vector
<
nvinfer1
::
PluginField
>
plugin_attr_
;
std
::
string
namespace_
;
};
// class SpmmPluginDynamicCreator
REGISTER_TRT_PLUGIN_V2
(
SpmmPluginDynamicCreator
);
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
0 → 100644
浏览文件 @
20b38cfa
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h"
#endif
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/float16.h"
using
float16
=
phi
::
dtype
::
float16
;
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
TensorRTDynamicEngineTest
:
public
::
testing
::
Test
{
protected:
void
SetUp
()
override
{
ctx_
=
new
platform
::
CUDADeviceContext
(
platform
::
CUDAPlace
(
0
));
ctx_
->
SetAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
platform
::
CUDAPlace
(
0
),
ctx_
->
stream
())
.
get
());
ctx_
->
SetHostAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
paddle
::
platform
::
CPUPlace
())
.
get
());
ctx_
->
SetZeroAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetZeroAllocator
(
platform
::
CUDAPlace
(
0
))
.
get
());
ctx_
->
SetPinnedAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
paddle
::
platform
::
CUDAPinnedPlace
())
.
get
());
ctx_
->
PartialInitWithAllocator
();
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
min_input_shape
=
{
{
"input"
,
{
16
,
32
,
1
,
1
}}};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape
=
{
{
"input"
,
{
16
,
32
,
1
,
1
}}};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
optim_input_shape
=
{
{
"input"
,
{
16
,
32
,
1
,
1
}}};
engine_
=
new
TensorRTEngine
(
16
,
1
<<
10
,
AnalysisConfig
::
Precision
::
kHalf
,
nullptr
,
0
,
min_input_shape
,
max_input_shape
,
optim_input_shape
,
false
,
NaiveLogger
::
Global
());
engine_
->
InitNetwork
();
}
void
TearDown
()
override
{
if
(
engine_
)
{
delete
engine_
;
engine_
=
nullptr
;
}
}
void
PrepareInputOutput
(
const
std
::
vector
<
float16
>
&
input
,
std
::
vector
<
int
>
output_shape
)
{
paddle
::
framework
::
TensorFromVector
(
input
,
*
ctx_
,
&
input_
);
output_
.
Resize
(
phi
::
make_ddim
(
output_shape
));
}
void
GetOutput
(
std
::
vector
<
float
>
*
output
)
{
paddle
::
framework
::
TensorToVector
(
output_
,
*
ctx_
,
output
);
}
protected:
framework
::
Tensor
input_
;
framework
::
Tensor
output_
;
TensorRTEngine
*
engine_
;
platform
::
CUDADeviceContext
*
ctx_
;
};
TEST_F
(
TensorRTDynamicEngineTest
,
test_spmm
)
{
// Weight in CPU memory.
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
float16
raw_weight
[
512
];
for
(
int
i
=
0
;
i
<
128
;
i
++
)
{
if
(
i
%
16
<=
7
)
{
raw_weight
[
4
*
i
]
=
float16
(
1.0
);
raw_weight
[
4
*
i
+
1
]
=
float16
(
0.0
);
raw_weight
[
4
*
i
+
2
]
=
float16
(
0.0
);
raw_weight
[
4
*
i
+
3
]
=
float16
(
4.0
);
}
else
{
raw_weight
[
4
*
i
]
=
float16
(
0.0
);
raw_weight
[
4
*
i
+
1
]
=
float16
(
2.0
);
raw_weight
[
4
*
i
+
2
]
=
float16
(
3.0
);
raw_weight
[
4
*
i
+
3
]
=
float16
(
0.0
);
}
}
float16
raw_bias
[
16
]
=
{
float16
(
0
),
float16
(
1
),
float16
(
0
),
float16
(
2
),
float16
(
0
),
float16
(
3
),
float16
(
0
),
float16
(
4
),
float16
(
0
),
float16
(
5
),
float16
(
0
),
float16
(
6
),
float16
(
0
),
float16
(
7
),
float16
(
0
),
float16
(
8
)};
std
::
vector
<
void
*>
buffers
(
2
);
// TRT binded inputs
TensorRTEngine
::
Weight
weight
(
nvinfer1
::
DataType
::
kHALF
,
raw_weight
,
512
);
TensorRTEngine
::
Weight
bias
(
nvinfer1
::
DataType
::
kHALF
,
raw_bias
,
16
);
std
::
cout
<<
"with_dynamic_shape: "
<<
engine_
->
with_dynamic_shape
()
<<
std
::
endl
;
auto
*
x
=
engine_
->
DeclareInput
(
"input"
,
nvinfer1
::
DataType
::
kHALF
,
nvinfer1
::
Dims4
{
-
1
,
32
,
1
,
1
});
plugin
::
SpmmPluginDynamic
::
Activation
act
=
plugin
::
SpmmPluginDynamic
::
Activation
::
kNone
;
plugin
::
SpmmPluginDynamic
*
plugin
=
new
plugin
::
SpmmPluginDynamic
(
"CustomSpmmPluginDynamic"
,
nvinfer1
::
DataType
::
kHALF
,
16
,
weight
.
get
(),
bias
.
get
(),
act
);
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
x
);
auto
fc_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
LOG
(
INFO
)
<<
"create weights"
;
PADDLE_ENFORCE_NOT_NULL
(
fc_layer
,
platform
::
errors
::
InvalidArgument
(
"TRT SPMM layer building failed."
));
engine_
->
DeclareOutput
(
fc_layer
,
0
,
"y"
);
engine_
->
FreezeNetwork
();
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
2
);
std
::
vector
<
float16
>
x_v
(
512
);
for
(
int
i
=
0
;
i
<
128
;
i
++
)
{
x_v
[
4
*
i
]
=
float16
(
1.0
);
x_v
[
4
*
i
+
1
]
=
float16
(
2.0
);
x_v
[
4
*
i
+
2
]
=
float16
(
3.0
);
x_v
[
4
*
i
+
3
]
=
float16
(
4.0
);
}
std
::
vector
<
float
>
y_cpu
;
PrepareInputOutput
(
x_v
,
{
16
,
16
});
auto
*
x_v_gpu_data
=
input_
.
mutable_data
<
float16
>
(
ctx_
->
GetPlace
());
auto
*
y_gpu_data
=
output_
.
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
x_v_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
y_gpu_data
);
engine_
->
Execute
(
16
,
&
buffers
,
ctx_
->
stream
());
LOG
(
INFO
)
<<
"to get output"
;
GetOutput
(
&
y_cpu
);
auto
dims
=
engine_
->
GetITensor
(
"y"
)
->
getDimensions
();
ASSERT_EQ
(
dims
.
nbDims
,
4
);
ASSERT_EQ
(
dims
.
d
[
1
],
16
);
ASSERT_EQ
(
y_cpu
[
0
],
136
);
ASSERT_EQ
(
y_cpu
[
1
],
105
);
ASSERT_EQ
(
y_cpu
[
32
],
136
);
ASSERT_EQ
(
y_cpu
[
64
],
136
);
ASSERT_EQ
(
y_cpu
[
96
],
136
);
#endif
return
;
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/platform/dynload/CMakeLists.txt
浏览文件 @
20b38cfa
...
...
@@ -42,6 +42,10 @@ if(TENSORRT_FOUND)
list
(
APPEND CUDA_SRCS tensorrt.cc
)
endif
()
if
(
CUSPARSELT_FOUND
)
list
(
APPEND CUDA_SRCS cusparseLt.cc
)
endif
()
configure_file
(
cupti_lib_path.h.in
${
CMAKE_CURRENT_BINARY_DIR
}
/cupti_lib_path.h
)
if
(
CUPTI_FOUND
)
list
(
APPEND CUDA_SRCS cupti.cc
)
...
...
paddle/fluid/platform/dynload/cusparseLt.cc
0 → 100644
浏览文件 @
20b38cfa
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/cusparseLt.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
#define DEFINE_WRAP(__name) DynLoad__##__name __name
#ifdef CUSPARSELT_ROUTINE_EACH
CUSPARSELT_ROUTINE_EACH
(
DEFINE_WRAP
);
#endif
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/dynload/cusparseLt.h
0 → 100644
浏览文件 @
20b38cfa
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
#include <cusparseLt.h>
#include <mutex> // NOLINT
#include "paddle/phi/backends/dynload/cusparseLt.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
#define PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP(__name) \
using DynLoad__##__name = phi::dynload::DynLoad__##__name; \
extern DynLoad__##__name __name
#if defined(PADDLE_WITH_CUDA)
#if CUDA_VERSION >= 11020
#define CUSPARSELT_ROUTINE_EACH(__macro) \
__macro(cusparseLtInit); \
__macro(cusparseLtDestroy); \
__macro(cusparseLtDenseDescriptorInit); \
__macro(cusparseLtStructuredDescriptorInit); \
__macro(cusparseLtMatmulDescriptorInit); \
__macro(cusparseLtMatmulDescSetAttribute); \
__macro(cusparseLtMatmulAlgSelectionInit); \
__macro(cusparseLtMatmulAlgSetAttribute); \
__macro(cusparseLtMatmulGetWorkspace); \
__macro(cusparseLtMatmulPlanInit); \
__macro(cusparseLtMatDescriptorDestroy); \
__macro(cusparseLtSpMMACompressedSize2); \
__macro(cusparseLtSpMMACompress2); \
__macro(cusparseLtMatmulSearch); \
__macro(cusparseLtMatmulAlgGetAttribute); \
__macro(cusparseLtMatmulPlanDestroy); \
__macro(cusparseLtMatmul); \
__macro(cusparseGetErrorString);
CUSPARSELT_ROUTINE_EACH
(
PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP
);
#endif
#endif
#undef PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/dynload/dynamic_loader.cc
浏览文件 @
20b38cfa
...
...
@@ -72,6 +72,10 @@ void* GetCUFFTDsoHandle() { return phi::dynload::GetCUFFTDsoHandle(); }
void
*
GetMKLRTDsoHandle
()
{
return
phi
::
dynload
::
GetMKLRTDsoHandle
();
}
void
*
GetCusparseLtDsoHandle
()
{
return
phi
::
dynload
::
GetCusparseLtDsoHandle
();
}
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/dynload/dynamic_loader.h
浏览文件 @
20b38cfa
...
...
@@ -46,6 +46,7 @@ void* GetNvtxDsoHandle();
void
*
GetCUFFTDsoHandle
();
void
*
GetMKLRTDsoHandle
();
void
*
GetROCFFTDsoHandle
();
void
*
GetCusparseLtDsoHandle
();
void
SetPaddleLibPath
(
const
std
::
string
&
);
}
// namespace dynload
...
...
paddle/phi/backends/dynload/CMakeLists.txt
浏览文件 @
20b38cfa
...
...
@@ -42,6 +42,10 @@ if(TENSORRT_FOUND)
list
(
APPEND CUDA_SRCS tensorrt.cc
)
endif
()
if
(
CUSPARSELT_FOUND
)
list
(
APPEND CUDA_SRCS cusparseLt.cc
)
endif
()
configure_file
(
cupti_lib_path.h.in
${
CMAKE_CURRENT_BINARY_DIR
}
/cupti_lib_path.h
)
if
(
CUPTI_FOUND
)
list
(
APPEND CUDA_SRCS cupti.cc
)
...
...
paddle/phi/backends/dynload/cusparseLt.cc
0 → 100644
浏览文件 @
20b38cfa
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/dynload/cusparseLt.h"
namespace
phi
{
namespace
dynload
{
std
::
once_flag
cusparselt_dso_flag
;
void
*
cusparselt_dso_handle
=
nullptr
;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
CUSPARSELT_ROUTINE_EACH
(
DEFINE_WRAP
);
}
// namespace dynload
}
// namespace phi
paddle/phi/backends/dynload/cusparseLt.h
0 → 100644
浏览文件 @
20b38cfa
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
#include <cusparseLt.h>
#include <mutex> // NOLINT
#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/port.h"
namespace
phi
{
namespace
dynload
{
extern
std
::
once_flag
cusparselt_dso_flag
;
extern
void
*
cusparselt_dso_handle
;
/**
* The following macro definition can generate structs
* (for each function) to dynamic load cupti routine
* via operator overloading.
*
* note: default dynamic linked libs
*/
#define DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
cusparseStatus_t operator()(Args... args) { \
using cusparseltFunc = decltype(&::__name); \
std::call_once(cusparselt_dso_flag, []() { \
cusparselt_dso_handle = phi::dynload::GetCusparseLtDsoHandle(); \
}); \
static void *p_##__name = dlsym(cusparselt_dso_handle, #__name); \
return reinterpret_cast<cusparseltFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#if defined(PADDLE_WITH_CUDA)
#if CUDA_VERSION >= 11020
#define CUSPARSELT_ROUTINE_EACH(__macro) \
__macro(cusparseLtInit); \
__macro(cusparseLtDestroy); \
__macro(cusparseLtDenseDescriptorInit); \
__macro(cusparseLtStructuredDescriptorInit); \
__macro(cusparseLtMatmulDescriptorInit); \
__macro(cusparseLtMatmulDescSetAttribute); \
__macro(cusparseLtMatmulAlgSelectionInit); \
__macro(cusparseLtMatmulAlgSetAttribute); \
__macro(cusparseLtMatmulGetWorkspace); \
__macro(cusparseLtMatmulPlanInit); \
__macro(cusparseLtMatDescriptorDestroy); \
__macro(cusparseLtSpMMACompressedSize2); \
__macro(cusparseLtSpMMACompress2); \
__macro(cusparseLtMatmulSearch); \
__macro(cusparseLtMatmulAlgGetAttribute); \
__macro(cusparseLtMatmulPlanDestroy); \
__macro(cusparseLtMatmul); \
__macro(cusparseGetErrorString);
CUSPARSELT_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP
);
#endif
#endif
#undef DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP
}
// namespace dynload
}
// namespace phi
paddle/phi/backends/dynload/dynamic_loader.cc
浏览文件 @
20b38cfa
...
...
@@ -76,6 +76,8 @@ DEFINE_string(mkl_dir,
DEFINE_string
(
op_dir
,
""
,
"Specify path for loading user-defined op library."
);
DEFINE_string
(
cusparselt_dir
,
""
,
"Specify path for loading libcusparseLt.so."
);
#ifdef PADDLE_WITH_HIP
DEFINE_string
(
miopen_dir
,
...
...
@@ -578,5 +580,18 @@ void* GetMKLRTDsoHandle() {
#endif
}
void
*
GetCusparseLtDsoHandle
()
{
// APIs available after CUDA 11.2
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11020
return
GetDsoHandleFromSearchPath
(
FLAGS_cusparselt_dir
,
"libcusparseLt.so"
);
#else
std
::
string
warning_msg
(
"Your CUDA_VERSION less 11.2, not support cusparseLt. "
"If you want to use cusparseLt, please upgrade CUDA and rebuild "
"PaddlePaddle."
);
return
nullptr
;
#endif
}
}
// namespace dynload
}
// namespace phi
paddle/phi/backends/dynload/dynamic_loader.h
浏览文件 @
20b38cfa
...
...
@@ -45,6 +45,7 @@ void* GetNvtxDsoHandle();
void
*
GetCUFFTDsoHandle
();
void
*
GetMKLRTDsoHandle
();
void
*
GetROCFFTDsoHandle
();
void
*
GetCusparseLtDsoHandle
();
void
SetPaddleLibPath
(
const
std
::
string
&
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录