Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9bcd9f66
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9bcd9f66
编写于
5月 02, 2018
作者:
C
chengduo
提交者:
Abhinav Arora
5月 02, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix cpplint error (#10329)
上级
55f0d840
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
154 addition
and
120 deletion
+154
-120
paddle/fluid/operators/math/pooling.cc
paddle/fluid/operators/math/pooling.cc
+55
-52
paddle/fluid/operators/math/pooling.cu
paddle/fluid/operators/math/pooling.cu
+50
-34
paddle/fluid/operators/math/pooling.h
paddle/fluid/operators/math/pooling.h
+49
-34
未找到文件。
paddle/fluid/operators/math/pooling.cc
浏览文件 @
9bcd9f66
...
...
@@ -11,8 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/pooling.h"
#include <algorithm>
#include <vector>
namespace
paddle
{
namespace
operators
{
...
...
@@ -27,9 +28,10 @@ template <typename PoolProcess, typename T>
class
Pool2dFunctor
<
platform
::
CPUDeviceContext
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -63,11 +65,11 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
T
ele
=
pool_process
.
initial
();
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
pool_process
.
compute
(
ele
,
input_data
[
h
*
input_width
+
w
]
);
pool_process
.
compute
(
input_data
[
h
*
input_width
+
w
],
&
ele
);
}
}
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
pool_process
.
finalize
(
ele
,
(
static_cast
<
T
>
(
pool_size
))
);
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
ph
*
output_width
+
pw
]
=
ele
;
}
}
...
...
@@ -86,13 +88,12 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
template
<
typename
PoolProcess
,
class
T
>
class
Pool2dGradFunctor
<
platform
::
CPUDeviceContext
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -131,8 +132,8 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
input_data
[
h
*
input_width
+
w
],
output_data
[
ph
*
output_width
+
pw
],
output_grad_data
[
ph
*
output_width
+
pw
],
input_grad_data
[
h
*
input_width
+
w
]
,
static_cast
<
T
>
(
scale
)
);
static_cast
<
T
>
(
scale
)
,
input_grad_data
+
h
*
input_width
+
w
);
}
}
}
...
...
@@ -154,12 +155,11 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
template
<
class
T
>
class
MaxPool2dGradFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -246,9 +246,10 @@ template <typename PoolProcess, class T>
class
Pool3dFunctor
<
platform
::
CPUDeviceContext
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -293,14 +294,14 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
pool_process
.
compute
(
ele
,
input_data
[(
d
*
input_height
+
h
)
*
input_width
+
w
]
);
input_data
[(
d
*
input_height
+
h
)
*
input_width
+
w
]
,
&
ele
);
}
}
}
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
pool_process
.
finalize
(
ele
,
static_cast
<
T
>
(
pool_size
)
);
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
output_idx
]
=
ele
;
}
}
...
...
@@ -320,13 +321,12 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
template
<
typename
PoolProcess
,
class
T
>
class
Pool3dGradFunctor
<
platform
::
CPUDeviceContext
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -379,8 +379,8 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
pool_grad_process
.
compute
(
input_data
[
input_idx
],
output_data
[
output_idx
],
output_grad_data
[
output_idx
],
input_grad_data
[
input_idx
],
static_cast
<
T
>
(
scale
)
);
output_grad_data
[
output_idx
],
static_cast
<
T
>
(
scale
),
input_grad_data
+
input_idx
);
}
}
}
...
...
@@ -404,12 +404,11 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
template
<
class
T
>
class
MaxPool3dGradFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -510,9 +509,10 @@ template <typename T1, typename T2>
class
MaxPool2dWithIndexFunctor
<
platform
::
CPUDeviceContext
,
T1
,
T2
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -576,8 +576,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUDeviceContext, T1, T2> {
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
mask
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input_grad
->
dims
()[
0
];
const
int
input_height
=
input_grad
->
dims
()[
2
];
...
...
@@ -628,9 +629,10 @@ template <typename T1, typename T2>
class
MaxPool3dWithIndexFunctor
<
platform
::
CPUDeviceContext
,
T1
,
T2
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -708,8 +710,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUDeviceContext, T1, T2> {
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
mask
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input_grad
->
dims
()[
0
];
const
int
input_depth
=
input_grad
->
dims
()[
2
];
...
...
paddle/fluid/operators/math/pooling.cu
浏览文件 @
9bcd9f66
...
...
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <vector>
#include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
...
@@ -47,11 +49,11 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
T
ele
=
pool_process
.
initial
();
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
pool_process
.
compute
(
ele
,
input_data
[
h
*
input_width
+
w
]
);
pool_process
.
compute
(
input_data
[
h
*
input_width
+
w
],
&
ele
);
}
}
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
pool_process
.
finalize
(
ele
,
(
static_cast
<
T
>
(
pool_size
))
);
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
index
]
=
ele
;
}
}
...
...
@@ -96,8 +98,8 @@ __global__ void KernelPool2DGrad(
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
output_sub_idx
=
ph
*
output_width
+
pw
;
pool_process
.
compute
(
input
,
output_data
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
gradient
,
static_cast
<
T
>
(
1.0
/
pool_size
));
output_grad
[
output_sub_idx
],
static_cast
<
T
>
(
1.0
/
pool_size
)
,
&
gradient
);
}
}
input_grad
[
index
]
=
gradient
;
...
...
@@ -158,9 +160,10 @@ template <typename PoolProcess, typename T>
class
Pool2dFunctor
<
platform
::
CUDADeviceContext
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -201,9 +204,11 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -246,8 +251,10 @@ class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
...
...
@@ -340,12 +347,12 @@ __global__ void KernelPool3D(const int nthreads, const T* input_data,
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
pool_process
.
compute
(
ele
,
input_data
[(
d
*
input_height
+
h
)
*
input_width
+
w
]
);
input_data
[(
d
*
input_height
+
h
)
*
input_width
+
w
],
&
ele
);
}
}
}
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
pool_process
.
finalize
(
ele
,
static_cast
<
T
>
(
pool_size
)
);
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
index
]
=
ele
;
}
}
...
...
@@ -405,8 +412,8 @@ __global__ void KernelPool3DGrad(
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
output_sub_idx
=
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
pool_process
.
compute
(
input
,
output_data
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
gradient
,
static_cast
<
T
>
(
1.0
/
pool_size
));
output_grad
[
output_sub_idx
],
static_cast
<
T
>
(
1.0
/
pool_size
)
,
&
gradient
);
}
}
}
...
...
@@ -474,9 +481,10 @@ template <typename PoolProcess, class T>
class
Pool3dFunctor
<
platform
::
CUDADeviceContext
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
...
@@ -525,9 +533,11 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
...
@@ -578,8 +588,10 @@ class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
...
...
@@ -736,9 +748,10 @@ template <typename T1, typename T2>
class
MaxPool2dWithIndexFunctor
<
platform
::
CUDADeviceContext
,
T1
,
T2
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -779,8 +792,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
mask
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input_grad
->
dims
()[
0
];
const
int
input_channels
=
input_grad
->
dims
()[
1
];
...
...
@@ -937,9 +951,10 @@ template <typename T1, typename T2>
class
MaxPool3dWithIndexFunctor
<
platform
::
CUDADeviceContext
,
T1
,
T2
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
...
@@ -987,8 +1002,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
mask
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input_grad
->
dims
()[
0
];
const
int
input_channels
=
input_grad
->
dims
()[
1
];
...
...
paddle/fluid/operators/math/pooling.h
浏览文件 @
9bcd9f66
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
...
...
@@ -23,8 +24,8 @@ namespace operators {
namespace
math
{
#define FLT_MAX \
__FLT_MAX__ //
It might need to be placed in another file, but I'm still
// wondering where to put it.
__FLT_MAX__ //
TODO(zcd) :It might need to be placed in another file, but I'm
//
still
wondering where to put it.
/*
* \brief Extracting simple operations from pooling.
...
...
@@ -40,33 +41,33 @@ template <class T>
class
MaxPool
{
public:
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
-
FLT_MAX
);
}
DEVICE
inline
void
compute
(
T
&
y
,
const
T
&
x
)
{
y
=
y
>
x
?
y
:
x
;
}
DEVICE
inline
void
finalize
(
T
&
y
,
const
T
&
pool_field
)
{}
DEVICE
inline
void
compute
(
const
T
&
x
,
T
*
y
)
{
*
y
=
*
y
>
x
?
*
y
:
x
;
}
DEVICE
inline
void
finalize
(
const
T
&
pool_field
,
T
*
y
)
{}
};
template
<
class
T
>
class
AvgPool
{
public:
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
0
);
}
DEVICE
inline
void
compute
(
T
&
y
,
const
T
&
x
)
{
y
+=
x
;
}
DEVICE
inline
void
finalize
(
T
&
y
,
const
T
&
pool_field
)
{
y
/=
pool_field
;
}
DEVICE
inline
void
compute
(
const
T
&
x
,
T
*
y
)
{
*
y
+=
x
;
}
DEVICE
inline
void
finalize
(
const
T
&
pool_field
,
T
*
y
)
{
*
y
/=
pool_field
;
}
};
template
<
class
T
>
class
MaxPoolGrad
{
public:
DEVICE
inline
void
compute
(
const
T
&
x
,
const
T
&
y
,
const
T
&
dy
,
T
&
dx
,
T
scale
)
{
dx
+=
dy
*
(
x
==
y
);
DEVICE
inline
void
compute
(
const
T
&
x
,
const
T
&
y
,
const
T
&
dy
,
T
scale
,
T
*
dx
)
{
*
dx
+=
dy
*
(
x
==
y
);
}
};
template
<
class
T
>
class
AvgPoolGrad
{
public:
DEVICE
inline
void
compute
(
const
T
&
x
,
const
T
&
y
,
const
T
&
dy
,
T
&
dx
,
T
scale
)
{
dx
+=
(
scale
*
dy
);
DEVICE
inline
void
compute
(
const
T
&
x
,
const
T
&
y
,
const
T
&
dy
,
T
scale
,
T
*
dx
)
{
*
dx
+=
(
scale
*
dy
);
}
};
...
...
@@ -88,8 +89,9 @@ template <typename DeviceContext, typename PoolProcess, typename T>
class
Pool2dFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
output
);
};
...
...
@@ -98,9 +100,11 @@ class Pool2dGradFunctor {
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
};
template
<
typename
DeviceContext
,
class
T
>
...
...
@@ -108,8 +112,10 @@ class MaxPool2dGradFunctor {
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
);
};
...
...
@@ -117,8 +123,9 @@ template <typename DeviceContext, typename PoolProcess, typename T>
class
Pool3dFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
output
);
};
...
...
@@ -127,9 +134,11 @@ class Pool3dGradFunctor {
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
};
template
<
typename
DeviceContext
,
class
T
>
...
...
@@ -137,8 +146,10 @@ class MaxPool3dGradFunctor {
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
);
};
...
...
@@ -153,8 +164,9 @@ template <typename DeviceContext, typename T1, typename T2>
class
MaxPool2dWithIndexFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
);
};
...
...
@@ -163,8 +175,9 @@ class MaxPool2dWithIndexGradFunctor {
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
mask
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
);
};
...
...
@@ -172,8 +185,9 @@ template <typename DeviceContext, typename T1, typename T2>
class
MaxPool3dWithIndexFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
);
};
...
...
@@ -182,8 +196,9 @@ class MaxPool3dWithIndexGradFunctor {
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
mask
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
);
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录