Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ad349e77
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ad349e77
编写于
11月 21, 2018
作者:
Z
Zhaolong Xing
提交者:
GitHub
11月 21, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14452 from NHZlX/fix_avg_pool_trt_bug
fix avg pool trt bug
上级
1d9b2a45
e62872df
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
324 addition
and
61 deletion
+324
-61
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
+89
-53
paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc
paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc
+9
-7
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu
+64
-0
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h
+111
-0
paddle/fluid/operators/math/pooling.cu
paddle/fluid/operators/math/pooling.cu
+36
-0
paddle/fluid/operators/math/pooling.h
paddle/fluid/operators/math/pooling.h
+13
-0
未找到文件。
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
ad349e77
...
...
@@ -18,7 +18,7 @@ nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
nv_test
(
test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc
DEPS
${
FLUID_CORE_MODULES
}
${
GLOB_OPERATOR_DEPS
}
tensorrt_engine conv_op conv_transpose_op SERIAL
)
nv_test
(
test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
DEPS
${
FLUID_CORE_MODULES
}
${
GLOB_OPERATOR_DEPS
}
tensorrt_engine pool_op SERIAL
)
DEPS
${
FLUID_CORE_MODULES
}
${
GLOB_OPERATOR_DEPS
}
tensorrt_engine pool_op
tensorrt_plugin
SERIAL
)
nv_test
(
test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
DEPS
${
FLUID_CORE_MODULES
}
${
GLOB_OPERATOR_DEPS
}
tensorrt_engine tensorrt_plugin
elementwise_add_op elementwise_mul_op SERIAL
)
...
...
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
浏览文件 @
ad349e77
...
...
@@ -13,25 +13,57 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
void
DealCeilMode
(
const
nvinfer1
::
Dims
&
input_shape
,
std
::
vector
<
int
>
ksize
,
std
::
vector
<
int
>
strides
,
std
::
vector
<
int
>
paddings
,
nvinfer1
::
DimsHW
*
pre_pad
,
nvinfer1
::
DimsHW
*
post_pad
,
int
input_dims
)
{
int
input_height
=
input_shape
.
d
[
input_dims
-
2
];
int
input_width
=
input_shape
.
d
[
input_dims
-
1
];
int
floor_h_output_size
=
(
input_height
-
ksize
[
0
]
+
2
*
paddings
[
0
])
/
strides
[
0
]
+
1
;
int
ceil_h_output_size
=
(
input_height
-
ksize
[
0
]
+
2
*
paddings
[
0
]
+
strides
[
0
]
-
1
)
/
strides
[
0
]
+
1
;
int
floor_w_output_size
=
(
input_width
-
ksize
[
1
]
+
2
*
paddings
[
1
])
/
strides
[
1
]
+
1
;
int
ceil_w_output_size
=
(
input_width
-
ksize
[
1
]
+
2
*
paddings
[
1
]
+
strides
[
1
]
-
1
)
/
strides
[
1
]
+
1
;
if
(
floor_h_output_size
!=
ceil_h_output_size
)
{
post_pad
->
h
()
=
strides
[
0
]
-
1
;
}
if
(
floor_w_output_size
!=
ceil_w_output_size
)
{
post_pad
->
w
()
=
strides
[
1
]
-
1
;
}
}
/*
* Pool2dOp, IPoolingLayer in TRT. This Layer doesn't has weights.
*/
class
Pool2dOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
40
)
<<
"convert a fluid pool2d op to tensorrt pool2d layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"X"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Output
(
"Out"
).
size
(),
1
);
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
nvinfer1
::
Dims
input_shape
=
input1
->
getDimensions
();
int
input_dims
=
input_shape
.
nbDims
;
PADDLE_ENFORCE_EQ
(
input_dims
,
3UL
);
bool
global_pooling
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"global_pooling"
));
std
::
string
pool_type
=
...
...
@@ -44,23 +76,6 @@ class Pool2dOpConverter : public OpConverter {
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"paddings"
));
bool
ceil_mode
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"ceil_mode"
));
nvinfer1
::
Dims
input_shape
=
input1
->
getDimensions
();
int
nbDims
=
input_shape
.
nbDims
;
nvinfer1
::
DimsHW
nv_ksize
(
ksize
[
0
],
ksize
[
1
]);
nvinfer1
::
DimsHW
nv_strides
(
strides
[
0
],
strides
[
1
]);
nvinfer1
::
DimsHW
nv_paddings
(
paddings
[
0
],
paddings
[
1
]);
if
(
global_pooling
==
true
)
{
nv_ksize
.
d
[
0
]
=
input_shape
.
d
[
nbDims
-
2
];
nv_ksize
.
d
[
1
]
=
input_shape
.
d
[
nbDims
-
1
];
nv_strides
.
h
()
=
1
;
nv_strides
.
w
()
=
1
;
nv_paddings
.
h
()
=
0
;
nv_paddings
.
w
()
=
0
;
}
PADDLE_ENFORCE_EQ
(
input1
->
getDimensions
().
nbDims
,
3UL
);
nvinfer1
::
PoolingType
nv_pool_type
=
nvinfer1
::
PoolingType
::
kMAX
;
if
(
pool_type
==
"max"
)
{
nv_pool_type
=
nvinfer1
::
PoolingType
::
kMAX
;
...
...
@@ -70,42 +85,63 @@ class Pool2dOpConverter : public OpConverter {
PADDLE_THROW
(
"TensorRT unsupported pooling type!"
);
}
if
(
ceil_mode
)
{
nvinfer1
::
DimsHW
pre_pad
(
0
,
0
);
nvinfer1
::
DimsHW
post_pad
(
0
,
0
);
int
input_height
=
input_shape
.
d
[
nbDims
-
2
];
int
input_width
=
input_shape
.
d
[
nbDims
-
1
];
int
floor_h_output_size
=
(
input_height
-
ksize
[
0
]
+
2
*
paddings
[
0
])
/
strides
[
0
]
+
1
;
int
ceil_h_output_size
=
(
input_height
-
ksize
[
0
]
+
2
*
paddings
[
0
]
+
strides
[
0
]
-
1
)
/
strides
[
0
]
+
1
;
int
floor_w_output_size
=
(
input_width
-
ksize
[
1
]
+
2
*
paddings
[
1
])
/
strides
[
1
]
+
1
;
int
ceil_w_output_size
=
(
input_width
-
ksize
[
1
]
+
2
*
paddings
[
1
]
+
strides
[
1
]
-
1
)
/
strides
[
1
]
+
1
;
if
(
floor_h_output_size
!=
ceil_h_output_size
)
{
post_pad
.
h
()
=
strides
[
0
]
-
1
;
nvinfer1
::
DimsHW
nv_ksize
(
ksize
[
0
],
ksize
[
1
]);
nvinfer1
::
DimsHW
nv_strides
(
strides
[
0
],
strides
[
1
]);
nvinfer1
::
DimsHW
nv_paddings
(
paddings
[
0
],
paddings
[
1
]);
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
global_pooling
==
true
)
{
nv_ksize
.
d
[
0
]
=
input_shape
.
d
[
input_dims
-
2
];
nv_ksize
.
d
[
1
]
=
input_shape
.
d
[
input_dims
-
1
];
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Pooling
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
nv_pool_type
,
nv_ksize
);
PADDLE_ENFORCE_NOT_NULL
(
layer
,
"pool layer could not be created."
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
layer
->
setName
((
"pool2d (Output: "
+
output_name
+
")"
).
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
}
return
;
}
if
(
floor_w_output_size
!=
ceil_w_output_size
)
{
post_pad
.
w
()
=
strides
[
1
]
-
1
;
if
(
pool_type
==
"max"
)
{
nvinfer1
::
DimsHW
pre_pad
(
paddings
[
0
],
paddings
[
1
]);
nvinfer1
::
DimsHW
post_pad
(
paddings
[
0
],
paddings
[
1
]);
if
(
ceil_mode
)
{
// If ceil mode is true, we will pad the appropriate size to the input.
DealCeilMode
(
input_shape
,
ksize
,
strides
,
paddings
,
&
pre_pad
,
&
post_pad
,
input_dims
);
auto
*
pad_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Padding
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
pre_pad
,
post_pad
);
PADDLE_ENFORCE_NOT_NULL
(
pad_layer
,
"pad layer in poolOp converter could not be created."
);
input1
=
pad_layer
->
getOutput
(
0
);
}
auto
*
pool_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Pooling
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
nv_pool_type
,
nv_ksize
);
PADDLE_ENFORCE_NOT_NULL
(
pool_layer
,
"pool layer could not be created."
);
pool_layer
->
setStride
(
nv_strides
);
pool_layer
->
setPadding
(
nv_paddings
);
layer
=
pool_layer
;
}
else
{
// Average pooling needs to exclude the padding pixels from the average
// mean.
// It is not supported well by TRT, we use a plugin here.
std
::
vector
<
int
>
input_shape_v
;
for
(
int
i
=
0
;
i
<
input_dims
;
i
++
)
{
input_shape_v
.
push_back
(
input_shape
.
d
[
i
]);
}
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Padding
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
pre_pad
,
post_pad
);
input1
=
layer
->
getOutput
(
0
)
;
plugin
::
AvgPoolPlugin
*
plugin
=
new
plugin
::
AvgPoolPlugin
(
ceil_mode
,
ksize
,
strides
,
paddings
,
input_shape_v
);
auto
*
avg_pool_layer
=
engine_
->
AddPlugin
(
&
input1
,
1
,
plugin
);
layer
=
avg_pool_layer
;
}
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Pooling
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
nv_pool_type
,
nv_ksize
);
PADDLE_ENFORCE_NOT_NULL
(
layer
,
"pool layer could not be created."
);
layer
->
setStride
(
nv_strides
);
layer
->
setPadding
(
nv_paddings
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
layer
->
setName
((
"pool2d (Output: "
+
output_name
+
")"
).
c_str
());
...
...
paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc
浏览文件 @
ad349e77
...
...
@@ -20,20 +20,21 @@ namespace paddle {
namespace
inference
{
namespace
tensorrt
{
void
test_pool2d
(
bool
global_pooling
,
bool
ceil_mode
)
{
void
test_pool2d
(
bool
global_pooling
,
bool
ceil_mode
,
std
::
string
pool_type
=
"max"
)
{
framework
::
Scope
scope
;
std
::
unordered_set
<
std
::
string
>
parameters
;
TRTConvertValidation
validator
(
5
,
parameters
,
scope
,
1
<<
15
);
// The ITensor's Dims should not contain the batch size.
// So, the ITensor's Dims of input and output should be C * H * W.
validator
.
DeclInputVar
(
"pool2d-X"
,
nvinfer1
::
Dims3
(
3
,
13
,
14
));
validator
.
DeclInputVar
(
"pool2d-X"
,
nvinfer1
::
Dims3
(
3
,
6
,
7
));
if
(
global_pooling
)
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
1
,
1
));
else
if
(
ceil_mode
)
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
6
,
7
));
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
3
,
4
));
else
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
6
,
6
));
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
3
,
3
));
// Prepare Op description
framework
::
OpDesc
desc
;
...
...
@@ -41,10 +42,10 @@ void test_pool2d(bool global_pooling, bool ceil_mode) {
desc
.
SetInput
(
"X"
,
{
"pool2d-X"
});
desc
.
SetOutput
(
"Out"
,
{
"pool2d-Out"
});
std
::
vector
<
int
>
ksize
({
3
,
3
});
std
::
vector
<
int
>
ksize
({
2
,
2
});
std
::
vector
<
int
>
strides
({
2
,
2
});
std
::
vector
<
int
>
paddings
({
0
,
0
});
std
::
string
pooling_t
=
"max"
;
std
::
string
pooling_t
=
pool_type
;
desc
.
SetAttr
(
"pooling_type"
,
pooling_t
);
desc
.
SetAttr
(
"ksize"
,
ksize
);
...
...
@@ -63,7 +64,8 @@ void test_pool2d(bool global_pooling, bool ceil_mode) {
TEST
(
Pool2dOpConverter
,
normal
)
{
test_pool2d
(
false
,
false
);
}
TEST
(
Pool2dOpConverter
,
test_global_pooling
)
{
test_pool2d
(
true
,
false
);
}
TEST
(
Pool2dOpConverter
,
test_ceil_mode
)
{
test_pool2d
(
false
,
true
);
}
TEST
(
Pool2dOpConverter
,
max_ceil_test
)
{
test_pool2d
(
false
,
true
);
}
TEST
(
Pool2dOpConverter
,
avg_ceil_test
)
{
test_pool2d
(
false
,
true
,
"avg"
);
}
}
// namespace tensorrt
}
// namespace inference
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
ad349e77
nv_library
(
tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
avg_pool_op_plugin.cu
DEPS enforce tensorrt_engine
)
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu
0 → 100644
浏览文件 @
ad349e77
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h"
#include "paddle/fluid/operators/math/pooling.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
nvinfer1
::
Dims
AvgPoolPlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
)
{
assert
(
nbInputs
==
1
);
assert
(
index
==
0
);
assert
(
inputDims
[
0
].
nbDims
==
3
);
nvinfer1
::
Dims
const
&
input_dims
=
inputDims
[
0
];
nvinfer1
::
Dims
output_dims
=
input_dims
;
output_dims
.
d
[
1
]
=
output_shape_
[
1
];
output_dims
.
d
[
2
]
=
output_shape_
[
2
];
return
output_dims
;
}
int
AvgPoolPlugin
::
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
const
&
input_dims
=
this
->
getInputDims
(
0
);
int
input_size
=
0
;
float
const
*
idata
=
reinterpret_cast
<
float
const
*>
(
inputs
[
0
]);
float
**
odatas
=
reinterpret_cast
<
float
**>
(
outputs
);
paddle
::
operators
::
math
::
AvgPool
<
float
>
pool_process
;
paddle
::
operators
::
math
::
Pool2dDirectCUDAFunctor
<
paddle
::
operators
::
math
::
AvgPool
<
float
>
,
float
>
pool2d_forward
;
std
::
vector
<
int
>
input_shape
=
input_shape_
;
std
::
vector
<
int
>
output_shape
=
output_shape_
;
input_shape
.
insert
(
input_shape
.
begin
(),
batchSize
);
output_shape
.
insert
(
output_shape
.
begin
(),
batchSize
);
pool2d_forward
(
idata
,
input_shape
,
output_shape
,
ksize_
,
strides_
,
paddings_
,
pool_process
,
true
,
odatas
[
0
],
stream
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h
0 → 100644
浏览文件 @
ad349e77
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
class
AvgPoolPlugin
:
public
PluginTensorRT
{
private:
bool
ceil_mode_
;
std
::
vector
<
int
>
ksize_
;
std
::
vector
<
int
>
strides_
;
std
::
vector
<
int
>
paddings_
;
std
::
vector
<
int
>
input_shape_
;
std
::
vector
<
int
>
output_shape_
;
protected:
size_t
getSerializationSize
()
override
{
return
SerializedSize
(
ceil_mode_
)
+
SerializedSize
(
ksize_
)
+
SerializedSize
(
strides_
)
+
SerializedSize
(
paddings_
)
+
SerializedSize
(
input_shape_
)
+
getBaseSerializationSize
();
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void
serialize
(
void
*
buffer
)
override
{
serializeBase
(
buffer
);
SerializeValue
(
&
buffer
,
ceil_mode_
);
SerializeValue
(
&
buffer
,
ksize_
);
SerializeValue
(
&
buffer
,
strides_
);
SerializeValue
(
&
buffer
,
paddings_
);
SerializeValue
(
&
buffer
,
input_shape_
);
}
public:
AvgPoolPlugin
(
bool
ceil_mode
,
std
::
vector
<
int
>
ksize
,
std
::
vector
<
int
>
strides
,
std
::
vector
<
int
>
paddings
,
std
::
vector
<
int
>
input_shape
)
:
ceil_mode_
(
ceil_mode
),
ksize_
(
ksize
),
strides_
(
strides
),
paddings_
(
paddings
),
input_shape_
(
input_shape
)
{
int
output_h
,
output_w
;
output_shape_
=
input_shape_
;
if
(
!
ceil_mode_
)
{
output_h
=
(
input_shape
[
1
]
-
ksize_
[
0
]
+
2
*
paddings_
[
0
])
/
strides_
[
0
]
+
1
;
output_w
=
(
input_shape
[
2
]
-
ksize_
[
1
]
+
2
*
paddings_
[
1
])
/
strides_
[
1
]
+
1
;
}
else
{
output_h
=
(
input_shape
[
1
]
-
ksize_
[
0
]
+
2
*
paddings_
[
0
]
+
strides_
[
0
]
-
1
)
/
strides_
[
0
]
+
1
;
output_w
=
(
input_shape
[
2
]
-
ksize_
[
1
]
+
2
*
paddings_
[
1
]
+
strides_
[
1
]
-
1
)
/
strides_
[
1
]
+
1
;
}
output_shape_
[
1
]
=
output_h
;
output_shape_
[
2
]
=
output_w
;
}
// It was used for tensorrt deserialization.
// It should not be called by users.
AvgPoolPlugin
(
void
const
*
serialData
,
size_t
serialLength
)
{
deserializeBase
(
serialData
,
serialLength
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
ceil_mode_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
ksize_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
strides_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
paddings_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
input_shape_
);
}
AvgPoolPlugin
*
clone
()
const
override
{
return
new
AvgPoolPlugin
(
ceil_mode_
,
ksize_
,
strides_
,
paddings_
,
input_shape_
);
}
const
char
*
getPluginType
()
const
override
{
return
"avg_pool"
;
}
int
getNbOutputs
()
const
override
{
return
1
;
}
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nbInputDims
)
override
;
int
initialize
()
override
{
return
0
;
}
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
};
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/math/pooling.cu
浏览文件 @
ad349e77
...
...
@@ -153,6 +153,37 @@ __global__ void KernelMaxPool2DGrad(
}
}
template
<
typename
PoolProcess
,
typename
T
>
void
Pool2dDirectCUDAFunctor
<
PoolProcess
,
T
>::
operator
()(
const
T
*
input
,
const
std
::
vector
<
int
>&
input_shape
,
const
std
::
vector
<
int
>&
output_shape
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
bool
exclusive
,
T
*
output
,
cudaStream_t
stream
)
{
const
int
batch_size
=
input_shape
[
0
];
const
int
input_channels
=
input_shape
[
1
];
const
int
input_height
=
input_shape
[
2
];
const
int
input_width
=
input_shape
[
3
];
const
int
output_channels
=
output_shape
[
1
];
const
int
output_height
=
output_shape
[
2
];
const
int
output_width
=
output_shape
[
3
];
const
int
ksize_height
=
ksize
[
0
];
const
int
ksize_width
=
ksize
[
1
];
const
int
stride_height
=
strides
[
0
];
const
int
stride_width
=
strides
[
1
];
const
int
padding_height
=
paddings
[
0
];
const
int
padding_width
=
paddings
[
1
];
int
nthreads
=
batch_size
*
output_channels
*
output_height
*
output_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelPool2D
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
nthreads
,
input
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_compute
,
exclusive
,
output
);
}
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
...
...
@@ -291,6 +322,11 @@ class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
}
};
template
class
Pool2dDirectCUDAFunctor
<
paddle
::
operators
::
math
::
MaxPool
<
float
>,
float
>
;
template
class
Pool2dDirectCUDAFunctor
<
paddle
::
operators
::
math
::
AvgPool
<
float
>,
float
>
;
template
class
MaxPool2dGradFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
MaxPool2dGradFunctor
<
platform
::
CUDADeviceContext
,
double
>;
...
...
paddle/fluid/operators/math/pooling.h
浏览文件 @
ad349e77
...
...
@@ -82,6 +82,19 @@ class AvgPoolGrad {
* This is different from average pooling. So we rewrite the max_pool_grad:
* MaxPool2dGradFunctor, MaxPool3dGradFunctor.
*/
#ifdef PADDLE_WITH_CUDA
template
<
typename
PoolProcess
,
typename
T
>
class
Pool2dDirectCUDAFunctor
{
public:
void
operator
()(
const
T
*
input
,
const
std
::
vector
<
int
>&
input_shape
,
const
std
::
vector
<
int
>&
output_shape
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
bool
exclusive
,
T
*
output
,
cudaStream_t
stream
);
};
#endif
template
<
typename
DeviceContext
,
typename
PoolProcess
,
typename
T
>
class
Pool2dFunctor
{
public:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录