Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
8c3decd8
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8c3decd8
编写于
10月 27, 2021
作者:
W
wangxinxin08
提交者:
GitHub
10月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add dcnv2 trt plugin (#36612)
* add dcnv2 plugin
上级
d6b1beb0
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
1204 addition
and
2 deletion
+1204
-2
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+1
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc
...le/fluid/inference/tensorrt/convert/deformable_conv_op.cc
+111
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+47
-1
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
...id/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
+618
-0
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h
...uid/inference/tensorrt/plugin/deformable_conv_op_plugin.h
+148
-0
paddle/fluid/inference/tests/infer_ut/test_ppyolov2_r50vd.cc
paddle/fluid/inference/tests/infer_ut/test_ppyolov2_r50vd.cc
+1
-1
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_deformable_conv.py
...nittests/ir/inference/test_trt_convert_deformable_conv.py
+181
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_deformable_conv.py
.../tests/unittests/ir/inference/test_trt_deformable_conv.py
+95
-0
未找到文件。
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
8c3decd8
...
...
@@ -1415,6 +1415,7 @@ USE_TRT_CONVERTER(tile);
USE_TRT_CONVERTER
(
conv3d
);
USE_TRT_CONVERTER
(
conv3d_transpose
);
USE_TRT_CONVERTER
(
mish
);
USE_TRT_CONVERTER
(
deformable_conv
);
USE_TRT_CONVERTER
(
pool3d
)
#endif
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
8c3decd8
...
...
@@ -20,6 +20,7 @@ nv_library(tensorrt_converter
mish_op.cc
nearest_interp_v2_op.cc
pool3d_op.cc
deformable_conv_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry
)
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
...
...
paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc
0 → 100644
浏览文件 @
8c3decd8
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cstdio>
#include <vector>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
proto
{
class
OpDesc
;
}
// namespace proto
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
DeformableConvOpConverter
:
public
OpConverter
{
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a deformable conv op to tensorrt plugin"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
std
::
string
input_name
=
op_desc
.
Input
(
"Input"
).
front
();
std
::
string
offset_name
=
op_desc
.
Input
(
"Offset"
).
front
();
std
::
string
mask_name
=
op_desc
.
Input
(
"Mask"
).
front
();
std
::
string
filter_name
=
op_desc
.
Input
(
"Filter"
).
front
();
auto
*
input_tensor
=
engine_
->
GetITensor
(
input_name
);
auto
*
offset_tensor
=
engine_
->
GetITensor
(
offset_name
);
auto
*
mask_tensor
=
engine_
->
GetITensor
(
mask_name
);
auto
*
filter_var
=
scope
.
FindVar
(
filter_name
);
auto
*
filter_tensor
=
filter_var
->
GetMutable
<
framework
::
LoDTensor
>
();
float
*
filter_data
=
engine_
->
GetWeightCPUData
(
filter_name
,
filter_tensor
,
false
);
const
int
c_o
=
filter_tensor
->
dims
()[
0
];
const
int
c_i
=
filter_tensor
->
dims
()[
1
];
const
int
k_h
=
filter_tensor
->
dims
()[
2
];
const
int
k_w
=
filter_tensor
->
dims
()[
3
];
std
::
vector
<
int
>
kernel_dims
=
{
c_o
,
c_i
,
k_h
,
k_w
};
auto
strides
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"strides"
));
auto
paddings
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"paddings"
));
auto
dilations
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"dilations"
));
auto
groups
=
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"groups"
));
auto
deformable_groups
=
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"deformable_groups"
));
auto
im2col_step
=
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"im2col_step"
));
nvinfer1
::
Weights
weights
;
weights
.
count
=
filter_tensor
->
numel
();
if
(
engine_
->
WithFp16
())
{
auto
half_filter_data
=
new
half
[
filter_tensor
->
numel
()];
for
(
int
i
=
0
;
i
<
filter_tensor
->
numel
();
i
++
)
{
half_filter_data
[
i
]
=
static_cast
<
half
>
(
filter_data
[
i
]);
}
weights
.
type
=
nvinfer1
::
DataType
::
kHALF
;
weights
.
values
=
half_filter_data
;
}
else
{
weights
.
type
=
nvinfer1
::
DataType
::
kFLOAT
;
weights
.
values
=
filter_data
;
}
auto
*
deformable_conv_plugin
=
new
plugin
::
DeformableConvPlugin
(
engine_
->
WithFp16
()
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
weights
,
kernel_dims
,
strides
,
paddings
,
dilations
,
groups
,
deformable_groups
,
im2col_step
);
std
::
vector
<
nvinfer1
::
ITensor
*>
deformable_conv_inputs
;
deformable_conv_inputs
.
push_back
(
input_tensor
);
deformable_conv_inputs
.
push_back
(
offset_tensor
);
deformable_conv_inputs
.
push_back
(
mask_tensor
);
auto
*
deformable_conv_layer
=
engine_
->
network
()
->
addPluginV2
(
deformable_conv_inputs
.
data
(),
deformable_conv_inputs
.
size
(),
*
deformable_conv_plugin
);
std
::
vector
<
std
::
string
>
output_names
;
output_names
.
push_back
(
op_desc
.
Output
(
"Output"
).
front
());
RreplenishLayerAndOutput
(
deformable_conv_layer
,
"deformable_conv"
,
output_names
,
test_mode
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
deformable_conv
,
DeformableConvOpConverter
);
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
8c3decd8
...
...
@@ -143,7 +143,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"conv3d_transpose"
,
"mish"
,
"nearest_interp_v2"
,
"pool3d"
};
"pool3d"
,
"deformable_conv"
};
};
bool
OpTeller
::
Tell
(
const
framework
::
ir
::
Node
*
node
,
bool
use_no_calib_int8
,
...
...
@@ -332,6 +333,51 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
#endif
}
if
(
op_type
==
"deformable_conv"
)
{
if
(
with_dynamic_shape
)
{
VLOG
(
3
)
<<
"Deformable conv trt plugin does not support dynamic shape"
;
return
false
;
}
auto
*
block
=
desc
.
Block
();
auto
input_name
=
desc
.
Input
(
"Input"
)[
0
];
auto
*
input_desc
=
block
->
FindVar
(
input_name
);
const
auto
input_shape
=
input_desc
->
GetShape
();
if
(
input_shape
.
size
()
!=
4
)
{
VLOG
(
3
)
<<
"Input of deformable conv should be 4-D Tensor, but got "
<<
input_shape
.
size
();
return
false
;
}
auto
filter_name
=
desc
.
Input
(
"Filter"
)[
0
];
auto
*
filter_desc
=
block
->
FindVar
(
filter_name
);
const
auto
filter_shape
=
filter_desc
->
GetShape
();
int
groups
=
BOOST_GET_CONST
(
int
,
desc
.
GetAttr
(
"groups"
));
if
(
input_shape
[
1
]
!=
filter_shape
[
1
]
*
groups
)
{
VLOG
(
3
)
<<
"The number of input channels should be equal to filter "
<<
"channels * groups. But got input channels "
<<
input_shape
[
1
]
<<
"filter channels "
<<
filter_shape
[
1
];
return
false
;
}
const
std
::
vector
<
int
>
strides
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"strides"
));
if
(
strides
.
size
()
!=
2
)
{
VLOG
(
3
)
<<
"The size of strides should be 2, but got "
<<
strides
.
size
();
return
false
;
}
const
std
::
vector
<
int
>
paddings
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"paddings"
));
if
(
paddings
.
size
()
!=
2
)
{
VLOG
(
3
)
<<
"The size of paddings shoule be 2, but got "
<<
paddings
.
size
();
return
false
;
}
}
if
(
op_type
==
"matmul"
)
{
auto
*
block
=
desc
.
Block
();
if
(
block
==
nullptr
)
{
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
8c3decd8
...
...
@@ -11,6 +11,7 @@ nv_library(tensorrt_plugin
gather_nd_op_plugin.cu
mish_op_plugin.cu
pool3d_op_plugin.cu
deformable_conv_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
nv_test
(
test_split_plugin SRCS test_split_plugin.cc DEPS
...
...
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
0 → 100644
浏览文件 @
8c3decd8
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cstdio>
#include "paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
static
constexpr
int
kNumCUDAThreads
=
512
;
static
constexpr
int
kNumMaximumNumBlocks
=
4096
;
static
inline
int
NumBlocks
(
const
int
N
)
{
return
std
::
min
((
N
+
kNumCUDAThreads
-
1
)
/
kNumCUDAThreads
,
kNumMaximumNumBlocks
);
}
static
inline
int
ConvOutputSize
(
int
input_size
,
int
filter_size
,
int
dilation
,
int
padding
,
int
stride
)
{
const
int
dkernel
=
dilation
*
(
filter_size
-
1
)
+
1
;
int
output_size
=
(
input_size
+
2
*
padding
-
dkernel
)
/
stride
+
1
;
return
output_size
;
}
nvinfer1
::
Weights
DeformableConvPlugin
::
copyToDevice
(
const
void
*
hostData
,
size_t
count
)
{
int
num_bytes
=
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
?
4
:
2
);
void
*
deviceData
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMalloc
(
&
deviceData
,
count
*
num_bytes
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMemcpy
(
deviceData
,
hostData
,
count
*
num_bytes
,
cudaMemcpyHostToDevice
));
return
nvinfer1
::
Weights
{
data_type_
,
deviceData
,
int64_t
(
count
)};
}
void
DeformableConvPlugin
::
serializeFromDevice
(
void
**
hostBuffer
,
const
nvinfer1
::
Weights
&
deviceWeights
)
const
{
int
num_bytes
=
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
?
4
:
2
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMemcpy
(
static_cast
<
char
*>
(
*
hostBuffer
),
deviceWeights
.
values
,
deviceWeights
.
count
*
num_bytes
,
cudaMemcpyDeviceToHost
));
hostBuffer
+=
deviceWeights
.
count
*
num_bytes
;
}
nvinfer1
::
Weights
DeformableConvPlugin
::
deserializeToDevice
(
const
void
**
hostBuffer
,
size_t
count
)
{
int
num_bytes
=
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
?
4
:
2
);
nvinfer1
::
Weights
w
=
copyToDevice
(
static_cast
<
const
char
*>
(
*
hostBuffer
),
count
);
hostBuffer
+=
count
*
num_bytes
;
return
w
;
}
DeformableConvPlugin
::
DeformableConvPlugin
(
const
nvinfer1
::
DataType
data_type
,
const
nvinfer1
::
Weights
&
weights
,
const
std
::
vector
<
int
>&
kernel_dims
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
int
groups
,
const
int
deformable_groups
,
const
int
im2col_step
)
:
data_type_
(
data_type
),
groups_
(
groups
),
deformable_groups_
(
deformable_groups
),
im2col_step_
(
im2col_step
)
{
weights_
=
copyToDevice
(
weights
.
values
,
weights
.
count
);
kernel_dims_
.
insert
(
kernel_dims_
.
end
(),
kernel_dims
.
cbegin
(),
kernel_dims
.
cend
());
strides_
.
insert
(
strides_
.
end
(),
strides
.
cbegin
(),
strides
.
cend
());
paddings_
.
insert
(
paddings_
.
end
(),
paddings
.
cbegin
(),
paddings
.
cend
());
dilations_
.
insert
(
dilations_
.
end
(),
dilations
.
cbegin
(),
dilations
.
cend
());
PADDLE_ENFORCE_EQ
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
||
data_type_
==
nvinfer1
::
DataType
::
kHALF
,
true
,
platform
::
errors
::
InvalidArgument
(
"The DeformableConv TRT Plugin's input type "
"should be float or half."
));
PADDLE_ENFORCE_EQ
(
paddings_
.
size
(),
strides_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of paddings (%d) is not equal to the size of strides (%d)."
,
paddings_
.
size
(),
strides_
.
size
()));
}
DeformableConvPlugin
::
DeformableConvPlugin
(
const
nvinfer1
::
DataType
data_type
,
const
nvinfer1
::
Weights
&
weights
,
const
std
::
vector
<
int
>&
kernel_dims
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
int
groups
,
const
int
deformable_groups
,
const
int
im2col_step
,
const
std
::
vector
<
int
>&
input_dim
,
const
std
::
vector
<
int
>&
offset_dim
,
const
std
::
vector
<
int
>&
mask_dim
,
const
std
::
vector
<
int
>&
output_dim
)
:
data_type_
(
data_type
),
groups_
(
groups
),
deformable_groups_
(
deformable_groups
),
im2col_step_
(
im2col_step
)
{
weights_
=
copyToDevice
(
weights
.
values
,
weights
.
count
);
kernel_dims_
.
insert
(
kernel_dims_
.
end
(),
kernel_dims
.
cbegin
(),
kernel_dims
.
cend
());
strides_
.
insert
(
strides_
.
end
(),
strides
.
cbegin
(),
strides
.
cend
());
paddings_
.
insert
(
paddings_
.
end
(),
paddings
.
cbegin
(),
paddings
.
cend
());
dilations_
.
insert
(
dilations_
.
end
(),
dilations
.
cbegin
(),
dilations
.
cend
());
input_dim_
.
insert
(
input_dim_
.
end
(),
input_dim
.
cbegin
(),
input_dim
.
cend
());
offset_dim_
.
insert
(
offset_dim_
.
end
(),
offset_dim
.
cbegin
(),
offset_dim
.
cend
());
mask_dim_
.
insert
(
mask_dim_
.
end
(),
mask_dim
.
cbegin
(),
mask_dim
.
cend
());
output_dim_
.
insert
(
output_dim_
.
end
(),
output_dim
.
cbegin
(),
output_dim
.
cend
());
PADDLE_ENFORCE_EQ
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
||
data_type_
==
nvinfer1
::
DataType
::
kHALF
,
true
,
platform
::
errors
::
InvalidArgument
(
"The DeformableConv TRT Plugin's input type "
"should be float or half."
));
PADDLE_ENFORCE_EQ
(
paddings_
.
size
(),
strides_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of paddings (%d) is not equal to the size of strides (%d)."
,
paddings_
.
size
(),
strides_
.
size
()));
}
DeformableConvPlugin
::
DeformableConvPlugin
(
const
void
*
data
,
size_t
length
)
{
DeserializeValue
(
&
data
,
&
length
,
&
data_type_
);
DeserializeValue
(
&
data
,
&
length
,
&
strides_
);
DeserializeValue
(
&
data
,
&
length
,
&
paddings_
);
DeserializeValue
(
&
data
,
&
length
,
&
dilations_
);
DeserializeValue
(
&
data
,
&
length
,
&
groups_
);
DeserializeValue
(
&
data
,
&
length
,
&
deformable_groups_
);
DeserializeValue
(
&
data
,
&
length
,
&
im2col_step_
);
DeserializeValue
(
&
data
,
&
length
,
&
kernel_dims_
);
int64_t
count
;
DeserializeValue
(
&
data
,
&
length
,
&
count
);
weights_
=
deserializeToDevice
(
&
data
,
count
);
DeserializeValue
(
&
data
,
&
length
,
&
input_dim_
);
DeserializeValue
(
&
data
,
&
length
,
&
offset_dim_
);
DeserializeValue
(
&
data
,
&
length
,
&
mask_dim_
);
DeserializeValue
(
&
data
,
&
length
,
&
output_dim_
);
}
DeformableConvPlugin
::~
DeformableConvPlugin
()
{
if
(
weights_
.
values
)
{
cudaFree
(
const_cast
<
void
*>
(
weights_
.
values
));
weights_
.
values
=
nullptr
;
}
}
const
char
*
DeformableConvPlugin
::
getPluginType
()
const
TRT_NOEXCEPT
{
return
"deformable_conv_plugin"
;
}
const
char
*
DeformableConvPlugin
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
"1"
;
}
int
DeformableConvPlugin
::
getNbOutputs
()
const
TRT_NOEXCEPT
{
return
1
;
}
nvinfer1
::
Dims
DeformableConvPlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nb_input_dims
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
nb_input_dims
,
3
,
platform
::
errors
::
InvalidArgument
(
"The number of inputs should be equal to 3, but got %d"
,
nb_input_dims
));
nvinfer1
::
Dims
ret
;
ret
.
nbDims
=
inputs
[
0
].
nbDims
;
ret
.
d
[
0
]
=
kernel_dims_
[
0
];
ret
.
d
[
1
]
=
ConvOutputSize
(
inputs
[
0
].
d
[
1
],
kernel_dims_
[
2
],
dilations_
[
0
],
paddings_
[
0
],
strides_
[
0
]);
ret
.
d
[
2
]
=
ConvOutputSize
(
inputs
[
0
].
d
[
2
],
kernel_dims_
[
3
],
dilations_
[
1
],
paddings_
[
1
],
strides_
[
1
]);
return
ret
;
}
bool
DeformableConvPlugin
::
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
TensorFormat
format
)
const
TRT_NOEXCEPT
{
return
((
type
==
data_type_
||
type
==
nvinfer1
::
DataType
::
kINT32
)
&&
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
size_t
DeformableConvPlugin
::
getWorkspaceSize
(
int
max_batch_size
)
const
TRT_NOEXCEPT
{
int
c_i
=
input_dim_
[
0
],
h_i
=
input_dim_
[
1
],
w_i
=
input_dim_
[
2
];
int
k_h
=
kernel_dims_
[
2
],
k_w
=
kernel_dims_
[
3
];
int
c_o
=
output_dim_
[
0
],
h_o
=
output_dim_
[
1
],
w_o
=
output_dim_
[
2
];
int
num_bytes
=
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
?
4
:
2
);
size_t
data_col_size
=
static_cast
<
size_t
>
(
c_i
*
k_h
*
k_w
*
im2col_step_
*
h_o
*
w_o
*
num_bytes
);
return
data_col_size
;
}
int
DeformableConvPlugin
::
enqueue
(
int
batch_size
,
const
void
*
const
*
inputs
,
#if IS_TRT_VERSION_LT(8000)
void
**
outputs
,
void
*
workspace
,
#else
void
*
const
*
outputs
,
void
*
workspace
,
#endif
cudaStream_t
stream
)
TRT_NOEXCEPT
{
if
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
)
{
enqueue_impl
<
float
>
(
batch_size
,
inputs
,
outputs
,
workspace
,
stream
);
}
else
if
(
data_type_
==
nvinfer1
::
DataType
::
kHALF
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
enqueue_impl
<
half
>
(
batch_size
,
inputs
,
outputs
,
workspace
,
stream
);
#else
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Current CUDA arch dose not support fp16. Please use fp32 instead."
));
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The DeformableConv TRT Plugin's input type should be float or half."
));
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
template
<
typename
T
>
__device__
T
kFloor
(
T
x
);
template
<
>
__device__
half
kFloor
<
half
>
(
half
x
)
{
return
hfloor
(
x
);
}
template
<
>
__device__
float
kFloor
<
float
>
(
float
x
)
{
return
floor
(
x
);
}
template
<
typename
T
>
__device__
T
DmcnIm2colBilinear
(
const
T
*
bottom_data
,
const
int
data_width
,
const
int
height
,
const
int
width
,
T
h
,
T
w
)
{
int
h_low
=
kFloor
<
T
>
(
h
);
int
w_low
=
kFloor
<
T
>
(
w
);
int
h_high
=
h_low
+
1
;
int
w_high
=
w_low
+
1
;
T
h_low_t
=
h_low
,
w_low_t
=
w_low
,
one
=
1.0
f
;
T
lh
=
h
-
h_low_t
;
T
lw
=
w
-
w_low_t
;
T
hh
=
one
-
lh
,
hw
=
one
-
lw
;
T
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
v1
=
bottom_data
[
h_low
*
data_width
+
w_low
];
T
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
v2
=
bottom_data
[
h_low
*
data_width
+
w_high
];
T
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
v3
=
bottom_data
[
h_high
*
data_width
+
w_low
];
T
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
v4
=
bottom_data
[
h_high
*
data_width
+
w_high
];
T
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
T
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
return
val
;
}
template
<
typename
T
>
__global__
void
ModulatedDeformableIm2colGpuKernel
(
const
int
nthreads
,
const
T
*
data_im
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
num_channels
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
T
*
data_col
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
T
minus_one
=
-
1.0
f
,
height_t
=
height
,
width_t
=
width
;
for
(
size_t
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
const
int
w_col
=
i
%
width_col
;
const
int
h_col
=
(
i
/
width_col
)
%
height_col
;
const
int
b_col
=
(
i
/
width_col
)
/
height_col
%
batch_size
;
const
int
c_im
=
(
i
/
width_col
/
height_col
)
/
batch_size
;
const
int
c_col
=
c_im
*
kernel_h
*
kernel_w
;
const
int
deformable_group_index
=
c_im
/
channel_per_deformable_group
;
const
int
h_in
=
h_col
*
stride_h
-
pad_h
;
const
int
w_in
=
w_col
*
stride_w
-
pad_w
;
T
*
data_col_ptr
=
data_col
+
((
c_col
*
batch_size
+
b_col
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
T
*
data_im_ptr
=
data_im
+
(
b_col
*
num_channels
+
c_im
)
*
height
*
width
;
const
T
*
data_offset_ptr
=
data_offset
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
T
*
data_mask_ptr
=
data_mask
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
for
(
int
i
=
0
;
i
<
kernel_h
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kernel_w
;
++
j
)
{
const
int
data_offset_h_ptr
=
((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_offset_w_ptr
=
((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_mask_hw_ptr
=
((
i
*
kernel_w
+
j
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
T
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
T
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
T
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
T
val
=
0
;
T
h_im_t
=
h_in
+
i
*
dilation_h
,
w_im_t
=
w_in
+
j
*
dilation_w
;
const
T
h_im
=
h_im_t
+
offset_h
;
const
T
w_im
=
w_im_t
+
offset_w
;
if
(
h_im
>
minus_one
&&
w_im
>
minus_one
&&
h_im
<
height_t
&&
w_im
<
width_t
)
{
val
=
DmcnIm2colBilinear
<
T
>
(
data_im_ptr
,
width
,
height
,
width
,
h_im
,
w_im
);
}
*
data_col_ptr
=
val
*
mask
;
data_col_ptr
+=
batch_size
*
height_col
*
width_col
;
}
}
}
}
template
<
typename
T
>
void
gemm_impl
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
T
*
alpha
,
const
T
*
A
,
int
lda
,
const
T
*
B
,
int
ldb
,
const
T
*
beta
,
T
*
C
,
int
ldc
);
template
<
>
void
gemm_impl
<
float
>
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
platform
::
dynload
::
cublasSgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
gemm_impl
<
half
>
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
half
*
alpha
,
const
half
*
A
,
int
lda
,
const
half
*
B
,
int
ldb
,
const
half
*
beta
,
half
*
C
,
int
ldc
)
{
platform
::
dynload
::
cublasHgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
typename
T
>
int
DeformableConvPlugin
::
enqueue_impl
(
int
batch_size
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
const
T
*
input
=
reinterpret_cast
<
const
T
*>
(
inputs
[
0
]);
const
T
*
offset
=
reinterpret_cast
<
const
T
*>
(
inputs
[
1
]);
const
T
*
mask
=
reinterpret_cast
<
const
T
*>
(
inputs
[
2
]);
const
T
*
filter
=
reinterpret_cast
<
const
T
*>
(
weights_
.
values
);
T
*
output
=
reinterpret_cast
<
T
*>
(
outputs
[
0
]);
int
c_i
=
input_dim_
[
0
],
h_i
=
input_dim_
[
1
],
w_i
=
input_dim_
[
2
];
int
k_h
=
kernel_dims_
[
2
],
k_w
=
kernel_dims_
[
3
];
int
c_o
=
output_dim_
[
0
],
h_o
=
output_dim_
[
1
],
w_o
=
output_dim_
[
2
];
int
input_stride
=
c_i
*
h_i
*
w_i
;
int
offset_stride
=
offset_dim_
[
0
]
*
offset_dim_
[
1
]
*
offset_dim_
[
2
];
int
mask_stride
=
mask_dim_
[
0
]
*
mask_dim_
[
1
]
*
mask_dim_
[
2
];
int
output_stride
=
c_o
*
h_o
*
w_o
;
int
M
=
c_o
/
groups_
;
int
N
=
im2col_step_
*
h_o
*
w_o
;
int
K
=
c_i
*
k_h
*
k_w
/
groups_
;
// c_i / deformable_groups
int
channel_per_deformable_group
=
c_i
/
deformable_groups_
;
// c_i * im2col_step * h_o * w_o
int
num_kernels
=
c_i
*
im2col_step_
*
h_o
*
w_o
;
int
blocks
=
NumBlocks
(
num_kernels
);
int
threads
=
kNumCUDAThreads
;
T
alpha
=
static_cast
<
T
>
(
1.0
f
);
T
beta
=
static_cast
<
T
>
(
0.0
f
);
for
(
int
i
=
0
;
i
<
batch_size
/
im2col_step_
;
++
i
)
{
const
T
*
data_im
=
input
+
i
*
im2col_step_
*
input_stride
;
const
T
*
data_offset
=
offset
+
i
*
im2col_step_
*
offset_stride
;
const
T
*
data_mask
=
mask
+
i
*
im2col_step_
*
mask_stride
;
T
*
data_col
=
reinterpret_cast
<
T
*>
(
workspace
);
ModulatedDeformableIm2colGpuKernel
<
T
><<<
blocks
,
threads
,
0
,
stream
>>>
(
num_kernels
,
data_im
,
data_offset
,
data_mask
,
h_i
,
w_i
,
k_h
,
k_w
,
paddings_
[
0
],
paddings_
[
1
],
strides_
[
0
],
strides_
[
1
],
dilations_
[
0
],
dilations_
[
1
],
channel_per_deformable_group
,
im2col_step_
,
c_i
,
deformable_groups_
,
h_o
,
w_o
,
data_col
);
for
(
int
g
=
0
;
g
<
groups_
;
++
g
)
{
const
T
*
weight
=
filter
+
g
*
M
*
K
;
const
T
*
col
=
data_col
+
g
*
K
*
N
;
T
*
out
=
output
+
i
*
im2col_step_
*
output_stride
+
g
*
M
*
N
;
gemm_impl
<
T
>
(
cublasHandle_
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
N
,
M
,
K
,
&
alpha
,
col
,
N
,
weight
,
K
,
&
beta
,
out
,
N
);
}
}
return
0
;
}
int
DeformableConvPlugin
::
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
void
DeformableConvPlugin
::
terminate
()
TRT_NOEXCEPT
{}
size_t
DeformableConvPlugin
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
size_t
serialize_size
=
0
;
serialize_size
+=
SerializedSize
(
data_type_
);
serialize_size
+=
SerializedSize
(
strides_
);
serialize_size
+=
SerializedSize
(
paddings_
);
serialize_size
+=
SerializedSize
(
dilations_
);
serialize_size
+=
SerializedSize
(
groups_
);
serialize_size
+=
SerializedSize
(
deformable_groups_
);
serialize_size
+=
SerializedSize
(
im2col_step_
);
serialize_size
+=
SerializedSize
(
kernel_dims_
);
serialize_size
+=
SerializedSize
(
weights_
.
count
);
int
num_bytes
=
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
?
4
:
2
);
serialize_size
+=
weights_
.
count
*
num_bytes
;
serialize_size
+=
SerializedSize
(
input_dim_
);
serialize_size
+=
SerializedSize
(
offset_dim_
);
serialize_size
+=
SerializedSize
(
mask_dim_
);
serialize_size
+=
SerializedSize
(
output_dim_
);
return
serialize_size
;
}
void
DeformableConvPlugin
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{
SerializeValue
(
&
buffer
,
data_type_
);
SerializeValue
(
&
buffer
,
strides_
);
SerializeValue
(
&
buffer
,
paddings_
);
SerializeValue
(
&
buffer
,
dilations_
);
SerializeValue
(
&
buffer
,
groups_
);
SerializeValue
(
&
buffer
,
deformable_groups_
);
SerializeValue
(
&
buffer
,
im2col_step_
);
SerializeValue
(
&
buffer
,
kernel_dims_
);
SerializeValue
(
&
buffer
,
weights_
.
count
);
serializeFromDevice
(
&
buffer
,
weights_
);
SerializeValue
(
&
buffer
,
input_dim_
);
SerializeValue
(
&
buffer
,
offset_dim_
);
SerializeValue
(
&
buffer
,
mask_dim_
);
SerializeValue
(
&
buffer
,
output_dim_
);
}
void
DeformableConvPlugin
::
destroy
()
TRT_NOEXCEPT
{}
void
DeformableConvPlugin
::
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
{
namespace_
=
std
::
string
(
lib_namespace
);
}
const
char
*
DeformableConvPlugin
::
getPluginNamespace
()
const
TRT_NOEXCEPT
{
return
namespace_
.
c_str
();
}
nvinfer1
::
DataType
DeformableConvPlugin
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_type
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
return
data_type_
;
}
bool
DeformableConvPlugin
::
isOutputBroadcastAcrossBatch
(
int
output_index
,
const
bool
*
input_is_broadcast
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
return
false
;
}
bool
DeformableConvPlugin
::
canBroadcastInputAcrossBatch
(
int
input_index
)
const
TRT_NOEXCEPT
{
return
false
;
}
void
DeformableConvPlugin
::
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
nvinfer1
::
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
{
cublasHandle_
=
cublasContext
;
}
void
DeformableConvPlugin
::
configurePlugin
(
const
nvinfer1
::
Dims
*
input_dims
,
int
nb_inputs
,
const
nvinfer1
::
Dims
*
output_dims
,
int
nb_outputs
,
const
nvinfer1
::
DataType
*
input_types
,
const
nvinfer1
::
DataType
*
output_types
,
const
bool
*
input_is_broadcast
,
const
bool
*
output_is_broadcast
,
nvinfer1
::
PluginFormat
float_format
,
int
max_batct_size
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
nb_inputs
,
3
,
platform
::
errors
::
InvalidArgument
(
"The number of inputs should be equal to 3, but got %d"
,
nb_inputs
));
PADDLE_ENFORCE_EQ
(
nb_outputs
,
1
,
platform
::
errors
::
InvalidArgument
(
"The number of inputs should be equal to 1, but got %d"
,
nb_outputs
));
for
(
int
i
=
0
;
i
<
input_dims
[
0
].
nbDims
;
i
++
)
{
input_dim_
.
push_back
(
input_dims
[
0
].
d
[
i
]);
}
for
(
int
i
=
0
;
i
<
input_dims
[
1
].
nbDims
;
i
++
)
{
offset_dim_
.
push_back
(
input_dims
[
1
].
d
[
i
]);
}
for
(
int
i
=
0
;
i
<
input_dims
[
2
].
nbDims
;
i
++
)
{
mask_dim_
.
push_back
(
input_dims
[
2
].
d
[
i
]);
}
for
(
int
i
=
0
;
i
<
output_dims
[
0
].
nbDims
;
i
++
)
{
output_dim_
.
push_back
(
output_dims
[
0
].
d
[
i
]);
}
}
nvinfer1
::
IPluginV2Ext
*
DeformableConvPlugin
::
clone
()
const
TRT_NOEXCEPT
{
return
new
DeformableConvPlugin
(
data_type_
,
weights_
,
kernel_dims_
,
strides_
,
paddings_
,
dilations_
,
groups_
,
deformable_groups_
,
im2col_step_
,
input_dim_
,
offset_dim_
,
mask_dim_
,
output_dim_
);
}
DeformableConvPluginCreator
::
DeformableConvPluginCreator
()
TRT_NOEXCEPT
{}
void
DeformableConvPluginCreator
::
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
{
namespace_
=
std
::
string
(
lib_namespace
);
}
const
char
*
DeformableConvPluginCreator
::
getPluginNamespace
()
const
TRT_NOEXCEPT
{
return
namespace_
.
c_str
();
}
const
char
*
DeformableConvPluginCreator
::
getPluginName
()
const
TRT_NOEXCEPT
{
return
"deformable_conv_plugin"
;
}
const
char
*
DeformableConvPluginCreator
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
DeformableConvPluginCreator
::
getFieldNames
()
TRT_NOEXCEPT
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2Ext
*
DeformableConvPluginCreator
::
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
{
const
nvinfer1
::
PluginField
*
fields
=
fc
->
fields
;
nvinfer1
::
DataType
data_type
;
std
::
vector
<
int
>
strides
,
paddings
,
dilations
,
kernel_dims
;
nvinfer1
::
Weights
weights
;
int
groups
=
-
1
;
int
deformable_groups
=
-
1
;
int
im2col_step
=
-
1
;
for
(
int
i
=
0
;
i
<
fc
->
nbFields
;
++
i
)
{
const
std
::
string
field_name
(
fc
->
fields
[
i
].
name
);
if
(
field_name
.
compare
(
"data_type"
)
==
0
)
{
data_type
=
*
static_cast
<
const
nvinfer1
::
DataType
*>
(
fc
->
fields
[
i
].
data
);
}
else
if
(
field_name
.
compare
(
"strides"
))
{
const
int
length
=
fc
->
fields
[
i
].
length
;
const
int
*
data
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
strides
.
insert
(
strides
.
end
(),
data
,
data
+
length
);
}
else
if
(
field_name
.
compare
(
"paddings"
))
{
const
int
length
=
fc
->
fields
[
i
].
length
;
const
int
*
data
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
paddings
.
insert
(
paddings
.
end
(),
data
,
data
+
length
);
}
else
if
(
field_name
.
compare
(
"dilations"
))
{
const
int
length
=
fc
->
fields
[
i
].
length
;
const
int
*
data
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
dilations
.
insert
(
dilations
.
end
(),
data
,
data
+
length
);
}
else
if
(
field_name
.
compare
(
"groups"
))
{
groups
=
*
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
}
else
if
(
field_name
.
compare
(
"deformable_groups"
))
{
deformable_groups
=
*
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
}
else
if
(
field_name
.
compare
(
"im2col_step"
))
{
im2col_step
=
*
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
}
else
if
(
field_name
.
compare
(
"kernel_dims"
))
{
const
int
length
=
fc
->
fields
[
i
].
length
;
const
int
*
data
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
kernel_dims
.
insert
(
kernel_dims
.
end
(),
data
,
data
+
length
);
}
else
if
(
field_name
.
compare
(
"weights"
))
{
weights
.
count
=
fc
->
fields
[
i
].
length
;
weights
.
values
=
fc
->
fields
[
i
].
data
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unknown plugin field name [%s] in the DeformableConv TRT Plugin."
,
field_name
));
}
}
weights
.
type
=
data_type
;
return
new
DeformableConvPlugin
(
data_type
,
weights
,
kernel_dims
,
strides
,
paddings
,
dilations
,
groups
,
deformable_groups
,
im2col_step
);
}
nvinfer1
::
IPluginV2Ext
*
DeformableConvPluginCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
{
auto
plugin
=
new
DeformableConvPlugin
(
serial_data
,
serial_length
);
plugin
->
setPluginNamespace
(
namespace_
.
c_str
());
return
plugin
;
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h
0 → 100644
浏览文件 @
8c3decd8
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstdio>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
class
DeformableConvPlugin
:
public
nvinfer1
::
IPluginV2Ext
{
public:
explicit
DeformableConvPlugin
(
const
nvinfer1
::
DataType
data_type
,
const
nvinfer1
::
Weights
&
weights
,
const
std
::
vector
<
int
>&
kernel_dims
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
int
groups
,
const
int
deformable_groups
,
const
int
im2col_step
);
explicit
DeformableConvPlugin
(
const
nvinfer1
::
DataType
data_type
,
const
nvinfer1
::
Weights
&
weights
,
const
std
::
vector
<
int
>&
kernel_dims
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
int
groups
,
const
int
deformable_groups
,
const
int
im2col_step
,
const
std
::
vector
<
int
>&
input_dim
,
const
std
::
vector
<
int
>&
offset_dim
,
const
std
::
vector
<
int
>&
mask_dim
,
const
std
::
vector
<
int
>&
output_dim
);
DeformableConvPlugin
(
const
void
*
data
,
size_t
length
);
~
DeformableConvPlugin
()
override
;
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nb_input_dims
)
TRT_NOEXCEPT
override
;
bool
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
TensorFormat
format
)
const
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
int
max_batch_size
)
const
TRT_NOEXCEPT
override
;
#if IS_TRT_VERSION_LT(8000)
int
enqueue
(
int
batch_size
,
const
void
*
const
*
inputs
,
void
**
outputs
,
#else
int
enqueue
(
int
batch_size
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
#endif
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
int
initialize
()
TRT_NOEXCEPT
override
;
void
terminate
()
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
;
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
;
void
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
override
;
const
char
*
getPluginNamespace
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_type
,
int
nb_inputs
)
const
TRT_NOEXCEPT
override
;
bool
isOutputBroadcastAcrossBatch
(
int
output_index
,
const
bool
*
input_is_broadcast
,
int
nb_inputs
)
const
TRT_NOEXCEPT
override
;
bool
canBroadcastInputAcrossBatch
(
int
input_index
)
const
TRT_NOEXCEPT
override
;
void
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
nvinfer1
::
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
Dims
*
input_dims
,
int
nb_inputs
,
const
nvinfer1
::
Dims
*
output_dims
,
int
nb_outputs
,
const
nvinfer1
::
DataType
*
input_types
,
const
nvinfer1
::
DataType
*
output_types
,
const
bool
*
input_is_broadcast
,
const
bool
*
output_is_broadcast
,
nvinfer1
::
PluginFormat
float_format
,
int
max_batct_size
)
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2Ext
*
clone
()
const
TRT_NOEXCEPT
override
;
private:
template
<
typename
T
>
int
enqueue_impl
(
int
batch_size
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
);
nvinfer1
::
Weights
copyToDevice
(
const
void
*
hostData
,
size_t
count
);
void
serializeFromDevice
(
void
**
hostBuffer
,
const
nvinfer1
::
Weights
&
deviceWeights
)
const
;
nvinfer1
::
Weights
deserializeToDevice
(
const
void
**
hostBuffer
,
size_t
count
);
nvinfer1
::
DataType
data_type_
;
nvinfer1
::
Weights
weights_
;
std
::
vector
<
int
>
kernel_dims_
;
std
::
vector
<
int
>
strides_
;
std
::
vector
<
int
>
paddings_
;
std
::
vector
<
int
>
dilations_
;
int
groups_
;
int
deformable_groups_
;
int
im2col_step_
;
std
::
string
namespace_
;
std
::
vector
<
int
>
input_dim_
;
std
::
vector
<
int
>
offset_dim_
;
std
::
vector
<
int
>
mask_dim_
;
std
::
vector
<
int
>
output_dim_
;
cublasHandle_t
cublasHandle_
;
};
class
DeformableConvPluginCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
DeformableConvPluginCreator
();
~
DeformableConvPluginCreator
()
override
=
default
;
void
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
override
;
const
char
*
getPluginNamespace
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2Ext
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2Ext
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
;
private:
std
::
string
namespace_
;
nvinfer1
::
PluginFieldCollection
field_collection_
;
};
REGISTER_TRT_PLUGIN_V2
(
DeformableConvPluginCreator
);
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tests/infer_ut/test_ppyolov2_r50vd.cc
浏览文件 @
8c3decd8
...
...
@@ -73,7 +73,7 @@ TEST(tensorrt_tester_ppyolov2_r50vd, multi_thread2_trt_fp32_bz1) {
FLAGS_modeldir
+
"/model.pdiparams"
);
config
.
EnableUseGpu
(
100
,
0
);
config
.
EnableTensorRtEngine
(
1
<<
2
0
,
2
,
10
,
paddle_infer
::
PrecisionType
::
kFloat32
,
false
,
false
);
1
<<
2
8
,
2
,
10
,
paddle_infer
::
PrecisionType
::
kFloat32
,
false
,
false
);
LOG
(
INFO
)
<<
config
.
Summary
();
// get groudtruth by disbale ir
paddle_infer
::
services
::
PredictorPool
pred_pool_no_ir
(
config_no_ir
,
1
);
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_deformable_conv.py
0 → 100644
浏览文件 @
8c3decd8
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
trt_layer_auto_scan_test
import
TrtLayerAutoScanTest
,
SkipReasons
from
program_config
import
TensorConfig
,
ProgramConfig
import
numpy
as
np
import
paddle.inference
as
paddle_infer
from
functools
import
partial
from
typing
import
Optional
,
List
,
Callable
,
Dict
,
Any
,
Set
import
unittest
class
TrtConvertDeformableConvTest
(
TrtLayerAutoScanTest
):
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
inputs
=
program_config
.
inputs
weights
=
program_config
.
weights
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
if
inputs
[
'input_data'
].
shape
[
1
]
!=
weights
[
'filter_data'
].
shape
[
1
]
*
attrs
[
0
][
'groups'
]:
return
False
return
True
def
sample_program_configs
(
self
):
def
compute_output_size
(
input_size
:
List
[
int
],
kernel_sizes
:
List
[
int
],
attrs
:
List
[
Dict
[
str
,
Any
]]):
strides
=
attrs
[
0
][
'strides'
]
paddings
=
attrs
[
0
][
'paddings'
]
dilations
=
attrs
[
0
][
'dilations'
]
output_size
=
[]
for
i
,
k
,
s
,
p
,
d
in
zip
(
input_size
,
kernel_sizes
,
strides
,
paddings
,
dilations
):
k
=
d
*
(
k
-
1
)
+
1
output_size
.
append
((
i
+
2
*
p
-
k
)
//
s
+
1
)
return
output_size
def
generate_input1
(
batch
:
int
,
input_size
:
List
[
int
],
kernel_sizes
:
List
[
int
],
attrs
:
List
[
Dict
[
str
,
Any
]]):
return
np
.
random
.
random
([
batch
,
3
]
+
input_size
).
astype
(
np
.
float32
)
def
generate_offset1
(
batch
:
int
,
input_size
:
List
[
int
],
kernel_sizes
:
List
[
int
],
attrs
:
List
[
Dict
[
str
,
Any
]]):
output_size
=
compute_output_size
(
input_size
,
kernel_sizes
,
attrs
)
return
np
.
random
.
random
([
batch
,
2
*
np
.
prod
(
kernel_sizes
)]
+
output_size
).
astype
(
np
.
float32
)
def
generate_mask1
(
batch
:
int
,
input_size
:
List
[
int
],
kernel_sizes
:
List
[
int
],
attrs
:
List
[
Dict
[
str
,
Any
]]):
output_size
=
compute_output_size
(
input_size
,
kernel_sizes
,
attrs
)
return
np
.
random
.
random
([
batch
,
np
.
prod
(
kernel_sizes
)]
+
output_size
).
astype
(
np
.
float32
)
def
generate_filter1
(
batch
:
int
,
input_size
:
List
[
int
],
kernel_sizes
:
List
[
int
],
attrs
:
List
[
Dict
[
str
,
Any
]]):
return
np
.
random
.
random
([
6
,
3
]
+
kernel_sizes
).
astype
(
np
.
float32
)
for
batch
in
[
1
,
]:
for
input_size
in
[[
32
,
32
]]:
for
kernel_sizes
in
[[
3
,
3
]]:
for
strides
in
[[
1
,
1
],
[
2
,
2
]]:
for
paddings
in
[[
1
,
1
],
[
0
,
2
]]:
for
groups
in
[
1
,
]:
for
dilations
in
[[
1
,
1
],
[
2
,
2
]]:
dics
=
[{
"strides"
:
strides
,
"paddings"
:
paddings
,
"groups"
:
groups
,
"dilations"
:
dilations
,
"deformable_groups"
:
1
,
"im2col_step"
:
1
}]
ops_config
=
[{
"op_type"
:
"deformable_conv"
,
"op_inputs"
:
{
"Input"
:
[
"input_data"
],
"Offset"
:
[
"offset_data"
],
"Mask"
:
[
"mask_data"
],
"Filter"
:
[
"filter_data"
]
},
"op_outputs"
:
{
"Output"
:
[
"output_data"
]
},
"op_attrs"
:
dics
[
0
]
}]
ops
=
self
.
generate_op_config
(
ops_config
)
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{
"filter_data"
:
TensorConfig
(
data_gen
=
partial
(
generate_filter1
,
batch
,
input_size
,
kernel_sizes
,
dics
))
},
inputs
=
{
"input_data"
:
TensorConfig
(
data_gen
=
partial
(
generate_input1
,
batch
,
input_size
,
kernel_sizes
,
dics
)),
"offset_data"
:
TensorConfig
(
data_gen
=
partial
(
generate_offset1
,
batch
,
input_size
,
kernel_sizes
,
dics
)),
"mask_data"
:
TensorConfig
(
data_gen
=
partial
(
generate_mask1
,
batch
,
input_size
,
kernel_sizes
,
dics
))
},
outputs
=
[
"output_data"
])
yield
program_config
def
sample_predictor_configs
(
self
,
program_config
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
clear_dynamic_shape
():
self
.
dynamic_shape
.
min_input_shape
=
{}
self
.
dynamic_shape
.
max_input_shape
=
{}
self
.
dynamic_shape
.
opt_input_shape
=
{}
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
# TODO: This is just the example, need to be fixed.
if
len
(
attrs
[
0
][
'paddings'
])
==
4
:
return
1
,
2
else
:
return
1
,
2
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
# for static_shape
clear_dynamic_shape
()
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
False
),
1e-5
def
add_skip_trt_case
(
self
):
def
teller1
(
program_config
,
predictor_config
):
if
len
(
program_config
.
ops
[
0
].
attrs
[
"strides"
])
!=
2
:
return
False
return
True
self
.
add_skip_case
(
teller1
,
SkipReasons
.
TRT_NOT_IMPLEMENTED
,
"In deformable conv, length of Attr(strides) should be 2."
)
def
test
(
self
):
self
.
trt_param
.
workspace_size
=
1
<<
28
self
.
add_skip_trt_case
()
self
.
run_test
()
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/ir/inference/test_trt_deformable_conv.py
0 → 100644
浏览文件 @
8c3decd8
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
inference_pass_test
import
InferencePassTest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.core
import
PassVersionChecker
from
paddle.fluid.core
import
AnalysisConfig
class
TRTDeformableConvTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
input
=
fluid
.
data
(
name
=
'input'
,
shape
=
self
.
input_size
,
dtype
=
self
.
dtype
)
offset
=
fluid
.
data
(
name
=
'offset'
,
shape
=
self
.
offset_size
,
dtype
=
self
.
dtype
)
mask
=
fluid
.
data
(
name
=
'mask'
,
shape
=
self
.
mask_size
,
dtype
=
self
.
dtype
)
output
=
fluid
.
layers
.
deformable_conv
(
input
,
offset
,
mask
,
self
.
num_filters
,
self
.
filter_size
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilations
,
groups
=
self
.
groups
,
deformable_groups
=
self
.
deformable_groups
,
im2col_step
=
self
.
im2col_step
)
self
.
feeds
=
{
'input'
:
np
.
random
.
random
(
self
.
input_size
).
astype
(
self
.
dtype
),
'offset'
:
np
.
random
.
random
(
self
.
offset_size
).
astype
(
self
.
dtype
),
'mask'
:
np
.
random
.
random
(
self
.
mask_size
).
astype
(
self
.
dtype
)
}
self
.
enable_trt
=
True
dtype
=
AnalysisConfig
.
Precision
.
Float32
if
self
.
dtype
==
'float16'
:
dtype
=
AnalysisConfig
.
Precision
.
Half
self
.
trt_parameters
=
TRTDeformableConvTest
.
TensorRTParam
(
1
<<
30
,
self
.
bs
,
0
,
dtype
,
False
,
False
)
self
.
fetch_list
=
[
output
]
def
set_params
(
self
):
self
.
groups
=
1
self
.
padding
=
[
1
,
1
]
self
.
dilations
=
[
1
,
1
]
self
.
stride
=
[
1
,
1
]
self
.
im2col_step
=
1
self
.
deformable_groups
=
1
self
.
bs
=
2
self
.
input_size
=
[
self
.
bs
,
8
,
4
,
4
]
self
.
num_filters
=
8
self
.
filter_size
=
3
offset_c
=
2
*
self
.
deformable_groups
*
self
.
filter_size
*
self
.
filter_size
mask_c
=
self
.
deformable_groups
*
self
.
filter_size
*
self
.
filter_size
self
.
offset_size
=
[
self
.
input_size
[
0
],
offset_c
,
self
.
input_size
[
2
],
self
.
input_size
[
3
]
]
self
.
mask_size
=
[
self
.
input_size
[
0
],
mask_c
,
self
.
input_size
[
2
],
self
.
input_size
[
3
]
]
self
.
dtype
=
'float32'
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录