Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
86bf8274
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
86bf8274
编写于
3月 16, 2023
作者:
X
xjmxyt
提交者:
GitHub
3月 16, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Deformable Conv Dynamic Shape Support (#50698)
* add dynamic support * add more test * fix bug * change test * change test
上级
290aa368
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
696 addition
and
31 deletion
+696
-31
paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc
...le/fluid/inference/tensorrt/convert/deformable_conv_op.cc
+57
-27
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+0
-4
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
...id/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
+487
-0
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h
...uid/inference/tensorrt/plugin/deformable_conv_op_plugin.h
+130
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_deformable_conv.py
...nittests/ir/inference/test_trt_convert_deformable_conv.py
+22
-0
未找到文件。
paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc
浏览文件 @
86bf8274
...
...
@@ -86,33 +86,63 @@ class DeformableConvOpConverter : public OpConverter {
}
else
{
weights
=
engine_
->
GetFp32TrtWeight
(
filter_name
,
*
filter_tensor
).
get
();
}
auto
*
deformable_conv_plugin
=
new
plugin
::
DeformableConvPlugin
(
with_fp16
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
weights
,
kernel_dims
,
strides
,
paddings
,
dilations
,
groups
,
deformable_groups
,
im2col_step
,
with_fp16
);
std
::
vector
<
nvinfer1
::
ITensor
*>
deformable_conv_inputs
;
deformable_conv_inputs
.
push_back
(
input_tensor
);
deformable_conv_inputs
.
push_back
(
offset_tensor
);
deformable_conv_inputs
.
push_back
(
mask_tensor
);
auto
*
deformable_conv_layer
=
engine_
->
network
()
->
addPluginV2
(
deformable_conv_inputs
.
data
(),
deformable_conv_inputs
.
size
(),
*
deformable_conv_plugin
);
std
::
vector
<
std
::
string
>
output_names
;
output_names
.
push_back
(
op_desc
.
Output
(
"Output"
).
front
());
RreplenishLayerAndOutput
(
deformable_conv_layer
,
"deformable_conv"
,
output_names
,
test_mode
);
if
(
!
engine_
->
with_dynamic_shape
())
{
auto
*
deformable_conv_plugin
=
new
plugin
::
DeformableConvPlugin
(
with_fp16
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
weights
,
kernel_dims
,
strides
,
paddings
,
dilations
,
groups
,
deformable_groups
,
im2col_step
,
with_fp16
);
std
::
vector
<
nvinfer1
::
ITensor
*>
deformable_conv_inputs
;
deformable_conv_inputs
.
push_back
(
input_tensor
);
deformable_conv_inputs
.
push_back
(
offset_tensor
);
deformable_conv_inputs
.
push_back
(
mask_tensor
);
auto
*
deformable_conv_layer
=
engine_
->
network
()
->
addPluginV2
(
deformable_conv_inputs
.
data
(),
deformable_conv_inputs
.
size
(),
*
deformable_conv_plugin
);
std
::
vector
<
std
::
string
>
output_names
;
output_names
.
push_back
(
op_desc
.
Output
(
"Output"
).
front
());
RreplenishLayerAndOutput
(
deformable_conv_layer
,
"deformable_conv"
,
output_names
,
test_mode
);
}
else
{
auto
*
deformable_conv_plugin
=
new
plugin
::
DeformableConvPluginDynamic
(
with_fp16
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
weights
,
kernel_dims
,
strides
,
paddings
,
dilations
,
groups
,
deformable_groups
,
im2col_step
,
with_fp16
);
std
::
vector
<
nvinfer1
::
ITensor
*>
deformable_conv_inputs
;
deformable_conv_inputs
.
push_back
(
input_tensor
);
deformable_conv_inputs
.
push_back
(
offset_tensor
);
deformable_conv_inputs
.
push_back
(
mask_tensor
);
auto
*
deformable_conv_layer
=
engine_
->
network
()
->
addPluginV2
(
deformable_conv_inputs
.
data
(),
deformable_conv_inputs
.
size
(),
*
deformable_conv_plugin
);
std
::
vector
<
std
::
string
>
output_names
;
output_names
.
push_back
(
op_desc
.
Output
(
"Output"
).
front
());
RreplenishLayerAndOutput
(
deformable_conv_layer
,
"deformable_conv"
,
output_names
,
test_mode
);
}
}
};
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
86bf8274
...
...
@@ -306,10 +306,6 @@ struct SimpleOpTypeSetTeller : public Teller {
}
if
(
op_type
==
"deformable_conv"
)
{
if
(
with_dynamic_shape
)
{
VLOG
(
3
)
<<
"Deformable conv trt plugin does not support dynamic shape"
;
return
false
;
}
if
(
!
desc
.
HasAttr
(
"groups"
)
||
!
desc
.
HasAttr
(
"strides"
)
||
!
desc
.
HasAttr
(
"paddings"
))
return
false
;
...
...
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
浏览文件 @
86bf8274
...
...
@@ -18,7 +18,13 @@ limitations under the License. */
#include <algorithm>
#include <cstdio>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -545,6 +551,47 @@ __global__ void ModulatedDeformableIm2colGpuKernel<half>(
#endif
}
template
<
typename
T
>
struct
CUDATypeTraits
;
template
<
>
struct
CUDATypeTraits
<
half
>
{
typedef
platform
::
float16
TYPE
;
};
template
<
>
struct
CUDATypeTraits
<
float
>
{
typedef
float
TYPE
;
};
template
<
typename
T
>
void
gemm_impl_new
(
int
m
,
int
n
,
int
k
,
const
T
*
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
*
beta
,
T
*
C
)
{
auto
*
device_ctx
=
static_cast
<
phi
::
GPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CUDAPlace
(
0
)));
const
phi
::
GPUContext
&
dev_ctx
=
*
device_ctx
;
typedef
typename
CUDATypeTraits
<
T
>::
TYPE
run_type
;
auto
blas
=
phi
::
funcs
::
GetBlas
<
phi
::
GPUContext
,
run_type
>
(
dev_ctx
);
// note: here calls GEMM like cblas, so do not use like cblas
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
n
,
m
,
k
,
static_cast
<
run_type
>
(
*
alpha
),
reinterpret_cast
<
run_type
*>
(
const_cast
<
T
*>
(
B
)),
reinterpret_cast
<
run_type
*>
(
const_cast
<
T
*>
(
A
)),
static_cast
<
run_type
>
(
*
beta
),
reinterpret_cast
<
run_type
*>
(
C
));
}
template
<
typename
T
>
void
gemm_impl
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
...
...
@@ -919,6 +966,446 @@ nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::deserializePlugin(
return
plugin
;
}
#if IS_TRT_VERSION_GE(6000)
DeformableConvPluginDynamic
::
DeformableConvPluginDynamic
(
const
nvinfer1
::
DataType
data_type
,
const
nvinfer1
::
Weights
&
weights
,
const
std
::
vector
<
int
>&
kernel_dims
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
int
groups
,
const
int
deformable_groups
,
const
int
im2col_step
,
const
bool
with_fp16
)
:
data_type_
(
data_type
),
groups_
(
groups
),
deformable_groups_
(
deformable_groups
),
im2col_step_
(
im2col_step
),
with_fp16_
(
with_fp16
)
{
weights_
=
copyToDevice
(
weights
.
values
,
weights
.
count
);
kernel_dims_
.
insert
(
kernel_dims_
.
end
(),
kernel_dims
.
cbegin
(),
kernel_dims
.
cend
());
strides_
.
insert
(
strides_
.
end
(),
strides
.
cbegin
(),
strides
.
cend
());
paddings_
.
insert
(
paddings_
.
end
(),
paddings
.
cbegin
(),
paddings
.
cend
());
dilations_
.
insert
(
dilations_
.
end
(),
dilations
.
cbegin
(),
dilations
.
cend
());
PADDLE_ENFORCE_EQ
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
||
data_type_
==
nvinfer1
::
DataType
::
kHALF
,
true
,
platform
::
errors
::
InvalidArgument
(
"The DeformableConv TRT Plugin's input type "
"should be float or half."
));
PADDLE_ENFORCE_EQ
(
paddings_
.
size
(),
strides_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of paddings (%d) is not equal to the size of strides (%d)."
,
paddings_
.
size
(),
strides_
.
size
()));
}
DeformableConvPluginDynamic
::
DeformableConvPluginDynamic
(
const
void
*
data
,
size_t
length
)
{
DeserializeValue
(
&
data
,
&
length
,
&
data_type_
);
DeserializeValue
(
&
data
,
&
length
,
&
strides_
);
DeserializeValue
(
&
data
,
&
length
,
&
paddings_
);
DeserializeValue
(
&
data
,
&
length
,
&
dilations_
);
DeserializeValue
(
&
data
,
&
length
,
&
groups_
);
DeserializeValue
(
&
data
,
&
length
,
&
deformable_groups_
);
DeserializeValue
(
&
data
,
&
length
,
&
im2col_step_
);
DeserializeValue
(
&
data
,
&
length
,
&
kernel_dims_
);
int64_t
count
;
DeserializeValue
(
&
data
,
&
length
,
&
count
);
weights_
=
deserializeToDevice
(
&
data
,
count
);
DeserializeValue
(
&
data
,
&
length
,
&
with_fp16_
);
}
DeformableConvPluginDynamic
::~
DeformableConvPluginDynamic
()
{
if
(
weights_
.
values
)
{
cudaFree
(
const_cast
<
void
*>
(
weights_
.
values
));
weights_
.
values
=
nullptr
;
}
}
nvinfer1
::
Weights
DeformableConvPluginDynamic
::
copyToDevice
(
const
void
*
hostData
,
size_t
count
)
{
int
num_bytes
=
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
?
4
:
2
);
void
*
deviceData
;
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMalloc
(
&
deviceData
,
count
*
num_bytes
));
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMemcpy
(
deviceData
,
hostData
,
count
*
num_bytes
,
cudaMemcpyHostToDevice
));
return
nvinfer1
::
Weights
{
data_type_
,
deviceData
,
int64_t
(
count
)};
}
void
DeformableConvPluginDynamic
::
serializeFromDevice
(
void
**
hostBuffer
,
const
nvinfer1
::
Weights
&
deviceWeights
)
const
{
int
num_bytes
=
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
?
4
:
2
);
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMemcpy
(
static_cast
<
char
*>
(
*
hostBuffer
),
deviceWeights
.
values
,
deviceWeights
.
count
*
num_bytes
,
cudaMemcpyDeviceToHost
));
*
hostBuffer
=
reinterpret_cast
<
char
*>
(
*
hostBuffer
)
+
deviceWeights
.
count
*
num_bytes
;
}
nvinfer1
::
Weights
DeformableConvPluginDynamic
::
deserializeToDevice
(
const
void
**
hostBuffer
,
size_t
count
)
{
int
num_bytes
=
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
?
4
:
2
);
nvinfer1
::
Weights
w
=
copyToDevice
(
static_cast
<
const
char
*>
(
*
hostBuffer
),
count
);
*
hostBuffer
=
reinterpret_cast
<
const
char
*>
(
*
hostBuffer
)
+
count
*
num_bytes
;
return
w
;
}
int
DeformableConvPluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
size_t
DeformableConvPluginDynamic
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
size_t
serialize_size
=
0
;
serialize_size
+=
SerializedSize
(
data_type_
);
serialize_size
+=
SerializedSize
(
strides_
);
serialize_size
+=
SerializedSize
(
paddings_
);
serialize_size
+=
SerializedSize
(
dilations_
);
serialize_size
+=
SerializedSize
(
groups_
);
serialize_size
+=
SerializedSize
(
deformable_groups_
);
serialize_size
+=
SerializedSize
(
im2col_step_
);
serialize_size
+=
SerializedSize
(
kernel_dims_
);
serialize_size
+=
SerializedSize
(
weights_
.
count
);
int
num_bytes
=
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
?
4
:
2
);
serialize_size
+=
weights_
.
count
*
num_bytes
;
serialize_size
+=
SerializedSize
(
with_fp16_
);
return
serialize_size
;
}
void
DeformableConvPluginDynamic
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{
SerializeValue
(
&
buffer
,
data_type_
);
SerializeValue
(
&
buffer
,
strides_
);
SerializeValue
(
&
buffer
,
paddings_
);
SerializeValue
(
&
buffer
,
dilations_
);
SerializeValue
(
&
buffer
,
groups_
);
SerializeValue
(
&
buffer
,
deformable_groups_
);
SerializeValue
(
&
buffer
,
im2col_step_
);
SerializeValue
(
&
buffer
,
kernel_dims_
);
SerializeValue
(
&
buffer
,
weights_
.
count
);
serializeFromDevice
(
&
buffer
,
weights_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
}
size_t
DeformableConvPluginDynamic
::
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
nbInputs
,
3
,
platform
::
errors
::
InvalidArgument
(
"The number of inputs should be equal to 3, but got %d"
,
nbInputs
));
PADDLE_ENFORCE_EQ
(
nbOutputs
,
1
,
platform
::
errors
::
InvalidArgument
(
"The number of inputs should be equal to 1, but got %d"
,
nbOutputs
));
int
c_i
=
inputs
[
0
].
dims
.
d
[
1
],
h_i
=
inputs
[
0
].
dims
.
d
[
2
],
w_i
=
inputs
[
0
].
dims
.
d
[
3
];
int
k_h
=
kernel_dims_
[
2
],
k_w
=
kernel_dims_
[
3
];
int
c_o
=
outputs
[
0
].
dims
.
d
[
1
],
h_o
=
outputs
[
0
].
dims
.
d
[
2
],
w_o
=
outputs
[
0
].
dims
.
d
[
3
];
int
num_bytes
=
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
?
4
:
2
);
size_t
data_col_size
=
static_cast
<
size_t
>
(
c_i
*
k_h
*
k_w
*
im2col_step_
*
h_o
*
w_o
*
num_bytes
);
return
data_col_size
;
}
nvinfer1
::
DimsExprs
DeformableConvPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputDims
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
nb_inputs
,
3
,
platform
::
errors
::
InvalidArgument
(
"The number of inputs should be equal to 3, but got %d"
,
nb_inputs
));
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
inputDims
[
0
].
nbDims
;
ret
.
d
[
0
]
=
inputDims
[
0
].
d
[
0
];
auto
ConvOutputSizeDynamic
=
[
&
](
const
nvinfer1
::
IDimensionExpr
*
input_size
,
int
filter_size
,
int
dilation
,
int
padding
,
int
stride
)
->
const
nvinfer1
::
IDimensionExpr
*
{
auto
dkernel
=
dilation
*
(
filter_size
-
1
)
+
1
;
return
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kFLOOR_DIV
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
input_size
,
*
expr_builder
.
constant
(
2
*
padding
-
dkernel
)),
*
expr_builder
.
constant
(
stride
)),
*
expr_builder
.
constant
(
1
));
};
ret
.
d
[
1
]
=
expr_builder
.
constant
(
kernel_dims_
[
0
]);
ret
.
d
[
2
]
=
ConvOutputSizeDynamic
(
inputDims
[
0
].
d
[
2
],
kernel_dims_
[
2
],
dilations_
[
0
],
paddings_
[
0
],
strides_
[
0
]);
ret
.
d
[
3
]
=
ConvOutputSizeDynamic
(
inputDims
[
0
].
d
[
3
],
kernel_dims_
[
3
],
dilations_
[
1
],
paddings_
[
1
],
strides_
[
1
]);
return
ret
;
}
bool
DeformableConvPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_NOT_NULL
(
in_out
,
platform
::
errors
::
InvalidArgument
(
"The input of groupnorm plugin shoule not be nullptr."
));
PADDLE_ENFORCE_LT
(
pos
,
nb_inputs
+
nb_outputs
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
"num(%d) of the input and the output."
,
pos
,
nb_inputs
+
nb_outputs
));
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
if
(
pos
==
0
)
{
if
(
with_fp16_
)
{
return
((
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
((
in
.
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
)
||
in
.
format
==
nvinfer1
::
PluginFormat
::
kHWC8
));
}
else
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
}
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
pos
-
1
];
// output
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
}
nvinfer1
::
DataType
DeformableConvPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The Elementwise Plugin only has one input, so the "
"index value should be 0, but get %d."
,
index
));
return
input_types
[
0
];
}
void
DeformableConvPluginDynamic
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
nbInputs
,
3
,
platform
::
errors
::
InvalidArgument
(
"The number of inputs should be equal to 3, but got %d"
,
nbInputs
));
PADDLE_ENFORCE_EQ
(
nbOutputs
,
1
,
platform
::
errors
::
InvalidArgument
(
"The number of inputs should be equal to 1, but got %d"
,
nbOutputs
));
}
int
DeformableConvPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
if
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
)
{
enqueue_impl
<
float
>
(
input_desc
,
output_desc
,
inputs
,
outputs
,
workspace
,
stream
);
}
else
if
(
data_type_
==
nvinfer1
::
DataType
::
kHALF
)
{
#if TRT_PLUGIN_FP16_AVALIABLE
enqueue_impl
<
half
>
(
input_desc
,
output_desc
,
inputs
,
outputs
,
workspace
,
stream
);
#else
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Current CUDA arch dose not support fp16. Please use fp32 instead."
));
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The DeformableConv TRT Plugin's input type should be float or half."
));
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
template
<
typename
T
>
int
DeformableConvPluginDynamic
::
enqueue_impl
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
const
auto
&
input_dims
=
input_desc
[
0
].
dims
;
const
auto
&
offset_dims
=
input_desc
[
1
].
dims
;
const
auto
&
mask_dims
=
input_desc
[
2
].
dims
;
const
auto
&
output_dims
=
output_desc
[
0
].
dims
;
int
batch_size
=
input_dims
.
d
[
0
];
const
T
*
input
=
reinterpret_cast
<
const
T
*>
(
inputs
[
0
]);
const
T
*
offset
=
reinterpret_cast
<
const
T
*>
(
inputs
[
1
]);
const
T
*
mask
=
reinterpret_cast
<
const
T
*>
(
inputs
[
2
]);
const
T
*
filter
=
reinterpret_cast
<
const
T
*>
(
weights_
.
values
);
T
*
output
=
reinterpret_cast
<
T
*>
(
outputs
[
0
]);
int
c_i
=
input_dims
.
d
[
1
],
h_i
=
input_dims
.
d
[
2
],
w_i
=
input_dims
.
d
[
3
];
int
k_h
=
kernel_dims_
[
2
],
k_w
=
kernel_dims_
[
3
];
int
c_o
=
output_dims
.
d
[
1
],
h_o
=
output_dims
.
d
[
2
],
w_o
=
output_dims
.
d
[
3
];
int
input_stride
=
c_i
*
h_i
*
w_i
;
int
offset_stride
=
offset_dims
.
d
[
1
]
*
offset_dims
.
d
[
2
]
*
offset_dims
.
d
[
3
];
int
mask_stride
=
mask_dims
.
d
[
1
]
*
mask_dims
.
d
[
2
]
*
mask_dims
.
d
[
3
];
int
output_stride
=
c_o
*
h_o
*
w_o
;
int
M
=
c_o
/
groups_
;
int
N
=
im2col_step_
*
h_o
*
w_o
;
int
K
=
c_i
*
k_h
*
k_w
/
groups_
;
// c_i / deformable_groups
int
channel_per_deformable_group
=
c_i
/
deformable_groups_
;
// c_i * im2col_step * h_o * w_o
int
num_kernels
=
c_i
*
im2col_step_
*
h_o
*
w_o
;
int
blocks
=
NumBlocks
(
num_kernels
);
int
threads
=
kNumCUDAThreads
;
const
T
alpha
=
static_cast
<
T
>
(
1.0
f
);
const
T
beta
=
static_cast
<
T
>
(
0.0
f
);
for
(
int
i
=
0
;
i
<
batch_size
/
im2col_step_
;
++
i
)
{
const
T
*
data_im
=
input
+
i
*
im2col_step_
*
input_stride
;
const
T
*
data_offset
=
offset
+
i
*
im2col_step_
*
offset_stride
;
const
T
*
data_mask
=
mask
+
i
*
im2col_step_
*
mask_stride
;
T
*
data_col
=
reinterpret_cast
<
T
*>
(
workspace
);
ModulatedDeformableIm2colGpuKernel
<
T
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
num_kernels
,
data_im
,
data_offset
,
data_mask
,
h_i
,
w_i
,
k_h
,
k_w
,
paddings_
[
0
],
paddings_
[
1
],
strides_
[
0
],
strides_
[
1
],
dilations_
[
0
],
dilations_
[
1
],
channel_per_deformable_group
,
im2col_step_
,
c_i
,
deformable_groups_
,
h_o
,
w_o
,
data_col
);
for
(
int
g
=
0
;
g
<
groups_
;
++
g
)
{
const
T
*
weight
=
filter
+
g
*
M
*
K
;
const
T
*
col
=
data_col
+
g
*
K
*
N
;
T
*
out
=
output
+
i
*
im2col_step_
*
output_stride
+
g
*
M
*
N
;
gemm_impl_new
<
T
>
(
N
,
M
,
K
,
&
alpha
,
col
,
weight
,
&
beta
,
out
);
}
}
return
0
;
}
nvinfer1
::
IPluginV2Ext
*
DeformableConvPluginDynamicCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
{
return
new
DeformableConvPluginDynamic
(
serial_data
,
serial_length
);
}
nvinfer1
::
IPluginV2Ext
*
DeformableConvPluginDynamicCreator
::
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
{
const
nvinfer1
::
PluginField
*
fields
=
fc
->
fields
;
nvinfer1
::
DataType
data_type
;
std
::
vector
<
int
>
strides
,
paddings
,
dilations
,
kernel_dims
;
nvinfer1
::
Weights
weights
;
int
groups
=
-
1
;
int
deformable_groups
=
-
1
;
int
im2col_step
=
-
1
;
bool
with_fp16
=
false
;
for
(
int
i
=
0
;
i
<
fc
->
nbFields
;
++
i
)
{
const
std
::
string
field_name
(
fc
->
fields
[
i
].
name
);
if
(
field_name
.
compare
(
"data_type"
)
==
0
)
{
data_type
=
*
static_cast
<
const
nvinfer1
::
DataType
*>
(
fc
->
fields
[
i
].
data
);
}
else
if
(
field_name
.
compare
(
"strides"
))
{
const
int
length
=
fc
->
fields
[
i
].
length
;
const
int
*
data
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
strides
.
insert
(
strides
.
end
(),
data
,
data
+
length
);
}
else
if
(
field_name
.
compare
(
"paddings"
))
{
const
int
length
=
fc
->
fields
[
i
].
length
;
const
int
*
data
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
paddings
.
insert
(
paddings
.
end
(),
data
,
data
+
length
);
}
else
if
(
field_name
.
compare
(
"dilations"
))
{
const
int
length
=
fc
->
fields
[
i
].
length
;
const
int
*
data
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
dilations
.
insert
(
dilations
.
end
(),
data
,
data
+
length
);
}
else
if
(
field_name
.
compare
(
"groups"
))
{
groups
=
*
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
}
else
if
(
field_name
.
compare
(
"deformable_groups"
))
{
deformable_groups
=
*
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
}
else
if
(
field_name
.
compare
(
"im2col_step"
))
{
im2col_step
=
*
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
}
else
if
(
field_name
.
compare
(
"kernel_dims"
))
{
const
int
length
=
fc
->
fields
[
i
].
length
;
const
int
*
data
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
kernel_dims
.
insert
(
kernel_dims
.
end
(),
data
,
data
+
length
);
}
else
if
(
field_name
.
compare
(
"weights"
))
{
weights
.
count
=
fc
->
fields
[
i
].
length
;
weights
.
values
=
fc
->
fields
[
i
].
data
;
}
else
if
(
field_name
.
compare
(
"with_fp16"
))
{
with_fp16
=
*
static_cast
<
const
bool
*>
(
fc
->
fields
[
i
].
data
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unknown plugin field name [%s] in the DeformableConv TRT Plugin."
,
field_name
));
}
}
weights
.
type
=
data_type
;
return
new
DeformableConvPlugin
(
data_type
,
weights
,
kernel_dims
,
strides
,
paddings
,
dilations
,
groups
,
deformable_groups
,
im2col_step
,
with_fp16
);
}
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
...
...
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h
浏览文件 @
86bf8274
...
...
@@ -169,6 +169,136 @@ class DeformableConvPluginCreator : public nvinfer1::IPluginCreator {
REGISTER_TRT_PLUGIN_V2
(
DeformableConvPluginCreator
);
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
class
DeformableConvPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
explicit
DeformableConvPluginDynamic
(
const
nvinfer1
::
DataType
data_type
,
const
nvinfer1
::
Weights
&
weights
,
const
std
::
vector
<
int
>&
kernel_dims
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
int
groups
,
const
int
deformable_groups
,
const
int
im2col_step
,
const
bool
with_fp16
);
DeformableConvPluginDynamic
(
const
void
*
data
,
size_t
length
);
~
DeformableConvPluginDynamic
()
override
;
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
return
new
DeformableConvPluginDynamic
(
data_type_
,
weights_
,
kernel_dims_
,
strides_
,
paddings_
,
dilations_
,
groups_
,
deformable_groups_
,
im2col_step_
,
with_fp16_
);
}
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"deformable_conv_plugin_dynamic"
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
int
initialize
()
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
;
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
// NOLINT
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
{
delete
this
;
}
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
private:
nvinfer1
::
Weights
copyToDevice
(
const
void
*
hostData
,
size_t
count
);
void
serializeFromDevice
(
void
**
hostBuffer
,
const
nvinfer1
::
Weights
&
deviceWeights
)
const
;
nvinfer1
::
Weights
deserializeToDevice
(
const
void
**
hostBuffer
,
size_t
count
);
template
<
typename
T
>
int
enqueue_impl
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
);
bool
with_fp16_
;
nvinfer1
::
DataType
data_type_
;
nvinfer1
::
Weights
weights_
;
std
::
vector
<
int
>
kernel_dims_
;
std
::
vector
<
int
>
strides_
;
std
::
vector
<
int
>
paddings_
;
std
::
vector
<
int
>
dilations_
;
int
groups_
;
int
deformable_groups_
;
int
im2col_step_
;
std
::
string
namespace_
;
cublasHandle_t
cublasHandle_
;
};
class
DeformableConvPluginDynamicCreator
:
public
TensorRTPluginCreator
{
public:
void
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
override
{
namespace_
=
lib_namespace
;
}
const
char
*
getPluginNamespace
()
const
TRT_NOEXCEPT
override
{
return
namespace_
.
c_str
();
}
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"deformable_conv_plugin_dynamic"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
TRT_NOEXCEPT
override
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2Ext
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2Ext
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
;
private:
std
::
string
namespace_
;
nvinfer1
::
PluginFieldCollection
field_collection_
;
};
REGISTER_TRT_PLUGIN_V2
(
DeformableConvPluginDynamicCreator
);
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_deformable_conv.py
浏览文件 @
86bf8274
...
...
@@ -183,6 +183,23 @@ class TrtConvertDeformableConvTest(TrtLayerAutoScanTest):
def
sample_predictor_configs
(
self
,
program_config
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
self
.
dynamic_shape
.
min_input_shape
=
{
"input_data"
:
[
1
,
3
,
32
,
32
],
"offset_data"
:
[
1
,
18
,
14
,
14
],
"mask_data"
:
[
1
,
9
,
14
,
14
],
}
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
1
,
3
,
32
,
32
],
"offset_data"
:
[
1
,
18
,
32
,
32
],
"mask_data"
:
[
1
,
9
,
32
,
32
],
}
self
.
dynamic_shape
.
opt_input_shape
=
{
"input_data"
:
[
1
,
3
,
32
,
32
],
"offset_data"
:
[
1
,
18
,
14
,
16
],
"mask_data"
:
[
1
,
9
,
14
,
16
],
}
def
clear_dynamic_shape
():
self
.
dynamic_shape
.
min_input_shape
=
{}
self
.
dynamic_shape
.
max_input_shape
=
{}
...
...
@@ -205,6 +222,11 @@ class TrtConvertDeformableConvTest(TrtLayerAutoScanTest):
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
False
),
1e-5
generate_dynamic_shape
(
attrs
)
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
True
),
(
1e-5
,
1e-5
)
def
test
(
self
):
self
.
trt_param
.
workspace_size
=
1
<<
28
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录