Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
def04fe7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
def04fe7
编写于
4月 17, 2020
作者:
P
Pei Yang
提交者:
GitHub
4月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle-TRT] Add hard_sigmoid and hard_swish support(support MobilenetV3) (#23672) (#23908)
上级
2bca3295
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
299 addition
and
2 deletion
+299
-2
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+2
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/convert/hard_sigmoid_op.cc
paddle/fluid/inference/tensorrt/convert/hard_sigmoid_op.cc
+55
-0
paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc
paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc
+72
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+2
-0
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.cu
...e/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.cu
+86
-0
paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h
...le/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h
+80
-0
未找到文件。
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
def04fe7
...
@@ -961,6 +961,8 @@ USE_TRT_CONVERTER(batch_norm);
...
@@ -961,6 +961,8 @@ USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER
(
concat
);
USE_TRT_CONVERTER
(
concat
);
USE_TRT_CONVERTER
(
dropout
);
USE_TRT_CONVERTER
(
dropout
);
USE_TRT_CONVERTER
(
pad
);
USE_TRT_CONVERTER
(
pad
);
USE_TRT_CONVERTER
(
hard_sigmoid
);
USE_TRT_CONVERTER
(
hard_swish
);
USE_TRT_CONVERTER
(
split
);
USE_TRT_CONVERTER
(
split
);
USE_TRT_CONVERTER
(
prelu
);
USE_TRT_CONVERTER
(
prelu
);
USE_TRT_CONVERTER
(
conv2d_transpose
);
USE_TRT_CONVERTER
(
conv2d_transpose
);
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
def04fe7
...
@@ -4,7 +4,7 @@ nv_library(tensorrt_converter
...
@@ -4,7 +4,7 @@ nv_library(tensorrt_converter
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc
hard_sigmoid_op.cc hard_swish_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry
)
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry
)
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
...
...
paddle/fluid/inference/tensorrt/convert/hard_sigmoid_op.cc
0 → 100644
浏览文件 @
def04fe7
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
/*
* HardSigmoidOp, IActivationLayer in TRT. This Layer doesn't has weights.
*/
class
HardSigmoidOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
#if IS_TRT_VERSION_GE(5000)
VLOG
(
3
)
<<
"convert a fluid HardSigmoid op to tensorrt IActivationLayer "
"layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
float
slope
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"slope"
));
float
offset
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"offset"
));
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Activation
,
*
input
,
nvinfer1
::
ActivationType
::
kHARD_SIGMOID
);
layer
->
setAlpha
(
slope
);
layer
->
setBeta
(
offset
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"hard_sigmoid"
,
{
output_name
},
test_mode
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Hard sigmoid TRT converter is only supported on TRT 5 or higher. "
"Please confirm your TRT version is no less than 5.0."
));
#endif
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
hard_sigmoid
,
HardSigmoidOpConverter
);
paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc
0 → 100644
浏览文件 @
def04fe7
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
/*
* HardSwish converter from fluid to tensorRT.
*/
class
HardSwishOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert fluid HardSwish op to tensorrt HardSwish plugin"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
int
input_num
=
op_desc
.
Input
(
"X"
).
size
();
PADDLE_ENFORCE_EQ
(
input_num
,
1
,
platform
::
errors
::
InvalidArgument
(
"HardSwish op has only 1 input, but got %d"
,
input_num
));
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
// Get output
size_t
output_num
=
op_desc
.
Output
(
"Out"
).
size
();
PADDLE_ENFORCE_EQ
(
output_num
,
1
,
platform
::
errors
::
InvalidArgument
(
"HardSwish op has only 1 output, but got %d"
,
output_num
));
const
float
threshold
=
op_desc
.
HasAttr
(
"threshold"
)
?
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"threshold"
))
:
6.0
f
;
const
float
scale
=
op_desc
.
HasAttr
(
"scale"
)
?
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"scale"
))
:
6.0
f
;
const
float
offset
=
op_desc
.
HasAttr
(
"offset"
)
?
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"offset"
))
:
3.0
f
;
nvinfer1
::
ILayer
*
layer
=
nullptr
;
plugin
::
HardSwishPlugin
*
plugin
=
new
plugin
::
HardSwishPlugin
(
threshold
,
scale
,
offset
);
layer
=
engine_
->
AddPlugin
(
&
input
,
input_num
,
plugin
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"hard_swish"
,
{
output_name
},
test_mode
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
hard_swish
,
HardSwishOpConverter
);
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
def04fe7
...
@@ -23,6 +23,7 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -23,6 +23,7 @@ struct SimpleOpTypeSetTeller : public Teller {
SimpleOpTypeSetTeller
()
{
SimpleOpTypeSetTeller
()
{
#if IS_TRT_VERSION_GE(5130)
#if IS_TRT_VERSION_GE(5130)
teller_set
.
insert
(
"relu6"
);
teller_set
.
insert
(
"relu6"
);
teller_set
.
insert
(
"hard_sigmoid"
);
#endif
#endif
#if IS_TRT_VERSION_GE(6000)
#if IS_TRT_VERSION_GE(6000)
teller_set
.
insert
(
"fused_embedding_eltwise_layernorm"
);
teller_set
.
insert
(
"fused_embedding_eltwise_layernorm"
);
...
@@ -54,6 +55,7 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -54,6 +55,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"relu"
,
"relu"
,
"softmax"
,
"softmax"
,
"sigmoid"
,
"sigmoid"
,
"hard_swish"
,
"depthwise_conv2d"
,
"depthwise_conv2d"
,
"batch_norm"
,
"batch_norm"
,
"concat"
,
"concat"
,
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
def04fe7
...
@@ -3,5 +3,5 @@ nv_library(tensorrt_plugin
...
@@ -3,5 +3,5 @@ nv_library(tensorrt_plugin
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu
hard_swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.cu
0 → 100644
浏览文件 @
def04fe7
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include "paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
HardSwishPlugin
*
CreateHardSwishPluginDeserialize
(
const
void
*
buffer
,
size_t
length
)
{
return
new
HardSwishPlugin
(
buffer
,
length
);
}
REGISTER_TRT_PLUGIN
(
"hard_swish_plugin"
,
CreateHardSwishPluginDeserialize
);
nvinfer1
::
Dims
HardSwishPlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
in_dims
,
int
nb_inputs
)
{
assert
(
nb_inputs
==
1
);
assert
(
index
<
this
->
getNbOutputs
());
nvinfer1
::
Dims
const
&
input_dims
=
in_dims
[
0
];
nvinfer1
::
Dims
output_dims
=
input_dims
;
return
output_dims
;
}
template
<
typename
T
>
__device__
T
kMax
(
T
a
,
T
b
)
{
return
a
>
b
?
a
:
b
;
}
template
<
typename
T
>
__device__
T
kMin
(
T
a
,
T
b
)
{
return
a
<
b
?
a
:
b
;
}
template
<
typename
T
,
unsigned
TPB
>
__global__
void
hard_swish_kernel
(
float
threshold
,
float
scale
,
float
offset
,
int
n
,
const
T
*
input
,
T
*
output
)
{
const
int
idx
=
blockIdx
.
x
*
TPB
+
threadIdx
.
x
;
if
(
idx
<
n
)
{
const
T
in
=
input
[
idx
];
output
[
idx
]
=
in
/
scale
*
kMin
<
T
>
(
kMax
<
T
>
(
in
+
offset
,
0
),
threshold
);
}
}
int
HardSwishPlugin
::
enqueue
(
int
batch_size
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
,
cudaStream_t
stream
)
{
const
auto
&
input_dims
=
this
->
getInputDims
(
0
);
int
num
=
batch_size
;
for
(
int
i
=
0
;
i
<
input_dims
.
nbDims
;
i
++
)
{
num
*=
input_dims
.
d
[
i
];
}
float
threshold
=
threshold_
;
float
scale
=
scale_
;
float
offset
=
offset_
;
const
int
block_size
=
256
;
const
int
grid_size
=
(
num
+
block_size
-
1
)
/
block_size
;
const
float
*
input
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
float
*
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
hard_swish_kernel
<
float
,
block_size
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
threshold
,
scale
,
offset
,
num
,
input
,
output
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h
0 → 100644
浏览文件 @
def04fe7
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
class
HardSwishPlugin
:
public
PluginTensorRT
{
public:
HardSwishPlugin
(
const
float
threshold
,
const
float
scale
,
const
float
offset
)
:
threshold_
(
threshold
),
scale_
(
scale
),
offset_
(
offset
)
{}
// It was used for tensorrt deserialization.
// It should not be called by users.
HardSwishPlugin
(
void
const
*
serialData
,
size_t
serialLength
)
{
deserializeBase
(
serialData
,
serialLength
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
threshold_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
scale_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
offset_
);
}
~
HardSwishPlugin
()
{}
HardSwishPlugin
*
clone
()
const
override
{
return
new
HardSwishPlugin
(
threshold_
,
scale_
,
offset_
);
}
const
char
*
getPluginType
()
const
override
{
return
"hard_swish_plugin"
;
}
int
getNbOutputs
()
const
override
{
return
1
;
}
int
initialize
()
override
{
return
0
;
}
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nbInputDims
)
override
;
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
protected:
float
threshold_
;
float
scale_
;
float
offset_
;
size_t
getSerializationSize
()
override
{
return
getBaseSerializationSize
()
+
SerializedSize
(
threshold_
)
+
SerializedSize
(
scale_
)
+
SerializedSize
(
offset_
)
+
SerializedSize
(
getPluginType
());
}
// TRT will call this func to serialize the configuration of TRT
// It should not be called by users.
void
serialize
(
void
*
buffer
)
override
{
SerializeValue
(
&
buffer
,
getPluginType
());
serializeBase
(
buffer
);
SerializeValue
(
&
buffer
,
threshold_
);
SerializeValue
(
&
buffer
,
scale_
);
SerializeValue
(
&
buffer
,
offset_
);
}
};
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录