Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e91141fb
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
e91141fb
编写于
11月 23, 2021
作者:
W
wangxinxin08
提交者:
GitHub
11月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix problem of dcnv2 trt (#37345)
* modify code about fp16 of dcnv2 trt
上级
586bafbd
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
203 addition
and
50 deletion
+203
-50
paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc
...le/fluid/inference/tensorrt/convert/deformable_conv_op.cc
+4
-4
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
...id/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
+188
-40
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h
...uid/inference/tensorrt/plugin/deformable_conv_op_plugin.h
+11
-6
未找到文件。
paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc
浏览文件 @
e91141fb
...
...
@@ -70,7 +70,8 @@ class DeformableConvOpConverter : public OpConverter {
nvinfer1
::
Weights
weights
;
weights
.
count
=
filter_tensor
->
numel
();
if
(
engine_
->
WithFp16
())
{
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
if
(
with_fp16
)
{
auto
half_filter_data
=
new
half
[
filter_tensor
->
numel
()];
for
(
int
i
=
0
;
i
<
filter_tensor
->
numel
();
i
++
)
{
half_filter_data
[
i
]
=
static_cast
<
half
>
(
filter_data
[
i
]);
...
...
@@ -82,10 +83,9 @@ class DeformableConvOpConverter : public OpConverter {
weights
.
values
=
filter_data
;
}
auto
*
deformable_conv_plugin
=
new
plugin
::
DeformableConvPlugin
(
engine_
->
WithFp16
()
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
with_fp16
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
weights
,
kernel_dims
,
strides
,
paddings
,
dilations
,
groups
,
deformable_groups
,
im2col_step
);
deformable_groups
,
im2col_step
,
with_fp16
);
std
::
vector
<
nvinfer1
::
ITensor
*>
deformable_conv_inputs
;
deformable_conv_inputs
.
push_back
(
input_tensor
);
...
...
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
浏览文件 @
e91141fb
...
...
@@ -71,11 +71,13 @@ DeformableConvPlugin::DeformableConvPlugin(
const
nvinfer1
::
DataType
data_type
,
const
nvinfer1
::
Weights
&
weights
,
const
std
::
vector
<
int
>&
kernel_dims
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
int
groups
,
const
int
deformable_groups
,
const
int
im2col_step
)
const
int
groups
,
const
int
deformable_groups
,
const
int
im2col_step
,
const
bool
with_fp16
)
:
data_type_
(
data_type
),
groups_
(
groups
),
deformable_groups_
(
deformable_groups
),
im2col_step_
(
im2col_step
)
{
im2col_step_
(
im2col_step
),
with_fp16_
(
with_fp16
)
{
weights_
=
copyToDevice
(
weights
.
values
,
weights
.
count
);
kernel_dims_
.
insert
(
kernel_dims_
.
end
(),
kernel_dims
.
cbegin
(),
kernel_dims
.
cend
());
...
...
@@ -101,11 +103,13 @@ DeformableConvPlugin::DeformableConvPlugin(
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
int
groups
,
const
int
deformable_groups
,
const
int
im2col_step
,
const
std
::
vector
<
int
>&
input_dim
,
const
std
::
vector
<
int
>&
offset_dim
,
const
std
::
vector
<
int
>&
mask_dim
,
const
std
::
vector
<
int
>&
output_dim
)
const
std
::
vector
<
int
>&
mask_dim
,
const
std
::
vector
<
int
>&
output_dim
,
const
bool
with_fp16
)
:
data_type_
(
data_type
),
groups_
(
groups
),
deformable_groups_
(
deformable_groups
),
im2col_step_
(
im2col_step
)
{
im2col_step_
(
im2col_step
),
with_fp16_
(
with_fp16
)
{
weights_
=
copyToDevice
(
weights
.
values
,
weights
.
count
);
kernel_dims_
.
insert
(
kernel_dims_
.
end
(),
kernel_dims
.
cbegin
(),
kernel_dims
.
cend
());
...
...
@@ -145,6 +149,7 @@ DeformableConvPlugin::DeformableConvPlugin(const void* data, size_t length) {
DeserializeValue
(
&
data
,
&
length
,
&
offset_dim_
);
DeserializeValue
(
&
data
,
&
length
,
&
mask_dim_
);
DeserializeValue
(
&
data
,
&
length
,
&
output_dim_
);
DeserializeValue
(
&
data
,
&
length
,
&
with_fp16_
);
}
DeformableConvPlugin
::~
DeformableConvPlugin
()
{
...
...
@@ -182,8 +187,19 @@ nvinfer1::Dims DeformableConvPlugin::getOutputDimensions(
bool
DeformableConvPlugin
::
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
TensorFormat
format
)
const
TRT_NOEXCEPT
{
return
((
type
==
data_type_
||
type
==
nvinfer1
::
DataType
::
kINT32
)
&&
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
if
(
with_fp16_
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return
(
type
==
nvinfer1
::
DataType
::
kFLOAT
||
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#else
return
(
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#endif
}
else
{
return
(
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
}
size_t
DeformableConvPlugin
::
getWorkspaceSize
(
int
max_batch_size
)
const
...
...
@@ -207,7 +223,7 @@ int DeformableConvPlugin::enqueue(int batch_size, const void* const* inputs,
if
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
)
{
enqueue_impl
<
float
>
(
batch_size
,
inputs
,
outputs
,
workspace
,
stream
);
}
else
if
(
data_type_
==
nvinfer1
::
DataType
::
kHALF
)
{
#if
CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
#if
TRT_PLUGIN_FP16_AVALIABLE
enqueue_impl
<
half
>
(
batch_size
,
inputs
,
outputs
,
workspace
,
stream
);
#else
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
...
@@ -225,7 +241,9 @@ __device__ T kFloor(T x);
template
<
>
__device__
half
kFloor
<
half
>
(
half
x
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
return
hfloor
(
x
);
#endif
}
template
<
>
...
...
@@ -235,35 +253,75 @@ __device__ float kFloor<float>(float x) {
template
<
typename
T
>
__device__
T
DmcnIm2colBilinear
(
const
T
*
bottom_data
,
const
int
data_width
,
const
int
height
,
const
int
width
,
T
h
,
T
w
)
{
int
h_low
=
kFloor
<
T
>
(
h
);
int
w_low
=
kFloor
<
T
>
(
w
);
const
int
height
,
const
int
width
,
T
h
,
T
w
);
template
<
>
__device__
float
DmcnIm2colBilinear
<
float
>
(
const
float
*
bottom_data
,
const
int
data_width
,
const
int
height
,
const
int
width
,
float
h
,
float
w
)
{
int
h_low
=
kFloor
<
float
>
(
h
);
int
w_low
=
kFloor
<
float
>
(
w
);
int
h_high
=
h_low
+
1
;
int
w_high
=
w_low
+
1
;
T
h_low_t
=
h_low
,
w_low_t
=
w_low
,
one
=
1.0
f
;
T
lh
=
h
-
h_low_t
;
T
lw
=
w
-
w_low_t
;
T
hh
=
one
-
lh
,
hw
=
one
-
lw
;
float
h_low_t
=
h_low
,
w_low_t
=
w_low
,
one
=
1.0
f
;
float
lh
=
h
-
h_low_t
;
float
lw
=
w
-
w_low_t
;
float
hh
=
one
-
lh
,
hw
=
one
-
lw
;
T
v1
=
0
;
float
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
v1
=
bottom_data
[
h_low
*
data_width
+
w_low
];
T
v2
=
0
;
float
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
v2
=
bottom_data
[
h_low
*
data_width
+
w_high
];
T
v3
=
0
;
float
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
v3
=
bottom_data
[
h_high
*
data_width
+
w_low
];
T
v4
=
0
;
float
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
v4
=
bottom_data
[
h_high
*
data_width
+
w_high
];
T
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
float
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
T
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
float
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
return
val
;
}
template
<
>
__device__
half
DmcnIm2colBilinear
<
half
>
(
const
half
*
bottom_data
,
const
int
data_width
,
const
int
height
,
const
int
width
,
half
h
,
half
w
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int
h_low
=
kFloor
<
half
>
(
h
);
int
w_low
=
kFloor
<
half
>
(
w
);
int
h_high
=
h_low
+
1
;
int
w_high
=
w_low
+
1
;
half
h_low_t
=
h_low
,
w_low_t
=
w_low
,
one
=
1.0
f
;
half
lh
=
h
-
h_low_t
;
half
lw
=
w
-
w_low_t
;
half
hh
=
one
-
lh
,
hw
=
one
-
lw
;
half
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
v1
=
bottom_data
[
h_low
*
data_width
+
w_low
];
half
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
v2
=
bottom_data
[
h_low
*
data_width
+
w_high
];
half
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
v3
=
bottom_data
[
h_high
*
data_width
+
w_low
];
half
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
v4
=
bottom_data
[
h_high
*
data_width
+
w_high
];
half
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
half
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
return
val
;
#endif
}
template
<
typename
T
>
__global__
void
ModulatedDeformableIm2colGpuKernel
(
const
int
nthreads
,
const
T
*
data_im
,
const
T
*
data_offset
,
...
...
@@ -272,11 +330,21 @@ __global__ void ModulatedDeformableIm2colGpuKernel(
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
num_channels
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
T
*
data_col
)
{
const
int
width_col
,
T
*
data_col
);
template
<
>
__global__
void
ModulatedDeformableIm2colGpuKernel
<
float
>
(
const
int
nthreads
,
const
float
*
data_im
,
const
float
*
data_offset
,
const
float
*
data_mask
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
num_channels
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
float
*
data_col
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
T
minus_one
=
-
1.0
f
,
height_t
=
height
,
width_t
=
width
;
float
minus_one
=
-
1.0
f
,
height_t
=
height
,
width_t
=
width
;
for
(
size_t
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
const
int
w_col
=
i
%
width_col
;
const
int
h_col
=
(
i
/
width_col
)
%
height_col
;
...
...
@@ -289,16 +357,16 @@ __global__ void ModulatedDeformableIm2colGpuKernel(
const
int
h_in
=
h_col
*
stride_h
-
pad_h
;
const
int
w_in
=
w_col
*
stride_w
-
pad_w
;
T
*
data_col_ptr
=
float
*
data_col_ptr
=
data_col
+
((
c_col
*
batch_size
+
b_col
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
T
*
data_im_ptr
=
const
float
*
data_im_ptr
=
data_im
+
(
b_col
*
num_channels
+
c_im
)
*
height
*
width
;
const
T
*
data_offset_ptr
=
const
float
*
data_offset_ptr
=
data_offset
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
T
*
data_mask_ptr
=
const
float
*
data_mask_ptr
=
data_mask
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
...
...
@@ -313,17 +381,17 @@ __global__ void ModulatedDeformableIm2colGpuKernel(
const
int
data_mask_hw_ptr
=
((
i
*
kernel_w
+
j
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
T
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
T
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
T
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
T
val
=
0
;
T
h_im_t
=
h_in
+
i
*
dilation_h
,
w_im_t
=
w_in
+
j
*
dilation_w
;
const
T
h_im
=
h_im_t
+
offset_h
;
const
T
w_im
=
w_im_t
+
offset_w
;
const
float
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
float
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
float
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
float
val
=
0
;
float
h_im_t
=
h_in
+
i
*
dilation_h
,
w_im_t
=
w_in
+
j
*
dilation_w
;
const
float
h_im
=
h_im_t
+
offset_h
;
const
float
w_im
=
w_im_t
+
offset_w
;
if
(
h_im
>
minus_one
&&
w_im
>
minus_one
&&
h_im
<
height_t
&&
w_im
<
width_t
)
{
val
=
DmcnIm2colBilinear
<
T
>
(
data_im_ptr
,
width
,
height
,
width
,
h_im
,
w_im
);
val
=
DmcnIm2colBilinear
<
float
>
(
data_im_ptr
,
width
,
height
,
width
,
h_im
,
w_im
);
}
*
data_col_ptr
=
val
*
mask
;
data_col_ptr
+=
batch_size
*
height_col
*
width_col
;
...
...
@@ -332,6 +400,76 @@ __global__ void ModulatedDeformableIm2colGpuKernel(
}
}
template
<
>
__global__
void
ModulatedDeformableIm2colGpuKernel
<
half
>
(
const
int
nthreads
,
const
half
*
data_im
,
const
half
*
data_offset
,
const
half
*
data_mask
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
num_channels
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
half
*
data_col
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
half
minus_one
=
-
1.0
f
,
height_t
=
height
,
width_t
=
width
;
for
(
size_t
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
const
int
w_col
=
i
%
width_col
;
const
int
h_col
=
(
i
/
width_col
)
%
height_col
;
const
int
b_col
=
(
i
/
width_col
)
/
height_col
%
batch_size
;
const
int
c_im
=
(
i
/
width_col
/
height_col
)
/
batch_size
;
const
int
c_col
=
c_im
*
kernel_h
*
kernel_w
;
const
int
deformable_group_index
=
c_im
/
channel_per_deformable_group
;
const
int
h_in
=
h_col
*
stride_h
-
pad_h
;
const
int
w_in
=
w_col
*
stride_w
-
pad_w
;
half
*
data_col_ptr
=
data_col
+
((
c_col
*
batch_size
+
b_col
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
half
*
data_im_ptr
=
data_im
+
(
b_col
*
num_channels
+
c_im
)
*
height
*
width
;
const
half
*
data_offset_ptr
=
data_offset
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
half
*
data_mask_ptr
=
data_mask
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
for
(
int
i
=
0
;
i
<
kernel_h
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kernel_w
;
++
j
)
{
const
int
data_offset_h_ptr
=
((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_offset_w_ptr
=
((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_mask_hw_ptr
=
((
i
*
kernel_w
+
j
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
half
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
half
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
half
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
half
val
=
0
;
half
h_im_t
=
h_in
+
i
*
dilation_h
,
w_im_t
=
w_in
+
j
*
dilation_w
;
const
half
h_im
=
h_im_t
+
offset_h
;
const
half
w_im
=
w_im_t
+
offset_w
;
if
(
h_im
>
minus_one
&&
w_im
>
minus_one
&&
h_im
<
height_t
&&
w_im
<
width_t
)
{
val
=
DmcnIm2colBilinear
<
half
>
(
data_im_ptr
,
width
,
height
,
width
,
h_im
,
w_im
);
}
*
data_col_ptr
=
val
*
mask
;
data_col_ptr
+=
batch_size
*
height_col
*
width_col
;
}
}
}
#endif
}
template
<
typename
T
>
void
gemm_impl
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
T
*
alpha
,
...
...
@@ -353,8 +491,13 @@ void gemm_impl<half>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
half
*
alpha
,
const
half
*
A
,
int
lda
,
const
half
*
B
,
int
ldb
,
const
half
*
beta
,
half
*
C
,
int
ldc
)
{
#if TRT_PLUGIN_FP16_AVALIABLE
platform
::
dynload
::
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
#else
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Current CUDA arch dose not support fp16. Please use fp32 instead."
));
#endif
}
template
<
typename
T
>
...
...
@@ -436,6 +579,7 @@ size_t DeformableConvPlugin::getSerializationSize() const TRT_NOEXCEPT {
serialize_size
+=
SerializedSize
(
offset_dim_
);
serialize_size
+=
SerializedSize
(
mask_dim_
);
serialize_size
+=
SerializedSize
(
output_dim_
);
serialize_size
+=
SerializedSize
(
with_fp16_
);
return
serialize_size
;
}
...
...
@@ -454,6 +598,7 @@ void DeformableConvPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
SerializeValue
(
&
buffer
,
offset_dim_
);
SerializeValue
(
&
buffer
,
mask_dim_
);
SerializeValue
(
&
buffer
,
output_dim_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
}
void
DeformableConvPlugin
::
destroy
()
TRT_NOEXCEPT
{}
...
...
@@ -521,10 +666,10 @@ void DeformableConvPlugin::configurePlugin(
}
nvinfer1
::
IPluginV2Ext
*
DeformableConvPlugin
::
clone
()
const
TRT_NOEXCEPT
{
return
new
DeformableConvPlugin
(
data_type_
,
weights_
,
kernel_dims_
,
strides_
,
paddings_
,
dilations_
,
group
s_
,
deformable_groups_
,
im2col_step_
,
inpu
t_dim_
,
offset_dim_
,
mask_dim_
,
output_dim
_
);
return
new
DeformableConvPlugin
(
data_type_
,
weights_
,
kernel_dims_
,
strides_
,
paddings_
,
dilation
s_
,
groups_
,
deformable_groups_
,
im2col_step_
,
input_dim_
,
offse
t_dim_
,
mask_dim_
,
output_dim_
,
with_fp16
_
);
}
void
DeformableConvPluginCreator
::
setPluginNamespace
(
const
char
*
lib_namespace
)
...
...
@@ -560,6 +705,7 @@ nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::createPlugin(
int
groups
=
-
1
;
int
deformable_groups
=
-
1
;
int
im2col_step
=
-
1
;
bool
with_fp16
=
false
;
for
(
int
i
=
0
;
i
<
fc
->
nbFields
;
++
i
)
{
const
std
::
string
field_name
(
fc
->
fields
[
i
].
name
);
...
...
@@ -590,6 +736,8 @@ nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::createPlugin(
}
else
if
(
field_name
.
compare
(
"weights"
))
{
weights
.
count
=
fc
->
fields
[
i
].
length
;
weights
.
values
=
fc
->
fields
[
i
].
data
;
}
else
if
(
field_name
.
compare
(
"with_fp16"
))
{
with_fp16
=
*
static_cast
<
const
bool
*>
(
fc
->
fields
[
i
].
data
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unknown plugin field name [%s] in the DeformableConv TRT Plugin."
,
...
...
@@ -599,7 +747,7 @@ nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::createPlugin(
weights
.
type
=
data_type
;
return
new
DeformableConvPlugin
(
data_type
,
weights
,
kernel_dims
,
strides
,
paddings
,
dilations
,
groups
,
deformable_groups
,
im2col_step
);
deformable_groups
,
im2col_step
,
with_fp16
);
}
nvinfer1
::
IPluginV2Ext
*
DeformableConvPluginCreator
::
deserializePlugin
(
...
...
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h
浏览文件 @
e91141fb
...
...
@@ -30,18 +30,22 @@ namespace plugin {
class
DeformableConvPlugin
:
public
nvinfer1
::
IPluginV2Ext
{
public:
explicit
DeformableConvPlugin
(
const
nvinfer1
::
DataType
data_type
,
const
nvinfer1
::
Weights
&
weights
,
const
std
::
vector
<
int
>&
kernel_dims
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
int
groups
,
const
int
deformable_groups
,
const
int
im2col_step
);
explicit
DeformableConvPlugin
(
const
nvinfer1
::
DataType
data_type
,
const
nvinfer1
::
Weights
&
weights
,
const
std
::
vector
<
int
>&
kernel_dims
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
int
groups
,
const
int
deformable_groups
,
const
int
im2col_step
,
const
bool
with_fp16
);
explicit
DeformableConvPlugin
(
const
nvinfer1
::
DataType
data_type
,
const
nvinfer1
::
Weights
&
weights
,
const
std
::
vector
<
int
>&
kernel_dims
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
int
groups
,
const
int
deformable_groups
,
const
int
im2col_step
,
const
std
::
vector
<
int
>&
input_dim
,
const
std
::
vector
<
int
>&
offset_dim
,
const
std
::
vector
<
int
>&
mask_dim
,
const
std
::
vector
<
int
>&
output_dim
);
const
std
::
vector
<
int
>&
mask_dim
,
const
std
::
vector
<
int
>&
output_dim
,
const
bool
with_fp16
);
DeformableConvPlugin
(
const
void
*
data
,
size_t
length
);
~
DeformableConvPlugin
()
override
;
...
...
@@ -98,6 +102,7 @@ class DeformableConvPlugin : public nvinfer1::IPluginV2Ext {
const
nvinfer1
::
Weights
&
deviceWeights
)
const
;
nvinfer1
::
Weights
deserializeToDevice
(
const
void
**
hostBuffer
,
size_t
count
);
bool
with_fp16_
;
nvinfer1
::
DataType
data_type_
;
nvinfer1
::
Weights
weights_
;
std
::
vector
<
int
>
kernel_dims_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录