Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
a1abb7c9
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a1abb7c9
编写于
5月 11, 2022
作者:
W
wenbin
提交者:
GitHub
5月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
swish refactor (#42610)
* swish refactor * bug fix * trt7 non-linear bug fix
上级
29a6b8c9
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
56 addition
and
14 deletion
+56
-14
paddle/fluid/inference/tensorrt/convert/swish_op.cc
paddle/fluid/inference/tensorrt/convert/swish_op.cc
+1
-1
paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.cu
+34
-9
paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h
+21
-4
未找到文件。
paddle/fluid/inference/tensorrt/convert/swish_op.cc
浏览文件 @
a1abb7c9
...
...
@@ -75,7 +75,7 @@ class SwishOpConverter : public OpConverter {
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
plugin
::
SwishPlugin
*
plugin
=
new
plugin
::
SwishPlugin
(
beta
,
with_fp16
);
layer
=
engine_
->
AddPlugin
(
&
input
,
input_num
,
plugin
);
layer
=
engine_
->
AddPlugin
V2Ext
(
&
input
,
input_num
,
plugin
);
}
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
...
...
paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.cu
浏览文件 @
a1abb7c9
...
...
@@ -24,6 +24,16 @@ namespace tensorrt {
namespace
plugin
{
int
SwishPlugin
::
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
void
SwishPlugin
::
terminate
()
TRT_NOEXCEPT
{}
bool
SwishPlugin
::
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
)
const
TRT_NOEXCEPT
{
if
(
with_fp16_
)
{
return
type
==
nvinfer1
::
DataType
::
kFLOAT
||
type
==
nvinfer1
::
DataType
::
kHALF
;
}
return
type
==
nvinfer1
::
DataType
::
kFLOAT
;
}
nvinfer1
::
Dims
SwishPlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputDims
,
...
...
@@ -85,17 +95,29 @@ int SwishPlugin::enqueue(int batch_size, const void *const *inputs,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
#endif
// input dims is CHW.
const
auto
&
input_dims
=
this
->
getInputDims
(
0
);
const
float
*
input
=
reinterpret_cast
<
const
float
*>
(
inputs
[
0
]);
float
*
output
=
reinterpret_cast
<
float
*
const
*>
(
outputs
)[
0
];
int
num
=
batch_size
;
for
(
int
i
=
0
;
i
<
input_dims
.
nbDims
;
i
++
)
{
num
*=
input_dims
.
d
[
i
];
}
int
threads
=
1024
;
int
blocks
=
(
num
+
threads
-
1
)
/
threads
;
swish_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
input
,
output
,
beta_
);
auto
type
=
getDataType
();
if
(
type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. Swish-->fp32"
;
const
float
*
input
=
reinterpret_cast
<
const
float
*>
(
inputs
[
0
]);
float
*
output
=
reinterpret_cast
<
float
*
const
*>
(
outputs
)[
0
];
swish_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
input
,
output
,
beta_
);
}
else
if
(
type
==
nvinfer1
::
DataType
::
kHALF
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. Swish-->fp16"
;
const
half
*
input
=
reinterpret_cast
<
const
half
*>
(
inputs
[
0
]);
half
*
output
=
reinterpret_cast
<
half
*
const
*>
(
outputs
)[
0
];
swish_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
input
,
output
,
(
half
)
beta_
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The Swish TRT Plugin's input type should be float or half."
));
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
...
...
@@ -140,12 +162,15 @@ bool SwishPluginDynamic::supportsFormatCombination(
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
if
(
pos
==
0
)
{
if
(
with_fp16_
)
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
||
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
bool
res
=
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
||
in
.
type
==
nvinfer1
::
DataType
::
kHALF
);
// encounter trt crash bug
#if IS_TRT_VERSION_LT(8000)
res
=
res
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#endif
return
res
;
}
else
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
return
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
;
}
}
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
pos
-
1
];
...
...
paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h
浏览文件 @
a1abb7c9
...
...
@@ -26,7 +26,7 @@ namespace inference {
namespace
tensorrt
{
namespace
plugin
{
class
SwishPlugin
:
public
PluginTensorRT
{
class
SwishPlugin
:
public
PluginTensorRT
V2Ext
{
private:
float
beta_
;
...
...
@@ -55,13 +55,24 @@ class SwishPlugin : public PluginTensorRT {
int
initialize
()
TRT_NOEXCEPT
override
;
SwishPlugin
*
clone
()
const
TRT_NOEXCEPT
override
{
return
new
SwishPlugin
(
beta_
,
with_fp16_
);
nvinfer1
::
IPluginV2Ext
*
clone
()
const
TRT_NOEXCEPT
override
{
auto
*
plugin
=
new
SwishPlugin
(
beta_
,
with_fp16_
);
plugin
->
data_format_
=
data_format_
;
plugin
->
data_type_
=
data_type_
;
plugin
->
input_dims_
=
input_dims_
;
return
plugin
;
}
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"swish_plugin"
;
}
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
override
{
return
input_types
[
0
];
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nbInputDims
)
TRT_NOEXCEPT
override
;
...
...
@@ -71,6 +82,12 @@ class SwishPlugin : public PluginTensorRT {
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
#endif
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
void
terminate
()
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
{
delete
this
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"2"
;
}
bool
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
)
const
TRT_NOEXCEPT
override
;
};
class
SwishPluginCreator
:
public
TensorRTPluginCreator
{
...
...
@@ -79,7 +96,7 @@ class SwishPluginCreator : public TensorRTPluginCreator {
return
"swish_plugin"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"
1
"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"
2
"
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录