Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
229bae81
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
229bae81
编写于
10月 26, 2021
作者:
F
feng_shuai
提交者:
GitHub
10月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Pool3d 2.0 (#36545)
上级
cea1ba88
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
1248 addition
and
1 deletion
+1248
-1
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+1
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/convert/pool3d_op.cc
paddle/fluid/inference/tensorrt/convert/pool3d_op.cc
+228
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+2
-1
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/plugin/pool3d_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/pool3d_op_plugin.cu
+375
-0
paddle/fluid/inference/tensorrt/plugin/pool3d_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/pool3d_op_plugin.h
+244
-0
paddle/fluid/operators/math/pooling.cu
paddle/fluid/operators/math/pooling.cu
+48
-0
paddle/fluid/operators/math/pooling.h
paddle/fluid/operators/math/pooling.h
+14
-0
python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
.../paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_pool3d_op.py
.../fluid/tests/unittests/ir/inference/test_trt_pool3d_op.py
+332
-0
未找到文件。
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
229bae81
...
...
@@ -1415,6 +1415,7 @@ USE_TRT_CONVERTER(tile);
USE_TRT_CONVERTER
(
conv3d
);
USE_TRT_CONVERTER
(
conv3d_transpose
);
USE_TRT_CONVERTER
(
mish
);
USE_TRT_CONVERTER
(
pool3d
)
#endif
namespace
paddle_infer
{
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
229bae81
...
...
@@ -19,6 +19,7 @@ nv_library(tensorrt_converter
conv3d_op.cc
mish_op.cc
nearest_interp_v2_op.cc
pool3d_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry
)
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
...
...
paddle/fluid/inference/tensorrt/convert/pool3d_op.cc
0 → 100644
浏览文件 @
229bae81
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/pool3d_op_plugin.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
proto
{
class
OpDesc
;
}
// namespace proto
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
inline
void
DealCeilMode
(
const
nvinfer1
::
Dims
&
input_shape
,
std
::
vector
<
int
>
ksize
,
std
::
vector
<
int
>
strides
,
std
::
vector
<
int
>
paddings
,
nvinfer1
::
DimsCHW
*
pre_pad
,
nvinfer1
::
DimsCHW
*
post_pad
,
int
input_dims
)
{
int
input_depth
=
input_shape
.
d
[
input_dims
-
3
];
int
input_height
=
input_shape
.
d
[
input_dims
-
2
];
int
input_width
=
input_shape
.
d
[
input_dims
-
1
];
int
floor_d_output_size
=
(
input_depth
-
ksize
[
0
]
+
2
*
paddings
[
0
])
/
strides
[
0
]
+
1
;
int
ceil_d_output_size
=
(
input_depth
-
ksize
[
0
]
+
2
*
paddings
[
0
]
+
strides
[
0
]
-
1
)
/
strides
[
0
]
+
1
;
int
floor_h_output_size
=
(
input_height
-
ksize
[
1
]
+
2
*
paddings
[
1
])
/
strides
[
1
]
+
1
;
int
ceil_h_output_size
=
(
input_height
-
ksize
[
1
]
+
2
*
paddings
[
1
]
+
strides
[
1
]
-
1
)
/
strides
[
1
]
+
1
;
int
floor_w_output_size
=
(
input_width
-
ksize
[
2
]
+
2
*
paddings
[
2
])
/
strides
[
2
]
+
1
;
int
ceil_w_output_size
=
(
input_width
-
ksize
[
2
]
+
2
*
paddings
[
2
]
+
strides
[
2
]
-
1
)
/
strides
[
2
]
+
1
;
if
(
floor_d_output_size
!=
ceil_d_output_size
)
{
post_pad
->
c
()
=
strides
[
0
]
-
1
;
}
if
(
floor_h_output_size
!=
ceil_h_output_size
)
{
post_pad
->
h
()
=
strides
[
1
]
-
1
;
}
if
(
floor_w_output_size
!=
ceil_w_output_size
)
{
post_pad
->
w
()
=
strides
[
2
]
-
1
;
}
}
class
Pool3dOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert a fluid pool3d op to tensorrt pool3d layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
nvinfer1
::
Dims
input_shape
=
input1
->
getDimensions
();
int
input_dims
=
input_shape
.
nbDims
;
bool
global_pooling
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"global_pooling"
));
std
::
string
pool_type
=
BOOST_GET_CONST
(
std
::
string
,
op_desc
.
GetAttr
(
"pooling_type"
));
std
::
vector
<
int
>
ksize
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"ksize"
));
std
::
vector
<
int
>
strides
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"strides"
));
std
::
vector
<
int
>
paddings
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"paddings"
));
bool
exclusive
=
op_desc
.
HasAttr
(
"exclusive"
)
?
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"exclusive"
))
:
true
;
bool
ceil_mode
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"ceil_mode"
));
bool
adaptive
=
false
;
if
(
op_desc
.
HasAttr
(
"adaptive"
))
adaptive
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"adaptive"
));
std
::
string
padding_algorithm
=
"EXPLICIT"
;
if
(
op_desc
.
HasAttr
(
"padding_algorithm"
))
padding_algorithm
=
BOOST_GET_CONST
(
std
::
string
,
op_desc
.
GetAttr
(
"padding_algorithm"
));
if
(
padding_algorithm
==
"VALID"
||
padding_algorithm
==
"SAME"
)
{
std
::
fill
(
paddings
.
begin
(),
paddings
.
end
(),
0
);
}
nvinfer1
::
PoolingType
nv_pool_type
=
nvinfer1
::
PoolingType
::
kMAX
;
nvinfer1
::
ReduceOperation
reduce_operation
=
nvinfer1
::
ReduceOperation
::
kMAX
;
plugin
::
Pool3DPlugin
::
Pool3DType
plugin_pool_type
=
plugin
::
Pool3DPlugin
::
Pool3DType
::
max
;
if
(
pool_type
==
"max"
)
{
nv_pool_type
=
nvinfer1
::
PoolingType
::
kMAX
;
reduce_operation
=
nvinfer1
::
ReduceOperation
::
kMAX
;
plugin_pool_type
=
plugin
::
Pool3DPlugin
::
Pool3DType
::
max
;
}
else
if
(
pool_type
==
"avg"
)
{
nv_pool_type
=
nvinfer1
::
PoolingType
::
kAVERAGE
;
reduce_operation
=
nvinfer1
::
ReduceOperation
::
kAVG
;
plugin_pool_type
=
plugin
::
Pool3DPlugin
::
Pool3DType
::
avg
;
}
nvinfer1
::
DimsCHW
nv_ksize
(
ksize
[
0
],
ksize
[
1
],
ksize
[
2
]);
nvinfer1
::
DimsCHW
nv_strides
(
strides
[
0
],
strides
[
1
],
strides
[
2
]);
nvinfer1
::
DimsCHW
nv_paddings
(
paddings
[
0
],
paddings
[
1
],
paddings
[
2
]);
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
op_desc
.
HasAttr
(
"enable_int8"
))
{
CHECK
(
op_desc
.
HasAttr
(
"X_scale"
));
float
input_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"X_scale"
));
engine_
->
SetTensorDynamicRange
(
input1
,
input_scale
);
}
if
(
engine_
->
with_dynamic_shape
())
{
if
(
!
adaptive
&&
!
global_pooling
&&
!
ceil_mode
)
{
auto
*
pool_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
PoolingNd
,
*
input1
,
nv_pool_type
,
nv_ksize
);
pool_layer
->
setStrideNd
(
nv_strides
);
pool_layer
->
setPaddingNd
(
nv_paddings
);
pool_layer
->
setAverageCountExcludesPadding
(
exclusive
);
layer
=
pool_layer
;
}
else
if
(
global_pooling
)
{
auto
*
reduce_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Reduce
,
*
input1
,
reduce_operation
,
28
,
true
);
layer
=
reduce_layer
;
}
else
{
plugin
::
Pool3DPluginDynamic
*
plugin
=
new
plugin
::
Pool3DPluginDynamic
(
ceil_mode
,
pool_type
,
adaptive
,
ksize
,
strides
,
paddings
,
global_pooling
);
layer
=
engine_
->
AddDynamicPlugin
(
&
input1
,
1
,
plugin
);
}
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
layer
->
setName
((
"pool3d (Output: "
+
output_name
+
")"
).
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
}
return
;
}
if
(
global_pooling
==
true
)
{
auto
*
reduce_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Reduce
,
*
input1
,
reduce_operation
,
14
,
true
);
layer
=
reduce_layer
;
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
layer
->
setName
((
"pool3d (Output: "
+
output_name
+
")"
).
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
}
return
;
}
if
(
!
adaptive
)
{
if
(
!
ceil_mode
)
{
auto
*
pool_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
PoolingNd
,
*
input1
,
nv_pool_type
,
nv_ksize
);
PADDLE_ENFORCE_NOT_NULL
(
pool_layer
,
platform
::
errors
::
Fatal
(
"trt pool layer in converter could not be created."
));
pool_layer
->
setStrideNd
(
nv_strides
);
pool_layer
->
setPaddingNd
(
nv_paddings
);
pool_layer
->
setAverageCountExcludesPadding
(
exclusive
);
layer
=
pool_layer
;
}
else
{
std
::
vector
<
int
>
input_shape_v
;
for
(
int
i
=
0
;
i
<
input_dims
;
i
++
)
{
input_shape_v
.
push_back
(
input_shape
.
d
[
i
]);
}
plugin
::
Pool3DPlugin
*
plugin
=
new
plugin
::
Pool3DPlugin
(
ceil_mode
,
plugin_pool_type
,
adaptive
,
ksize
,
strides
,
paddings
,
input_shape_v
);
auto
*
pool_layer
=
engine_
->
AddPluginV2Ext
(
&
input1
,
1
,
plugin
);
PADDLE_ENFORCE_NOT_NULL
(
pool_layer
,
platform
::
errors
::
Fatal
(
"trt pool3d plugin layer in converter could not be created."
));
layer
=
pool_layer
;
}
}
else
{
// Average pooling needs to exclude the padding pixels from the average
// mean.
// It is not supported well by TRT, we use a plugin here.
std
::
vector
<
int
>
input_shape_v
;
for
(
int
i
=
0
;
i
<
input_dims
;
i
++
)
{
input_shape_v
.
push_back
(
input_shape
.
d
[
i
]);
}
plugin
::
Pool3DPlugin
*
plugin
=
new
plugin
::
Pool3DPlugin
(
ceil_mode
,
plugin_pool_type
,
adaptive
,
ksize
,
strides
,
paddings
,
input_shape_v
);
auto
*
pool_layer
=
engine_
->
AddPluginV2Ext
(
&
input1
,
1
,
plugin
);
PADDLE_ENFORCE_NOT_NULL
(
pool_layer
,
platform
::
errors
::
Fatal
(
"trt pool3d plugin layer in converter could not be created."
));
layer
=
pool_layer
;
}
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"pool3d"
,
{
output_name
},
test_mode
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
USE_OP
(
pool3d
);
REGISTER_TRT_OP_CONVERTER
(
pool3d
,
Pool3dOpConverter
);
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
229bae81
...
...
@@ -142,7 +142,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"conv3d"
,
"conv3d_transpose"
,
"mish"
,
"nearest_interp_v2"
};
"nearest_interp_v2"
,
"pool3d"
};
};
bool
OpTeller
::
Tell
(
const
framework
::
ir
::
Node
*
node
,
bool
use_no_calib_int8
,
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
229bae81
...
...
@@ -10,6 +10,7 @@ nv_library(tensorrt_plugin
roi_align_op_plugin.cu
gather_nd_op_plugin.cu
mish_op_plugin.cu
pool3d_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
nv_test
(
test_split_plugin SRCS test_split_plugin.cc DEPS
...
...
paddle/fluid/inference/tensorrt/plugin/pool3d_op_plugin.cu
0 → 100644
浏览文件 @
229bae81
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, softwarepool
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/pool3d_op_plugin.h"
#include "paddle/fluid/operators/math/pooling.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
size_t
Pool3DPlugin
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
return
getBaseSerializationSize
()
+
SerializedSize
(
ceil_mode_
)
+
SerializedSize
(
pool3d_type_
)
+
SerializedSize
(
adaptive_
)
+
SerializedSize
(
ksize_
)
+
SerializedSize
(
strides_
)
+
SerializedSize
(
paddings_
)
+
SerializedSize
(
input_shape_
)
+
SerializedSize
(
output_shape_
);
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
void
Pool3DPlugin
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{
serializeBase
(
buffer
);
SerializeValue
(
&
buffer
,
ceil_mode_
);
SerializeValue
(
&
buffer
,
pool3d_type_
);
SerializeValue
(
&
buffer
,
adaptive_
);
SerializeValue
(
&
buffer
,
ksize_
);
SerializeValue
(
&
buffer
,
strides_
);
SerializeValue
(
&
buffer
,
paddings_
);
SerializeValue
(
&
buffer
,
input_shape_
);
SerializeValue
(
&
buffer
,
output_shape_
);
}
Pool3DPlugin
*
Pool3DPlugin
::
clone
()
const
TRT_NOEXCEPT
{
return
new
Pool3DPlugin
(
ceil_mode_
,
pool3d_type_
,
adaptive_
,
ksize_
,
strides_
,
paddings_
,
input_shape_
);
}
const
char
*
Pool3DPlugin
::
getPluginType
()
const
TRT_NOEXCEPT
{
return
"pool3d_plugin"
;
}
int
Pool3DPlugin
::
getNbOutputs
()
const
TRT_NOEXCEPT
{
return
1
;
}
int
Pool3DPlugin
::
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
nvinfer1
::
DataType
Pool3DPlugin
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
return
input_types
[
0
];
}
void
Pool3DPlugin
::
destroy
()
TRT_NOEXCEPT
{
delete
this
;
}
nvinfer1
::
Dims
Pool3DPlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
nbInputs
,
1
,
platform
::
errors
::
InvalidArgument
(
"The Pool3D Plugin only has one input, so the nbInputs "
"value should be 1, but get %d."
,
nbInputs
));
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The Pool3D Plugin only has one input, so "
"the index value should be 0, but get %d."
,
index
));
PADDLE_ENFORCE_EQ
(
inputDims
[
0
].
nbDims
,
4
,
platform
::
errors
::
InvalidArgument
(
"The Pool3D Plugin only has four Dimensions, so the "
"nbDims value should be 4, but get %d."
,
inputDims
[
0
].
nbDims
));
nvinfer1
::
Dims
const
&
input_dims
=
inputDims
[
0
];
nvinfer1
::
Dims
output_dims
=
input_dims
;
output_dims
.
d
[
1
]
=
output_shape_
[
1
];
output_dims
.
d
[
2
]
=
output_shape_
[
2
];
output_dims
.
d
[
3
]
=
output_shape_
[
3
];
return
output_dims
;
}
int
Pool3DPlugin
::
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
#if IS_TRT_VERSION_LT(8000)
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
#else
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
#endif
int
input_size
=
0
;
float
const
*
idata
=
reinterpret_cast
<
float
const
*>
(
inputs
[
0
]);
float
*
const
*
odatas
=
reinterpret_cast
<
float
*
const
*>
(
outputs
);
std
::
vector
<
int
>
input_shape
=
input_shape_
;
std
::
vector
<
int
>
output_shape
=
output_shape_
;
input_shape
.
insert
(
input_shape
.
begin
(),
batchSize
);
output_shape
.
insert
(
output_shape
.
begin
(),
batchSize
);
if
(
pool3d_type_
==
Pool3DType
::
max
)
{
paddle
::
operators
::
math
::
MaxPool
<
float
>
pool_process
;
paddle
::
operators
::
math
::
Pool3dDirectCUDAFunctor
<
paddle
::
operators
::
math
::
MaxPool
<
float
>
,
float
>
pool3d_forward
;
pool3d_forward
(
idata
,
input_shape
,
output_shape
,
ksize_
,
strides_
,
paddings_
,
true
,
adaptive_
,
odatas
[
0
],
stream
,
pool_process
);
}
else
if
(
pool3d_type_
==
Pool3DType
::
avg
)
{
paddle
::
operators
::
math
::
AvgPool
<
float
>
pool_process
;
paddle
::
operators
::
math
::
Pool3dDirectCUDAFunctor
<
paddle
::
operators
::
math
::
AvgPool
<
float
>
,
float
>
pool3d_forward
;
pool3d_forward
(
idata
,
input_shape
,
output_shape
,
ksize_
,
strides_
,
paddings_
,
true
,
adaptive_
,
odatas
[
0
],
stream
,
pool_process
);
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
// Dynamic Plugin below.
Pool3DPluginDynamic
::
Pool3DPluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
)
{
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
ceil_mode_
);
const
char
*
pool3d_type
;
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
pool3d_type
);
pool3d_type_
=
std
::
string
(
pool3d_type
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
adaptive_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
ksize_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
strides_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
paddings_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
is_global_
);
}
nvinfer1
::
IPluginV2DynamicExt
*
Pool3DPluginDynamic
::
clone
()
const
TRT_NOEXCEPT
{
return
new
Pool3DPluginDynamic
(
ceil_mode_
,
pool3d_type_
,
adaptive_
,
ksize_
,
strides_
,
paddings_
,
is_global_
);
}
const
char
*
Pool3DPluginDynamic
::
getPluginType
()
const
TRT_NOEXCEPT
{
return
"pool3d_plugin_dynamic"
;
}
int
Pool3DPluginDynamic
::
getNbOutputs
()
const
TRT_NOEXCEPT
{
return
1
;
}
int
Pool3DPluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
void
Pool3DPluginDynamic
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
{}
size_t
Pool3DPluginDynamic
::
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
{
return
0
;
}
size_t
Pool3DPluginDynamic
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
return
SerializedSize
(
ceil_mode_
)
+
SerializedSize
(
pool3d_type_
.
c_str
())
+
SerializedSize
(
adaptive_
)
+
SerializedSize
(
ksize_
)
+
SerializedSize
(
strides_
)
+
SerializedSize
(
paddings_
)
+
SerializedSize
(
is_global_
);
}
void
Pool3DPluginDynamic
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{
SerializeValue
(
&
buffer
,
ceil_mode_
);
SerializeValue
(
&
buffer
,
pool3d_type_
.
c_str
());
SerializeValue
(
&
buffer
,
adaptive_
);
SerializeValue
(
&
buffer
,
ksize_
);
SerializeValue
(
&
buffer
,
strides_
);
SerializeValue
(
&
buffer
,
paddings_
);
SerializeValue
(
&
buffer
,
is_global_
);
}
nvinfer1
::
DimsExprs
Pool3DPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
nb_inputs
,
1
,
platform
::
errors
::
InvalidArgument
(
"The Split plugin should be only one input."
));
PADDLE_ENFORCE_EQ
(
inputs
[
0
].
d
[
1
]
->
isConstant
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The channel dimension should be "
"static, but we found it's dynamic."
));
nvinfer1
::
DimsExprs
output
(
inputs
[
0
]);
if
(
is_global_
)
{
output
.
d
[
2
]
=
expr_builder
.
constant
(
1
);
output
.
d
[
3
]
=
expr_builder
.
constant
(
1
);
output
.
d
[
4
]
=
expr_builder
.
constant
(
1
);
return
output
;
}
if
(
adaptive_
)
{
output
.
d
[
2
]
=
expr_builder
.
constant
(
ksize_
[
0
]);
output
.
d
[
3
]
=
expr_builder
.
constant
(
ksize_
[
1
]);
output
.
d
[
4
]
=
expr_builder
.
constant
(
ksize_
[
2
]);
return
output
;
}
auto
stri_0
=
expr_builder
.
constant
(
strides_
[
0
]);
auto
stri_1
=
expr_builder
.
constant
(
strides_
[
1
]);
auto
stri_2
=
expr_builder
.
constant
(
strides_
[
2
]);
auto
one_value
=
expr_builder
.
constant
(
1
);
auto
v0_tmp
=
expr_builder
.
constant
(
-
ksize_
[
0
]
+
2
*
paddings_
[
0
]);
auto
v1_tmp
=
expr_builder
.
constant
(
-
ksize_
[
1
]
+
2
*
paddings_
[
1
]);
auto
v2_tmp
=
expr_builder
.
constant
(
-
ksize_
[
2
]
+
2
*
paddings_
[
2
]);
auto
ceil_tmp
=
expr_builder
.
constant
(
-
ksize_
[
0
]
+
2
*
paddings_
[
0
]
+
strides_
[
0
]
-
1
);
auto
ceil1_tmp
=
expr_builder
.
constant
(
-
ksize_
[
1
]
+
2
*
paddings_
[
1
]
+
strides_
[
1
]
-
1
);
auto
ceil2_tmp
=
expr_builder
.
constant
(
-
ksize_
[
2
]
+
2
*
paddings_
[
2
]
+
strides_
[
2
]
-
1
);
if
(
!
ceil_mode_
)
{
output
.
d
[
2
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kFLOOR_DIV
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
inputs
[
0
].
d
[
2
],
*
v0_tmp
),
*
stri_0
),
*
one_value
);
output
.
d
[
3
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kFLOOR_DIV
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
inputs
[
0
].
d
[
3
],
*
v1_tmp
),
*
stri_1
),
*
one_value
);
output
.
d
[
4
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kFLOOR_DIV
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
inputs
[
0
].
d
[
4
],
*
v2_tmp
),
*
stri_2
),
*
one_value
);
}
else
{
output
.
d
[
2
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kFLOOR_DIV
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
inputs
[
0
].
d
[
2
],
*
ceil_tmp
),
*
stri_0
),
*
one_value
);
output
.
d
[
3
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kFLOOR_DIV
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
inputs
[
0
].
d
[
3
],
*
ceil1_tmp
),
*
stri_1
),
*
one_value
);
output
.
d
[
4
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kFLOOR_DIV
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
inputs
[
0
].
d
[
4
],
*
ceil2_tmp
),
*
stri_2
),
*
one_value
);
}
return
output
;
}
bool
Pool3DPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_NOT_NULL
(
in_out
,
platform
::
errors
::
InvalidArgument
(
"The input of swish plugin shoule not be nullptr."
));
PADDLE_ENFORCE_LT
(
pos
,
nb_inputs
+
nb_outputs
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
"num(%d) of the input and the output."
,
pos
,
nb_inputs
+
nb_outputs
));
(
in_out
&&
pos
<
(
nb_inputs
+
nb_outputs
));
return
((
in_out
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
in_out
[
pos
].
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
);
}
nvinfer1
::
DataType
Pool3DPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The Pool3D Plugin only has one input, so the "
"index value should be 0, but get %d."
,
index
));
PADDLE_ENFORCE_EQ
((
input_types
[
0
]
==
nvinfer1
::
DataType
::
kFLOAT
),
true
,
platform
::
errors
::
InvalidArgument
(
"The input type should be half or float"
));
return
input_types
[
0
];
}
int
Pool3DPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
auto
input_dims
=
input_desc
[
0
].
dims
;
int
n
=
input_dims
.
d
[
0
];
int
c
=
input_dims
.
d
[
1
];
int
d
=
input_dims
.
d
[
2
];
int
h
=
input_dims
.
d
[
3
];
int
w
=
input_dims
.
d
[
4
];
const
float
*
input
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
float
*
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
std
::
vector
<
int
>
input_shape
,
output_shape
;
for
(
int
i
=
0
;
i
<
input_dims
.
nbDims
;
i
++
)
input_shape
.
push_back
(
input_dims
.
d
[
i
]);
output_shape
=
input_shape
;
std
::
vector
<
int
>
ksize
=
ksize_
;
std
::
vector
<
int
>
paddings
=
paddings_
;
if
(
is_global_
)
{
ksize
[
0
]
=
d
;
ksize
[
1
]
=
h
;
ksize
[
2
]
=
w
;
paddings
[
0
]
=
0
;
paddings
[
1
]
=
0
;
paddings
[
2
]
=
0
;
output_shape
[
2
]
=
1
;
output_shape
[
3
]
=
1
;
output_shape
[
4
]
=
1
;
}
else
{
auto
data_dim
=
CalcOutputSize
({
d
,
h
,
w
},
ceil_mode_
,
adaptive_
,
ksize_
,
strides_
,
paddings_
);
output_shape
[
2
]
=
data_dim
[
0
];
output_shape
[
3
]
=
data_dim
[
1
];
output_shape
[
4
]
=
data_dim
[
2
];
}
if
(
pool3d_type_
==
"max"
)
{
paddle
::
operators
::
math
::
MaxPool
<
float
>
pool_process
;
paddle
::
operators
::
math
::
Pool3dDirectCUDAFunctor
<
paddle
::
operators
::
math
::
MaxPool
<
float
>
,
float
>
pool3d_forward
;
pool3d_forward
(
input
,
input_shape
,
output_shape
,
ksize
,
strides_
,
paddings
,
true
,
adaptive_
,
output
,
stream
,
pool_process
);
}
else
if
(
pool3d_type_
==
"avg"
)
{
paddle
::
operators
::
math
::
AvgPool
<
float
>
pool_process
;
paddle
::
operators
::
math
::
Pool3dDirectCUDAFunctor
<
paddle
::
operators
::
math
::
AvgPool
<
float
>
,
float
>
pool3d_forward
;
pool3d_forward
(
input
,
input_shape
,
output_shape
,
ksize
,
strides_
,
paddings
,
true
,
adaptive_
,
output
,
stream
,
pool_process
);
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/pool3d_op_plugin.h
0 → 100644
浏览文件 @
229bae81
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
static
std
::
vector
<
int
>
CalcOutputSize
(
const
std
::
vector
<
int
>&
input_shape
,
const
bool
&
ceil_mode
,
const
bool
&
adaptive
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
)
{
std
::
vector
<
int
>
output_shape
=
input_shape
;
if
(
adaptive
)
{
output_shape
[
0
]
=
ksize
[
0
];
output_shape
[
1
]
=
ksize
[
1
];
output_shape
[
2
]
=
ksize
[
2
];
}
else
{
int
output_d
=
(
input_shape
[
0
]
-
ksize
[
0
]
+
2
*
paddings
[
0
])
/
strides
[
0
]
+
1
;
int
output_h
=
(
input_shape
[
1
]
-
ksize
[
1
]
+
2
*
paddings
[
1
])
/
strides
[
1
]
+
1
;
int
output_w
=
(
input_shape
[
2
]
-
ksize
[
2
]
+
2
*
paddings
[
2
])
/
strides
[
2
]
+
1
;
if
(
ceil_mode
)
{
output_d
=
(
input_shape
[
0
]
-
ksize
[
0
]
+
2
*
paddings
[
0
]
+
strides
[
0
]
-
1
)
/
strides
[
0
]
+
1
;
output_h
=
(
input_shape
[
1
]
-
ksize
[
1
]
+
2
*
paddings
[
1
]
+
strides
[
1
]
-
1
)
/
strides
[
1
]
+
1
;
output_w
=
(
input_shape
[
2
]
-
ksize
[
2
]
+
2
*
paddings
[
2
]
+
strides
[
2
]
-
1
)
/
strides
[
2
]
+
1
;
}
output_shape
[
0
]
=
output_d
;
output_shape
[
1
]
=
output_h
;
output_shape
[
2
]
=
output_w
;
}
return
output_shape
;
}
class
Pool3DPlugin
:
public
PluginTensorRTV2Ext
{
public:
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
;
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
;
enum
class
Pool3DType
{
max
=
0
,
avg
,
};
Pool3DPlugin
()
{}
Pool3DPlugin
(
bool
ceil_mode
,
Pool3DType
pool3d_type
,
bool
adaptive
,
std
::
vector
<
int
>
ksize
,
std
::
vector
<
int
>
strides
,
std
::
vector
<
int
>
paddings
,
std
::
vector
<
int
>
input_shape
)
:
ceil_mode_
(
ceil_mode
),
pool3d_type_
(
pool3d_type
),
adaptive_
(
adaptive
),
ksize_
(
ksize
),
strides_
(
strides
),
paddings_
(
paddings
),
input_shape_
(
input_shape
)
{
output_shape_
=
input_shape_
;
std
::
vector
<
int
>
output_shape
=
CalcOutputSize
({
input_shape_
[
1
],
input_shape_
[
2
],
input_shape_
[
3
]},
ceil_mode_
,
adaptive_
,
ksize_
,
strides_
,
paddings_
);
output_shape_
[
1
]
=
output_shape
[
0
];
output_shape_
[
2
]
=
output_shape
[
1
];
output_shape_
[
3
]
=
output_shape
[
2
];
}
// It was used for tensorrt deserialization.
// It should not be called by users.
Pool3DPlugin
(
void
const
*
serialData
,
size_t
serialLength
)
{
deserializeBase
(
serialData
,
serialLength
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
ceil_mode_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
pool3d_type_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
adaptive_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
ksize_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
strides_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
paddings_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
input_shape_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
output_shape_
);
}
Pool3DPlugin
*
clone
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
override
;
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nbInputDims
)
TRT_NOEXCEPT
override
;
int
initialize
()
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
;
#if IS_TRT_VERSION_LT(8000)
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
#else
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
#endif
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
private:
bool
ceil_mode_
;
Pool3DType
pool3d_type_
;
bool
adaptive_
;
std
::
vector
<
int
>
ksize_
;
std
::
vector
<
int
>
strides_
;
std
::
vector
<
int
>
paddings_
;
std
::
vector
<
int
>
input_shape_
;
std
::
vector
<
int
>
output_shape_
;
};
class
Pool3DPluginCreator
:
public
TensorRTPluginCreator
{
public:
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"pool3d_plugin"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
{
return
new
Pool3DPlugin
(
serial_data
,
serial_length
);
}
};
REGISTER_TRT_PLUGIN_V2
(
Pool3DPluginCreator
);
class
Pool3DPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
Pool3DPluginDynamic
()
{}
Pool3DPluginDynamic
(
const
bool
&
ceil_mode
,
const
std
::
string
&
pool3d_type
,
const
bool
&
adaptive
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
bool
&
is_global
)
:
ceil_mode_
(
ceil_mode
),
pool3d_type_
(
pool3d_type
),
adaptive_
(
adaptive
),
ksize_
(
ksize
),
strides_
(
strides
),
paddings_
(
paddings
),
is_global_
(
is_global
)
{}
Pool3DPluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
);
~
Pool3DPluginDynamic
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
;
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
;
int
initialize
()
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
;
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
{
delete
this
;
}
private:
bool
ceil_mode_
;
std
::
string
pool3d_type_
;
bool
adaptive_
;
std
::
vector
<
int
>
ksize_
;
std
::
vector
<
int
>
strides_
;
std
::
vector
<
int
>
paddings_
;
bool
is_global_
;
};
class
Pool3DPluginDynamicCreator
:
public
TensorRTPluginCreator
{
public:
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"pool3d_plugin_dynamic"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
{
return
new
Pool3DPluginDynamic
(
serial_data
,
serial_length
);
}
};
REGISTER_TRT_PLUGIN_V2
(
Pool3DPluginDynamicCreator
);
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/math/pooling.cu
浏览文件 @
229bae81
...
...
@@ -979,6 +979,49 @@ __global__ void KernelMaxPool3DGrad(
}
}
template
<
typename
PoolProcess
,
typename
T
>
void
Pool3dDirectCUDAFunctor
<
PoolProcess
,
T
>::
operator
()(
const
T
*
input
,
const
std
::
vector
<
int
>&
input_shape
,
const
std
::
vector
<
int
>&
output_shape
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
bool
exclusive
,
bool
adaptive
,
T
*
output
,
gpuStream_t
stream
,
PoolProcess
pool_compute
)
{
const
int
batch_size
=
input_shape
[
0
];
const
int
input_channels
=
input_shape
[
1
];
const
int
input_depth
=
input_shape
[
2
];
const
int
input_height
=
input_shape
[
3
];
const
int
input_width
=
input_shape
[
4
];
const
int
output_channels
=
output_shape
[
1
];
const
int
output_depth
=
output_shape
[
2
];
const
int
output_height
=
output_shape
[
3
];
const
int
output_width
=
output_shape
[
4
];
const
int
ksize_depth
=
ksize
[
0
];
const
int
ksize_height
=
ksize
[
1
];
const
int
ksize_width
=
ksize
[
2
];
const
int
stride_depth
=
strides
[
0
];
const
int
stride_height
=
strides
[
1
];
const
int
stride_width
=
strides
[
2
];
const
int
padding_depth
=
paddings
[
0
];
const
int
padding_height
=
paddings
[
1
];
const
int
padding_width
=
paddings
[
2
];
int
nthreads
=
batch_size
*
output_channels
*
output_depth
*
output_height
*
output_width
;
int
thread_num
=
1024
;
#ifdef WITH_NV_JETSON
thread_num
=
512
;
#endif
int
blocks
=
(
nthreads
+
thread_num
-
1
)
/
thread_num
;
dim3
threads
(
thread_num
,
1
);
dim3
grid
(
blocks
,
1
);
KernelPool3D
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
nthreads
,
input
,
input_channels
,
input_depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
pool_compute
,
exclusive
,
adaptive
,
output
);
}
/*
* Tensors are in NCDHW or NDHWC format.
* Ksize, strides, paddings are three elements. These three elements represent
...
...
@@ -1315,6 +1358,11 @@ class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
}
};
template
class
Pool3dDirectCUDAFunctor
<
paddle
::
operators
::
math
::
MaxPool
<
float
>,
float
>
;
template
class
Pool3dDirectCUDAFunctor
<
paddle
::
operators
::
math
::
AvgPool
<
float
>,
float
>
;
template
class
MaxPool3dGradFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
MaxPool3dGradFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
MaxPool3dGradFunctor
<
platform
::
CUDADeviceContext
,
...
...
paddle/fluid/operators/math/pooling.h
浏览文件 @
229bae81
...
...
@@ -187,6 +187,20 @@ class MaxPool2dGradFunctor {
const
std
::
string
data_format
,
framework
::
Tensor
*
input_grad
);
};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template
<
typename
PoolProcess
,
typename
T
>
class
Pool3dDirectCUDAFunctor
{
public:
void
operator
()(
const
T
*
input
,
const
std
::
vector
<
int
>&
input_shape
,
const
std
::
vector
<
int
>&
output_shape
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
bool
exclusive
,
bool
adaptive
,
T
*
output
,
gpuStream_t
stream
,
PoolProcess
pool_compute
);
};
#endif
template
<
typename
DeviceContext
,
typename
PoolProcess
,
typename
T
>
class
Pool3dFunctor
{
public:
...
...
python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
浏览文件 @
229bae81
...
...
@@ -58,8 +58,10 @@ set_tests_properties(test_trt_conv_pass PROPERTIES TIMEOUT 120)
set_tests_properties
(
test_trt_dynamic_shape PROPERTIES TIMEOUT 120
)
if
(
WITH_NV_JETSON
)
set_tests_properties
(
test_trt_pool_op PROPERTIES ENVIRONMENT FLAGS_fraction_of_gpu_memory_to_use=0.1 TIMEOUT 450
)
set_tests_properties
(
test_trt_pool3d_op PROPERTIES ENVIRONMENT FLAGS_fraction_of_gpu_memory_to_use=0.1 TIMEOUT 450
)
else
()
set_tests_properties
(
test_trt_pool_op PROPERTIES ENVIRONMENT FLAGS_fraction_of_gpu_memory_to_use=0.1 TIMEOUT 45
)
set_tests_properties
(
test_trt_pool3d_op PROPERTIES ENVIRONMENT FLAGS_fraction_of_gpu_memory_to_use=0.1 TIMEOUT 45
)
endif
()
set_tests_properties
(
test_trt_reduce_mean_op PROPERTIES TIMEOUT 60
)
set_tests_properties
(
test_trt_tile_op PROPERTIES TIMEOUT 60
)
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_pool3d_op.py
0 → 100644
浏览文件 @
229bae81
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
shutil
import
unittest
import
itertools
import
numpy
as
np
from
inference_pass_test
import
InferencePassTest
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.core
import
PassVersionChecker
from
paddle.fluid.core
import
AnalysisConfig
class
TensorRTPool3dTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
bs
=
1
self
.
channel
=
3
self
.
depth
=
8
self
.
height
=
8
self
.
width
=
8
self
.
pool_size
=
2
self
.
pool_type
=
'max'
self
.
pool_stride
=
1
self
.
pool_padding
=
0
self
.
global_pooling
=
False
self
.
ceil_mode
=
False
self
.
exclusive
=
False
self
.
enable_trt
=
True
self
.
serialize
=
False
self
.
precision
=
AnalysisConfig
.
Precision
.
Float32
self
.
feeds
=
{
'data'
:
np
.
random
.
random
(
[
self
.
bs
,
self
.
channel
,
self
.
depth
,
self
.
height
,
self
.
width
]).
astype
(
'float32'
),
}
def
set_extra_config
(
self
):
pass
def
build_network
(
self
):
self
.
set_extra_config
()
self
.
trt_parameters
=
TensorRTPool3dTest
.
TensorRTParam
(
1
<<
30
,
self
.
bs
,
0
,
self
.
precision
,
self
.
serialize
,
False
)
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
'data'
,
shape
=
[
-
1
,
self
.
channel
,
self
.
depth
,
self
.
height
,
self
.
width
],
dtype
=
'float32'
)
pool_out
=
fluid
.
layers
.
pool3d
(
input
=
data
,
pool_size
=
self
.
pool_size
,
pool_type
=
self
.
pool_type
,
pool_stride
=
self
.
pool_stride
,
pool_padding
=
self
.
pool_padding
,
global_pooling
=
self
.
global_pooling
,
ceil_mode
=
self
.
ceil_mode
,
exclusive
=
self
.
exclusive
)
#out = fluid.layers.batch_norm(pool_out, is_test=True)
self
.
fetch_list
=
[
pool_out
]
def
check_output
(
self
):
if
os
.
path
.
exists
(
self
.
path
+
"_opt_cache"
):
shutil
.
rmtree
(
self
.
path
+
"_opt_cache"
)
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
def
run_test
(
self
):
self
.
build_network
()
self
.
check_output
()
def
test
(
self
):
precision_options
=
[
AnalysisConfig
.
Precision
.
Float32
,
AnalysisConfig
.
Precision
.
Half
]
serialize_options
=
[
False
,
True
]
dynamic_shape_profile
=
InferencePassTest
.
DynamicShapeParam
({
'data'
:
[
self
.
bs
,
self
.
channel
,
self
.
depth
//
2
,
self
.
height
//
2
,
self
.
width
//
2
]
},
{
'data'
:
[
self
.
bs
,
self
.
channel
,
self
.
depth
,
self
.
height
,
self
.
width
]
},
{
'data'
:
[
self
.
bs
,
self
.
channel
,
self
.
depth
,
self
.
height
,
self
.
width
]
},
False
)
dynamic_shape_options
=
[
None
,
dynamic_shape_profile
]
for
precision
,
serialize
,
dynamic_shape
in
itertools
.
product
(
precision_options
,
serialize_options
,
dynamic_shape_options
):
is_dynamic
=
True
if
dynamic_shape_options
is
not
None
else
False
with
self
.
subTest
(
'Precision: {}, Serialize: {}, Dynamic: {}'
.
format
(
precision
,
serialize
,
is_dynamic
)):
self
.
precision
=
precision
self
.
serialize
=
serialize
self
.
dynamic_shape_params
=
dynamic_shape
self
.
run_test
()
class
TensorRTAvgPool3dTest
(
TensorRTPool3dTest
):
def
set_extra_config
(
self
):
self
.
pool_size
=
2
self
.
pool_type
=
'avg'
self
.
pool_stride
=
1
self
.
pool_padding
=
0
self
.
global_pooling
=
False
self
.
ceil_mode
=
False
self
.
exclusive
=
False
class
TensorRTGlobalPool3dTest
(
TensorRTPool3dTest
):
def
set_extra_config
(
self
):
self
.
pool_size
=
2
self
.
pool_type
=
'max'
self
.
pool_stride
=
1
self
.
pool_padding
=
0
self
.
global_pooling
=
True
self
.
ceil_mode
=
False
self
.
exclusive
=
False
class
TensorRTCeilPool3dTest
(
TensorRTPool3dTest
):
def
set_extra_config
(
self
):
self
.
pool_size
=
2
self
.
pool_type
=
'max'
self
.
pool_stride
=
1
self
.
pool_padding
=
0
self
.
global_pooling
=
False
self
.
ceil_mode
=
True
self
.
exclusive
=
False
class
TensorRTExclusivePool3dTest
(
TensorRTPool3dTest
):
def
set_extra_config
(
self
):
self
.
pool_size
=
2
self
.
pool_type
=
'max'
self
.
pool_stride
=
1
self
.
pool_padding
=
0
self
.
global_pooling
=
False
self
.
ceil_mode
=
False
self
.
exclusive
=
True
class
TensorRTSamePaddingPool3dTest
(
InferencePassTest
):
def
set_extra_config
(
self
):
self
.
pool_size
=
2
self
.
pool_type
=
'max'
self
.
pool_stride
=
1
self
.
pool_padding
=
'SAME'
self
.
global_pooling
=
False
self
.
ceil_mode
=
False
self
.
exclusive
=
False
class
TensorRTValidPaddingPool3dTest
(
InferencePassTest
):
def
set_extra_config
(
self
):
self
.
pool_size
=
2
self
.
pool_type
=
'max'
self
.
pool_stride
=
1
self
.
pool_padding
=
'VALID'
self
.
global_pooling
=
False
self
.
ceil_mode
=
False
self
.
exclusive
=
False
class
TensorRTAdaptiveAvgPool3DTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
bs
=
1
self
.
channel
=
3
self
.
depth
=
8
self
.
height
=
8
self
.
width
=
8
self
.
enable_trt
=
True
self
.
serialize
=
False
self
.
precision
=
AnalysisConfig
.
Precision
.
Float32
self
.
feeds
=
{
'data'
:
np
.
random
.
random
(
[
self
.
bs
,
self
.
channel
,
self
.
depth
,
self
.
height
,
self
.
width
]).
astype
(
'float32'
),
}
def
build_network
(
self
):
self
.
trt_parameters
=
TensorRTPool3dTest
.
TensorRTParam
(
1
<<
30
,
self
.
bs
,
0
,
self
.
precision
,
self
.
serialize
,
False
)
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
'data'
,
shape
=
[
-
1
,
self
.
channel
,
self
.
depth
,
self
.
height
,
self
.
width
],
dtype
=
'float32'
)
pool_out
=
paddle
.
nn
.
functional
.
adaptive_avg_pool3d
(
x
=
data
,
output_size
=
[
3
,
3
,
3
])
#out = fluid.layers.batch_norm(pool_out, is_test=True)
self
.
fetch_list
=
[
pool_out
]
def
check_output
(
self
):
if
os
.
path
.
exists
(
self
.
path
+
"_opt_cache"
):
shutil
.
rmtree
(
self
.
path
+
"_opt_cache"
)
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
def
run_test
(
self
):
self
.
build_network
()
self
.
check_output
()
def
test
(
self
):
precision_options
=
[
AnalysisConfig
.
Precision
.
Float32
,
AnalysisConfig
.
Precision
.
Half
]
serialize_options
=
[
False
,
True
]
dynamic_shape_profile
=
InferencePassTest
.
DynamicShapeParam
({
'data'
:
[
self
.
bs
,
self
.
channel
,
self
.
depth
//
2
,
self
.
height
//
2
,
self
.
width
//
2
]
},
{
'data'
:
[
self
.
bs
,
self
.
channel
,
self
.
depth
,
self
.
height
,
self
.
width
]
},
{
'data'
:
[
self
.
bs
,
self
.
channel
,
self
.
depth
,
self
.
height
,
self
.
width
]
},
False
)
dynamic_shape_options
=
[
None
,
dynamic_shape_profile
]
for
precision
,
serialize
,
dynamic_shape
in
itertools
.
product
(
precision_options
,
serialize_options
,
dynamic_shape_options
):
is_dynamic
=
True
if
dynamic_shape_options
is
not
None
else
False
with
self
.
subTest
(
'Precision: {}, Serialize: {}, Dynamic: {}'
.
format
(
precision
,
serialize
,
is_dynamic
)):
self
.
precision
=
precision
self
.
serialize
=
serialize
self
.
dynamic_shape_params
=
dynamic_shape
self
.
run_test
()
class
TensorRTAdaptiveMaxPool3DTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
bs
=
1
self
.
channel
=
3
self
.
depth
=
8
self
.
height
=
8
self
.
width
=
8
self
.
enable_trt
=
True
self
.
serialize
=
False
self
.
precision
=
AnalysisConfig
.
Precision
.
Float32
self
.
feeds
=
{
'data'
:
np
.
random
.
random
(
[
self
.
bs
,
self
.
channel
,
self
.
depth
,
self
.
height
,
self
.
width
]).
astype
(
'float32'
),
}
def
build_network
(
self
):
self
.
trt_parameters
=
TensorRTPool3dTest
.
TensorRTParam
(
1
<<
30
,
self
.
bs
,
0
,
self
.
precision
,
self
.
serialize
,
False
)
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
'data'
,
shape
=
[
-
1
,
self
.
channel
,
self
.
depth
,
self
.
height
,
self
.
width
],
dtype
=
'float32'
)
pool_out
=
paddle
.
nn
.
functional
.
adaptive_max_pool3d
(
x
=
data
,
output_size
=
[
3
,
3
,
3
])
#out = fluid.layers.batch_norm(pool_out, is_test=True)
self
.
fetch_list
=
[
pool_out
]
def
check_output
(
self
):
if
os
.
path
.
exists
(
self
.
path
+
"_opt_cache"
):
shutil
.
rmtree
(
self
.
path
+
"_opt_cache"
)
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
def
run_test
(
self
):
self
.
build_network
()
self
.
check_output
()
def
test
(
self
):
precision_options
=
[
AnalysisConfig
.
Precision
.
Float32
,
AnalysisConfig
.
Precision
.
Half
]
serialize_options
=
[
False
,
True
]
dynamic_shape_profile
=
InferencePassTest
.
DynamicShapeParam
({
'data'
:
[
self
.
bs
,
self
.
channel
,
self
.
depth
//
2
,
self
.
height
//
2
,
self
.
width
//
2
]
},
{
'data'
:
[
self
.
bs
,
self
.
channel
,
self
.
depth
,
self
.
height
,
self
.
width
]
},
{
'data'
:
[
self
.
bs
,
self
.
channel
,
self
.
depth
,
self
.
height
,
self
.
width
]
},
False
)
dynamic_shape_options
=
[
None
,
dynamic_shape_profile
]
for
precision
,
serialize
,
dynamic_shape
in
itertools
.
product
(
precision_options
,
serialize_options
,
dynamic_shape_options
):
is_dynamic
=
True
if
dynamic_shape_options
is
not
None
else
False
with
self
.
subTest
(
'Precision: {}, Serialize: {}, Dynamic: {}'
.
format
(
precision
,
serialize
,
is_dynamic
)):
self
.
precision
=
precision
self
.
serialize
=
serialize
self
.
dynamic_shape_params
=
dynamic_shape
self
.
run_test
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录