Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2377d719
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2377d719
编写于
8月 21, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add3DPooling
上级
0f3a3e98
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
1998 addition
and
14 deletion
+1998
-14
paddle/cuda/include/hl_cnn.h
paddle/cuda/include/hl_cnn.h
+197
-1
paddle/cuda/include/stub/hl_cnn_stub.h
paddle/cuda/include/stub/hl_cnn_stub.h
+90
-0
paddle/cuda/src/hl_cuda_cnn.cu
paddle/cuda/src/hl_cuda_cnn.cu
+424
-3
paddle/gserver/layers/Pool3DLayer.cpp
paddle/gserver/layers/Pool3DLayer.cpp
+198
-0
paddle/gserver/layers/Pool3DLayer.h
paddle/gserver/layers/Pool3DLayer.h
+48
-0
paddle/gserver/tests/test_LayerGrad.cpp
paddle/gserver/tests/test_LayerGrad.cpp
+69
-0
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+502
-0
paddle/math/Matrix.h
paddle/math/Matrix.h
+247
-7
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+204
-0
paddle/parameter/Argument.cpp
paddle/parameter/Argument.cpp
+2
-0
paddle/parameter/Argument.h
paddle/parameter/Argument.h
+5
-3
proto/ModelConfig.proto
proto/ModelConfig.proto
+12
-0
未找到文件。
paddle/cuda/include/hl_cnn.h
浏览文件 @
2377d719
...
...
@@ -173,6 +173,202 @@ extern void hl_avgpool_backward(const int frameCnt,
real
*
backGrad
,
const
int
outStride
);
/**
* @brief Maximum pool forward.
*
* @param[in] frameCnt batch size of input image.
* @param[in] inputData input data.
* @param[in] channels number of channel.
* @param[in] depth image depth.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] pooledD output image depth.
* @param[in] pooledH output image height.
* @param[in] pooledW output image width.
* @param[in] sizeZ depth of pooling window.
* @param[in] sizeY height of pooling window.
* @param[in] sizeX width of pooling window.
* @param[in] strideD pooling stride depth.
* @param[in] strideH pooling stride height.
* @param[in] strideW pooling stride width.
* @param[in] paddingD padding depth.
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples.
*
*/
extern
void
hl_maxpool3D_forward
(
const
int
frameCnt
,
const
real
*
inputData
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
pooledD
,
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingD
,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
);
/**
* @brief Maximum pool backward.
*
* @param[in] frameCnt batch size of input image.
* @param[in] inputData input data.
* @param[out] outData output data.
* @param[out] outGrad output grad data.
* @param[in] channels number of channel.
* @param[in] depth image depth.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] pooledD output image depth.
* @param[in] pooledH output image height.
* @param[in] pooledW output image width.
* @param[in] sizeZ depth of pooling window.
* @param[in] sizeY height of pooling window.
* @param[in] sizeX width of pooling window.
* @param[in] strideD pooling stride depth.
* @param[in] strideH pooling stride height.
* @param[in] strideW pooling stride width.
* @param[in] scaleA scale.
* @param[in] scaleB scale.
* @param[in] paddingD padding depth.
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[out] targetGrad output grad.
* @param[in] outStride stride between output data samples.
*
*/
extern
void
hl_maxpool3D_backward
(
const
int
frameCnt
,
const
real
*
inputData
,
const
real
*
outData
,
const
real
*
outGrad
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
pooledD
,
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingD
,
const
int
paddingH
,
const
int
paddingW
,
real
scaleA
,
real
scaleB
,
real
*
targetGrad
,
const
int
outStride
);
/**
* @brief Averge pool forward.
*
* @param[in] frameCnt batch size of input image.
* @param[in] inputData input data.
* @param[in] channels number of channel.
* @param[in] depth image depth.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] pooledD output image depth.
* @param[in] pooledH output image height.
* @param[in] pooledW output image width.
* @param[in] sizeZ depth of pooling window.
* @param[in] sizeY height of pooling window.
* @param[in] sizeX width of pooling window.
* @param[in] strideD pooling stride depth.
* @param[in] strideH pooling stride height.
* @param[in] strideW pooling stride width.
* @param[in] paddingD padding depth.
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples.
*
*/
extern
void
hl_avgpool3D_forward
(
const
int
frameCnt
,
const
real
*
inputData
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
pooledD
,
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingD
,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
);
/**
* @brief Maximum pool backward.
*
* @param[in] frameCnt batch size of input image.
* @param[in] outGrad output grad data.
* @param[in] channels number of channel.
* @param[in] depth image depth.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] pooledD output image depth.
* @param[in] pooledH output image height.
* @param[in] pooledW output image width.
* @param[in] sizeZ depth of pooling window.
* @param[in] sizeY height of pooling window.
* @param[in] sizeX width of pooling window.
* @param[in] strideD pooling stride depth.
* @param[in] strideH pooling stride height.
* @param[in] strideW pooling stride width.
* @param[in] paddingD padding depth.
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[in] scaleA scale.
* @param[in] scaleB scale.
* @param[out] backGrad output grad.
* @param[in] outStride stride between output data samples.
*
*/
extern
void
hl_avgpool3D_backward
(
const
int
frameCnt
,
const
real
*
outGrad
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
pooledD
,
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
scaleA
,
real
scaleB
,
real
*
backGrad
,
const
int
outStride
);
/**
* @brief Bilinear interpolation forward.
*
...
...
@@ -275,4 +471,4 @@ extern void hl_maxout_backward(real* inGrad,
size_t
featLen
,
size_t
groups
);
#endif
/* HL_CNN_H_ */
#endif
// HL_CNN_H_
paddle/cuda/include/stub/hl_cnn_stub.h
浏览文件 @
2377d719
...
...
@@ -87,6 +87,96 @@ inline void hl_avgpool_backward(const int frameCnt,
real
*
backGrad
,
const
int
outStride
)
{}
inline
void
hl_maxpool3D_forward
(
const
int
frameCnt
,
const
real
*
inputData
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
pooledD
,
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingD
,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
)
{}
inline
void
hl_maxpool3D_backward
(
const
int
frameCnt
,
const
real
*
inputData
,
const
real
*
outData
,
const
real
*
outGrad
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
pooledD
,
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingD
,
const
int
paddingH
,
const
int
paddingW
,
real
scaleA
,
real
scaleB
,
real
*
targetGrad
,
const
int
outStride
)
{}
inline
void
hl_avgpool3D_forward
(
const
int
frameCnt
,
const
real
*
inputData
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
pooledD
,
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingD
,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
)
{}
inline
void
hl_avgpool3D_backward
(
const
int
frameCnt
,
const
real
*
outGrad
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
pooledD
,
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
scaleA
,
real
scaleB
,
real
*
backGrad
,
const
int
outStride
)
{}
inline
void
hl_bilinear_forward
(
const
real
*
inData
,
const
size_t
inImgH
,
const
size_t
inImgW
,
...
...
paddle/cuda/src/hl_cuda_cnn.cu
浏览文件 @
2377d719
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -353,6 +350,430 @@ void hl_avgpool_backward(const int frameCnt,
CHECK_SYNC
(
"hl_avgpool_backward failed"
);
}
/////////////////
__global__
void
KeMaxPool3DForward
(
const
int
nthreads
,
const
real
*
inputData
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
pooledD
,
const
int
pooledH
,
const
int
pooledW
,
const
int
ksizeD
,
const
int
ksizeH
,
const
int
ksizeW
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
offsetD
,
const
int
offsetH
,
const
int
offsetW
,
real
*
tgtData
,
const
int
tgtStride
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
(
nthreads
);
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
pooledW
;
int
ph
=
(
index
/
pooledW
)
%
pooledH
;
int
pd
=
(
index
/
pooledW
/
pooledH
)
%
pooledD
;
int
c
=
(
index
/
pooledW
/
pooledH
/
pooledD
)
%
channels
;
int
frameNum
=
index
/
pooledW
/
pooledH
/
pooledD
/
channels
;
int
dstart
=
pd
*
strideD
-
offsetD
;
int
hstart
=
ph
*
strideH
-
offsetH
;
int
wstart
=
pw
*
strideW
-
offsetW
;
int
dend
=
min
(
dstart
+
ksizeD
,
depth
);
int
hend
=
min
(
hstart
+
ksizeH
,
height
);
int
wend
=
min
(
wstart
+
ksizeW
,
width
);
dstart
=
max
(
dstart
,
0
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
real
maxval
=
-
FLT_MAX
;
inputData
+=
(
frameNum
*
channels
+
c
)
*
depth
*
height
*
width
;
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
if
(
maxval
<
inputData
[(
d
*
height
+
h
)
*
width
+
w
])
maxval
=
inputData
[(
d
*
height
+
h
)
*
width
+
w
];
}
}
}
int
tgtIndex
=
index
%
(
pooledW
*
pooledH
*
pooledD
*
channels
)
+
frameNum
*
tgtStride
;
tgtData
[
tgtIndex
]
=
maxval
;
}
}
void
hl_maxpool3D_forward
(
const
int
frameCnt
,
const
real
*
inputData
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
pooledD
,
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingD
,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
)
{
int
num_kernels
=
pooledD
*
pooledH
*
pooledW
*
channels
*
frameCnt
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KeMaxPool3DForward
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
num_kernels
,
inputData
,
channels
,
depth
,
height
,
width
,
pooledD
,
pooledH
,
pooledW
,
sizeZ
,
sizeY
,
sizeX
,
strideD
,
strideH
,
strideW
,
paddingD
,
paddingH
,
paddingW
,
tgtData
,
tgtStride
);
CHECK_SYNC
(
"hl_maxpool3D_forward failed"
);
}
__global__
void
KeMaxPool3DBackward
(
const
int
nthreads
,
const
real
*
inputData
,
const
real
*
outData
,
const
real
*
outGrad
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
pooledD
,
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
padD
,
const
int
padH
,
const
int
padW
,
real
scaleA
,
real
scaleB
,
real
*
targetGrad
,
const
int
outStride
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
(
nthreads
);
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
// find out the local index
// find out the local offset
int
offsetW
=
index
%
width
+
padW
;
int
offsetH
=
(
index
/
width
)
%
height
+
padH
;
int
offsetD
=
(
index
/
width
/
height
)
%
depth
+
padD
;
int
offsetC
=
(
index
/
width
/
height
/
depth
)
%
channels
;
int
frameNum
=
index
/
width
/
height
/
depth
/
channels
;
int
pdstart
=
(
offsetD
<
sizeZ
)
?
0
:
(
offsetD
-
sizeZ
)
/
strideD
+
1
;
int
phstart
=
(
offsetH
<
sizeY
)
?
0
:
(
offsetH
-
sizeY
)
/
strideH
+
1
;
int
pwstart
=
(
offsetW
<
sizeX
)
?
0
:
(
offsetW
-
sizeX
)
/
strideW
+
1
;
int
pdend
=
min
(
offsetD
/
strideD
+
1
,
pooledD
);
int
phend
=
min
(
offsetH
/
strideH
+
1
,
pooledH
);
int
pwend
=
min
(
offsetW
/
strideW
+
1
,
pooledW
);
real
gradient
=
0
;
real
input
=
inputData
[
index
];
outData
+=
((
frameNum
*
channels
+
offsetC
)
*
pooledD
*
pooledH
*
pooledW
);
outGrad
+=
((
frameNum
*
channels
+
offsetC
)
*
pooledD
*
pooledH
*
pooledW
);
for
(
int
pd
=
pdstart
;
pd
<
pdend
;
++
pd
)
{
for
(
int
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
int
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
if
(
input
==
outData
[(
pd
*
pooledH
+
ph
)
*
pooledW
+
pw
])
gradient
+=
outGrad
[(
pd
*
pooledH
+
ph
)
*
pooledW
+
pw
];
}
}
}
targetGrad
[
index
]
=
scaleA
*
gradient
+
scaleB
*
targetGrad
[
index
];
}
}
void
hl_maxpool3D_backward
(
const
int
frameCnt
,
const
real
*
inputData
,
const
real
*
outData
,
const
real
*
outGrad
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
outputD
,
const
int
outputH
,
const
int
outputW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingD
,
const
int
paddingH
,
const
int
paddingW
,
real
scaleA
,
real
scaleB
,
real
*
targetGrad
,
const
int
outStride
)
{
int
num_kernels
=
depth
*
height
*
width
*
channels
*
frameCnt
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
KeMaxPool3DBackward
<<<
blocks
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
num_kernels
,
inputData
,
outData
,
outGrad
,
channels
,
depth
,
height
,
width
,
outputD
,
outputH
,
outputW
,
sizeZ
,
sizeY
,
sizeX
,
strideD
,
strideH
,
strideW
,
paddingD
,
paddingH
,
paddingW
,
scaleA
,
scaleB
,
targetGrad
,
outStride
);
CHECK_SYNC
(
"hl_maxpool3D_backward"
);
}
__global__
void
KeAvgPool3DForward
(
const
int
nthreads
,
const
real
*
inputData
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
pooledD
,
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
padD
,
const
int
padH
,
const
int
padW
,
real
*
tgtData
,
const
int
tgtStride
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
(
nthreads
);
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
pooledW
;
int
ph
=
(
index
/
pooledW
)
%
pooledH
;
int
pd
=
(
index
/
pooledW
/
pooledH
)
%
pooledD
;
int
c
=
(
index
/
pooledW
/
pooledH
/
pooledD
)
%
channels
;
int
frameNum
=
index
/
pooledW
/
pooledH
/
pooledD
/
channels
;
int
dstart
=
pd
*
strideD
-
padD
;
int
hstart
=
ph
*
strideH
-
padH
;
int
wstart
=
pw
*
strideW
-
padW
;
int
dend
=
min
(
dstart
+
sizeZ
,
depth
+
padD
);
int
hend
=
min
(
hstart
+
sizeY
,
height
+
padH
);
int
wend
=
min
(
wstart
+
sizeX
,
width
+
padW
);
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
dstart
=
max
(
dstart
,
0
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
dend
=
min
(
dend
,
depth
);
hend
=
min
(
hend
,
height
);
wend
=
min
(
wend
,
width
);
real
aveval
=
0
;
inputData
+=
(
frameNum
*
channels
+
c
)
*
depth
*
height
*
width
;
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
aveval
+=
inputData
[(
d
*
height
+
h
)
*
width
+
w
];
}
}
}
int
tgtIndex
=
index
%
(
pooledW
*
pooledH
*
pooledD
*
channels
)
+
frameNum
*
tgtStride
;
tgtData
[
tgtIndex
]
=
aveval
/
pool_size
;
}
}
void
hl_avgpool3D_forward
(
const
int
frameCnt
,
const
real
*
inputData
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
pooledD
,
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingD
,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
)
{
int
num_kernels
=
pooledD
*
pooledH
*
pooledW
*
channels
*
frameCnt
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
KeAvgPool3DForward
<<<
blocks
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
num_kernels
,
inputData
,
channels
,
depth
,
height
,
width
,
pooledD
,
pooledH
,
pooledW
,
sizeZ
,
sizeY
,
sizeX
,
strideD
,
strideH
,
strideW
,
paddingD
,
paddingH
,
paddingW
,
tgtData
,
tgtStride
);
CHECK_SYNC
(
"hl_avgpool3D_forward failed"
);
}
__global__
void
KeAvgPool3DBackward
(
const
int
nthreads
,
const
real
*
outGrad
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
pooledD
,
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
padD
,
const
int
padH
,
const
int
padW
,
real
scaleA
,
real
scaleB
,
real
*
tgtGrad
,
const
int
outStride
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
(
nthreads
);
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
offsetW
=
index
%
width
+
padW
;
int
offsetH
=
(
index
/
width
)
%
height
+
padH
;
int
offsetD
=
(
index
/
width
/
height
)
%
depth
+
padD
;
int
offsetC
=
(
index
/
width
/
height
/
depth
)
%
channels
;
int
frameNum
=
index
/
width
/
height
/
depth
/
channels
;
int
pdstart
=
(
offsetD
<
sizeZ
)
?
0
:
(
offsetD
-
sizeZ
)
/
strideD
+
1
;
int
phstart
=
(
offsetH
<
sizeY
)
?
0
:
(
offsetH
-
sizeY
)
/
strideH
+
1
;
int
pwstart
=
(
offsetW
<
sizeX
)
?
0
:
(
offsetW
-
sizeX
)
/
strideW
+
1
;
int
pdend
=
min
(
offsetD
/
strideD
+
1
,
pooledD
);
int
phend
=
min
(
offsetH
/
strideH
+
1
,
pooledH
);
int
pwend
=
min
(
offsetW
/
strideW
+
1
,
pooledW
);
real
gradient
=
0
;
outGrad
+=
(
frameNum
*
channels
+
offsetC
)
*
pooledD
*
pooledH
*
pooledW
;
for
(
int
pd
=
pdstart
;
pd
<
pdend
;
++
pd
)
{
for
(
int
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
int
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
// figure out the pooling size
int
dstart
=
pd
*
strideD
-
padD
;
int
hstart
=
ph
*
strideH
-
padH
;
int
wstart
=
pw
*
strideW
-
padW
;
int
dend
=
min
(
dstart
+
sizeZ
,
depth
+
padD
);
int
hend
=
min
(
hstart
+
sizeY
,
height
+
padH
);
int
wend
=
min
(
wstart
+
sizeX
,
width
+
padW
);
int
poolsize
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
gradient
+=
outGrad
[(
pd
*
pooledH
+
ph
)
*
pooledW
+
pw
]
/
poolsize
;
}
}
}
tgtGrad
[
index
]
=
scaleA
*
gradient
+
scaleB
*
tgtGrad
[
index
];
}
}
void
hl_avgpool3D_backward
(
const
int
frameCnt
,
const
real
*
outGrad
,
const
int
channels
,
const
int
depth
,
const
int
height
,
const
int
width
,
const
int
outputD
,
const
int
outputH
,
const
int
outputW
,
const
int
sizeZ
,
const
int
sizeY
,
const
int
sizeX
,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
scaleA
,
real
scaleB
,
real
*
backGrad
,
const
int
outStride
)
{
int
num_kernels
=
depth
*
height
*
width
*
channels
*
frameCnt
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
KeAvgPool3DBackward
<<<
blocks
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
num_kernels
,
outGrad
,
channels
,
depth
,
height
,
width
,
outputD
,
outputH
,
outputW
,
sizeZ
,
sizeY
,
sizeX
,
strideD
,
strideH
,
strideW
,
paddingD
,
paddingH
,
paddingW
,
scaleA
,
scaleB
,
backGrad
,
outStride
);
CHECK_SYNC
(
"hl_avgpool3D_backward failed"
);
}
/////////////////
__global__
void
KeBilinearInterpFw
(
const
real
*
in
,
const
size_t
inImgH
,
const
size_t
inImgW
,
...
...
paddle/gserver/layers/Pool3DLayer.cpp
0 → 100644
浏览文件 @
2377d719
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Pool3DLayer.h"
#include "PoolProjectionLayer.h"
#include "paddle/utils/Logging.h"
namespace
paddle
{
REGISTER_LAYER
(
pool3d
,
Pool3DLayer
);
bool
Pool3DLayer
::
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
{
Layer
::
init
(
layerMap
,
parameterMap
);
/* the size of inputs for pool-layer is 1 */
CHECK_EQ
(
config_
.
inputs_size
(),
1
);
const
PoolConfig
&
conf
=
config_
.
inputs
(
0
).
pool_conf
();
poolType_
=
conf
.
pool_type
();
channels_
=
conf
.
channels
();
sizeX_
=
conf
.
size_x
();
sizeY_
=
conf
.
size_y
();
sizeZ_
=
conf
.
size_z
();
strideW_
=
conf
.
stride
();
strideH_
=
conf
.
stride_y
();
strideD_
=
conf
.
stride_z
();
imgSizeW_
=
conf
.
img_size
();
imgSizeH_
=
conf
.
img_size_y
();
imgSizeD_
=
conf
.
img_size_z
();
paddingW_
=
conf
.
padding
();
paddingH_
=
conf
.
padding_y
();
paddingD_
=
conf
.
padding_z
();
outputW_
=
conf
.
output_x
();
outputH_
=
conf
.
output_y
();
outputD_
=
conf
.
output_z
();
return
true
;
}
size_t
Pool3DLayer
::
getSize
()
{
CHECK_EQ
(
inputLayers_
.
size
(),
1UL
);
size_t
layerSize
=
0
;
// imgSizeD_ = inputLayers_[0]->getOutput().getFrameDepth();
// imgSizeH_ = inputLayers_[0]->getOutput().getFrameHeight();
// imgSizeW_ = inputLayers_[0]->getOutput().getFrameWidth();
if
(
imgSizeH_
==
0
)
{
// imgSizeH_ = imgSizeY_;
}
if
(
imgSizeW_
==
0
)
{
// imgSizeW_ = imgSize_;
}
outputD_
=
outputSize
(
imgSizeD_
,
sizeZ_
,
paddingD_
,
strideD_
,
/* caffeMode */
false
);
outputH_
=
outputSize
(
imgSizeH_
,
sizeY_
,
paddingH_
,
strideH_
,
/* caffeMode */
false
);
outputW_
=
outputSize
(
imgSizeW_
,
sizeX_
,
paddingW_
,
strideW_
,
/* caffeMode */
false
);
layerSize
=
outputD_
*
outputH_
*
outputW_
*
channels_
;
getOutput
().
setFrameHeight
(
outputH_
);
getOutput
().
setFrameWidth
(
outputW_
);
getOutput
().
setFrameDepth
(
outputD_
);
return
layerSize
;
}
void
Pool3DLayer
::
forward
(
PassType
passType
)
{
Layer
::
forward
(
passType
);
const
MatrixPtr
&
inMat
=
inputLayers_
[
0
]
->
getOutputValue
();
int
batchSize
=
inMat
->
getHeight
();
int
outWidth
=
getSize
();
resetOutput
(
batchSize
,
outWidth
);
const
MatrixPtr
outMat
=
getOutputValue
();
if
(
poolType_
==
"avg"
)
{
outMat
->
avgPool3DForward
(
*
inMat
,
imgSizeD_
,
imgSizeH_
,
imgSizeW_
,
channels_
,
sizeZ_
,
sizeY_
,
sizeX_
,
strideD_
,
strideH_
,
strideW_
,
outputD_
,
outputH_
,
outputW_
,
paddingD_
,
paddingH_
,
paddingW_
);
}
else
if
(
poolType_
==
"max"
)
{
outMat
->
maxPool3DForward
(
*
inMat
,
imgSizeD_
,
imgSizeH_
,
imgSizeW_
,
channels_
,
sizeZ_
,
sizeY_
,
sizeX_
,
strideD_
,
strideH_
,
strideW_
,
outputD_
,
outputH_
,
outputW_
,
paddingD_
,
paddingH_
,
paddingW_
);
}
else
{
LOG
(
FATAL
)
<<
"Unknown pool type: "
<<
poolType_
;
}
forwardActivation
();
}
void
Pool3DLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
backwardActivation
();
(
void
)
callback
;
if
(
NULL
==
getInputGrad
(
0
))
return
;
MatrixPtr
inMat
=
inputLayers_
[
0
]
->
getOutputValue
();
MatrixPtr
inGradMat
=
inputLayers_
[
0
]
->
getOutputGrad
();
MatrixPtr
outMat
=
getOutputValue
();
MatrixPtr
outGradMat
=
getOutputGrad
();
if
(
poolType_
==
"avg"
)
{
inGradMat
->
avgPool3DBackward
(
*
outGradMat
,
imgSizeD_
,
imgSizeH_
,
imgSizeW_
,
sizeZ_
,
sizeY_
,
sizeZ_
,
strideD_
,
strideH_
,
strideW_
,
outputD_
,
outputH_
,
outputW_
,
1
,
1
,
paddingD_
,
paddingH_
,
paddingW_
);
}
else
if
(
poolType_
==
"max"
)
{
inGradMat
->
maxPool3DBackward
(
*
inMat
,
imgSizeD_
,
imgSizeH_
,
imgSizeW_
,
*
outGradMat
,
*
outMat
,
sizeZ_
,
sizeY_
,
sizeZ_
,
strideD_
,
strideH_
,
strideW_
,
outputD_
,
outputH_
,
outputW_
,
1
,
1
,
paddingD_
,
paddingH_
,
paddingW_
);
}
else
{
LOG
(
FATAL
)
<<
"Unknown pool type: "
<<
poolType_
;
}
}
}
// namespace paddle
paddle/gserver/layers/Pool3DLayer.h
0 → 100644
浏览文件 @
2377d719
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "Layer.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
namespace
paddle
{
/**
* @brief Basic parent layer of pooling
* Pools the input within regions
*/
class
Pool3DLayer
:
public
Layer
{
public:
explicit
Pool3DLayer
(
const
LayerConfig
&
config
)
:
Layer
(
config
)
{}
~
Pool3DLayer
()
{}
bool
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
override
;
void
forward
(
PassType
passType
)
override
;
void
backward
(
const
UpdateCallback
&
callback
)
override
;
size_t
getSize
();
protected:
int
channels_
;
int
sizeX_
,
sizeY_
,
sizeZ_
;
int
strideW_
,
strideH_
,
strideD_
;
int
paddingW_
,
paddingH_
,
paddingD_
;
int
imgSizeW_
,
imgSizeH_
,
imgSizeD_
;
int
outputW_
,
outputH_
,
outputD_
;
std
::
string
poolType_
;
};
}
// namespace paddle
paddle/gserver/tests/test_LayerGrad.cpp
浏览文件 @
2377d719
...
...
@@ -1206,6 +1206,75 @@ TEST(Layer, PoolLayer) {
#endif
}
void
setPool3DConfig
(
TestConfig
*
config
,
PoolConfig
*
pool
,
const
string
&
poolType
)
{
// filter size
const
int
NUM_FILTERS
=
16
;
const
int
FILTER_SIZE
=
3
;
const
int
FILTER_SIZE_Y
=
3
;
const
int
FILTER_SIZE_Z
=
3
;
const
int
CHANNELS
=
16
;
(
*
config
).
biasSize
=
0
;
(
*
config
).
layerConfig
.
set_type
(
"pool3d"
);
(
*
config
).
layerConfig
.
set_num_filters
(
NUM_FILTERS
);
int
kw
=
FILTER_SIZE
,
kh
=
FILTER_SIZE_Y
,
kd
=
FILTER_SIZE_Z
;
int
pw
=
0
,
ph
=
0
,
pd
=
0
;
int
sw
=
2
,
sh
=
2
,
sd
=
2
;
pool
->
set_pool_type
(
poolType
);
pool
->
set_pool_type
(
"avg"
);
pool
->
set_channels
(
CHANNELS
);
pool
->
set_size_x
(
kw
);
pool
->
set_size_y
(
kh
);
pool
->
set_size_z
(
kd
);
pool
->
set_padding
(
0
);
pool
->
set_padding_y
(
0
);
pool
->
set_padding_z
(
0
);
pool
->
set_stride
(
sw
);
pool
->
set_stride_y
(
sh
);
pool
->
set_stride_z
(
sd
);
pool
->
set_start
(
0
);
int
ow
=
outputSize
(
pool
->
img_size
(),
kw
,
pw
,
sw
,
/* caffeMode */
false
);
int
oh
=
outputSize
(
pool
->
img_size_y
(),
kh
,
ph
,
sh
,
/* caffeMode */
false
);
int
od
=
outputSize
(
pool
->
img_size_z
(),
kd
,
pd
,
sd
,
/* caffeMode */
false
);
pool
->
set_output_x
(
ow
);
pool
->
set_output_y
(
oh
);
pool
->
set_output_z
(
od
);
}
void
testPool3DLayer
(
const
string
&
poolType
,
bool
trans
,
bool
useGpu
)
{
TestConfig
config
;
config
.
inputDefs
.
push_back
({
INPUT_DATA
,
"layer_0"
,
11664
,
0
});
LayerInputConfig
*
input
=
config
.
layerConfig
.
add_inputs
();
PoolConfig
*
pool
=
input
->
mutable_pool_conf
();
const
int
IMAGE_SIZE
=
9
;
const
int
IMAGE_SIZE_Y
=
9
;
const
int
IMAGE_SIZE_Z
=
9
;
pool
->
set_img_size
(
IMAGE_SIZE
);
pool
->
set_img_size_y
(
IMAGE_SIZE_Y
);
pool
->
set_img_size_z
(
IMAGE_SIZE_Z
);
setPool3DConfig
(
&
config
,
pool
,
poolType
);
config
.
layerConfig
.
set_size
(
pool
->
output_x
()
*
pool
->
output_y
()
*
pool
->
channels
());
testLayerGrad
(
config
,
"pool3d"
,
100
,
trans
,
useGpu
);
}
TEST
(
Layer
,
Pool3DLayer
)
{
testPool3DLayer
(
"avg"
,
/* trans= */
false
,
/* useGpu= */
false
);
testPool3DLayer
(
"max"
,
/* trans= */
false
,
/* useGpu= */
false
);
#ifndef PADDLE_ONLY_CPU
testPool3DLayer
(
"avg"
,
/* trans= */
false
,
/* useGpu= */
true
);
testPool3DLayer
(
"max"
,
/* trans= */
false
,
/* useGpu= */
true
);
#endif
}
void
testSppLayer
(
const
string
&
poolType
,
const
int
pyramidHeight
,
bool
trans
,
...
...
paddle/math/Matrix.cpp
浏览文件 @
2377d719
...
...
@@ -1190,6 +1190,224 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad,
outGrad
.
getStride
());
}
void
GpuMatrix
::
maxPool3DForward
(
Matrix
&
inputMat
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
)
{
CHECK
(
inputMat
.
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
real
*
inputData
=
inputMat
.
getData
();
size_t
num
=
inputMat
.
getHeight
();
size_t
width
=
imgSizeW
;
size_t
height
=
imgSizeH
;
size_t
depth
=
imgSizeD
;
CHECK
(
depth
*
height
*
width
*
channels
==
inputMat
.
getWidth
());
CHECK
(
height_
==
inputMat
.
getHeight
());
CHECK
(
width_
==
outputD
*
outputH
*
outputW
*
channels
);
hl_maxpool3D_forward
(
num
,
inputData
,
channels
,
depth
,
height
,
width
,
outputD
,
outputH
,
outputW
,
sizeZ
,
sizeY
,
sizeX
,
strideD
,
strideH
,
strideW
,
paddingD
,
paddingH
,
paddingW
,
data_
,
getStride
());
}
void
GpuMatrix
::
maxPool3DBackward
(
Matrix
&
inputMat
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
Matrix
&
outGrad
,
Matrix
&
outV
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
)
{
CHECK
(
inputMat
.
useGpu_
==
true
&&
outGrad
.
useGpu_
==
true
&&
outV
.
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
real
*
inputData
=
inputMat
.
getData
();
real
*
outData
=
outV
.
getData
();
real
*
outDiff
=
outGrad
.
getData
();
size_t
frameNum
=
inputMat
.
getHeight
();
size_t
channels
=
outV
.
getWidth
()
/
outputD
/
outputH
/
outputW
;
size_t
width
=
imgSizeW
;
size_t
height
=
imgSizeH
;
size_t
depth
=
imgSizeD
;
CHECK
(
depth
*
height
*
width
*
channels
==
inputMat
.
getWidth
());
CHECK
(
height_
==
inputMat
.
getHeight
());
CHECK
(
width_
==
depth
*
width
*
height
*
channels
);
CHECK
(
outGrad
.
getHeight
()
==
outV
.
getHeight
()
&&
outGrad
.
getWidth
()
==
outV
.
getWidth
());
hl_maxpool3D_backward
(
frameNum
,
inputData
,
outData
,
outDiff
,
channels
,
depth
,
height
,
width
,
outputD
,
outputH
,
outputW
,
sizeZ
,
sizeY
,
sizeX
,
strideD
,
strideH
,
strideW
,
paddingD
,
paddingH
,
paddingW
,
scaleTargets
,
scaleOutput
,
data_
,
outGrad
.
getStride
());
}
void
GpuMatrix
::
avgPool3DForward
(
Matrix
&
inputMat
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
)
{
CHECK
(
inputMat
.
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
real
*
inputData
=
inputMat
.
getData
();
size_t
frameNum
=
inputMat
.
getHeight
();
size_t
height
=
imgSizeH
;
size_t
width
=
imgSizeW
;
size_t
depth
=
imgSizeD
;
CHECK
(
depth
*
height
*
width
*
channels
==
inputMat
.
getWidth
());
CHECK
(
height_
==
inputMat
.
getHeight
());
CHECK
(
width_
==
outputD
*
outputH
*
outputW
*
channels
);
hl_avgpool3D_forward
(
frameNum
,
inputData
,
channels
,
depth
,
height
,
width
,
outputD
,
outputH
,
outputW
,
sizeZ
,
sizeY
,
sizeX
,
strideD
,
strideH
,
strideW
,
paddingD
,
paddingH
,
paddingW
,
data_
,
getStride
());
}
void
GpuMatrix
::
avgPool3DBackward
(
Matrix
&
outGrad
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
)
{
CHECK
(
outGrad
.
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
real
*
outDiff
=
outGrad
.
getData
();
size_t
frameNum
=
outGrad
.
getHeight
();
size_t
channels
=
outGrad
.
getWidth
()
/
outputD
/
outputH
/
outputW
;
size_t
height
=
imgSizeH
;
size_t
width
=
imgSizeW
;
size_t
depth
=
imgSizeD
;
CHECK
(
depth
*
height
*
width
*
channels
==
width_
);
CHECK
(
height_
==
outGrad
.
getHeight
());
CHECK
(
outGrad
.
getWidth
()
==
outputD
*
outputH
*
outputW
*
channels
);
hl_avgpool3D_backward
(
frameNum
,
outDiff
,
channels
,
depth
,
height
,
width
,
outputD
,
outputH
,
outputW
,
sizeZ
,
sizeY
,
sizeX
,
strideD
,
strideH
,
strideW
,
paddingD
,
paddingH
,
paddingW
,
scaleTargets
,
scaleOutput
,
data_
,
outGrad
.
getStride
());
}
void
GpuMatrix
::
maxSequenceForward
(
Matrix
&
input
,
const
IVector
&
sequence
,
IVector
&
index
)
{
...
...
@@ -1930,6 +2148,290 @@ void CpuMatrix::avgPoolBackward(Matrix& input,
}
}
void
CpuMatrix
::
maxPool3DForward
(
Matrix
&
inputMat
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
)
{
real
*
inputData
=
inputMat
.
getData
();
real
*
outData
=
data_
;
size_t
num
=
inputMat
.
getHeight
();
size_t
inWidth
=
imgSizeW
;
size_t
inHeight
=
imgSizeH
;
size_t
inDepth
=
imgSizeD
;
CHECK
(
inHeight
*
inWidth
*
inDepth
==
inputMat
.
getWidth
()
/
channels
);
CHECK_EQ
(
num
,
this
->
getHeight
());
CHECK_EQ
(
channels
*
outputH
*
outputW
*
outputD
,
this
->
getWidth
());
size_t
outStride
=
getStride
();
/* initialize the data_ */
for
(
size_t
i
=
0
;
i
<
height_
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
width_
;
j
++
)
{
outData
[(
i
)
*
outStride
+
j
]
=
-
(
real
)
FLT_MAX
;
}
}
/* pool max one by one */
for
(
size_t
n
=
0
;
n
<
num
;
++
n
)
{
// frame by frame
if
(
!
isContiguous
())
{
outData
=
data_
+
n
*
outStride
;
}
for
(
size_t
c
=
0
;
c
<
channels
;
++
c
)
{
// channel by channel
for
(
size_t
pd
=
0
;
pd
<
outputD
;
++
pd
)
{
for
(
size_t
ph
=
0
;
ph
<
outputH
;
++
ph
)
{
for
(
size_t
pw
=
0
;
pw
<
outputW
;
++
pw
)
{
int
dstart
=
pd
*
strideD
-
paddingD
;
int
hstart
=
ph
*
strideH
-
paddingH
;
int
wstart
=
pw
*
strideW
-
paddingW
;
int
dend
=
std
::
min
(
dstart
+
sizeZ
,
inDepth
);
int
hend
=
std
::
min
(
hstart
+
sizeY
,
inHeight
);
int
wend
=
std
::
min
(
wstart
+
sizeX
,
inWidth
);
dstart
=
std
::
max
(
dstart
,
0
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
outData
[(
pd
*
outputH
+
ph
)
*
outputW
+
pw
]
=
std
::
max
(
outData
[(
pd
*
outputH
+
ph
)
*
outputW
+
pw
],
inputData
[(
d
*
inHeight
+
h
)
*
inWidth
+
w
]);
}
}
}
}
}
}
// compute offset
inputData
+=
inDepth
*
inHeight
*
inWidth
;
outData
+=
outputD
*
outputH
*
outputW
;
}
}
}
void
CpuMatrix
::
maxPool3DBackward
(
Matrix
&
image
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
Matrix
&
outGrad
,
Matrix
&
outV
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
)
{
size_t
num
=
image
.
getHeight
();
size_t
channels
=
size_t
(
width_
/
imgSizeD
/
imgSizeH
/
imgSizeW
);
CHECK
(
image
.
getWidth
()
==
imgSizeD
*
imgSizeH
*
imgSizeW
*
channels
);
CHECK
(
image
.
getHeight
()
==
height_
&&
image
.
getWidth
()
==
width_
);
CHECK
(
outV
.
getHeight
()
==
outGrad
.
getHeight
()
&&
outV
.
getWidth
()
==
outGrad
.
getWidth
());
real
*
tgtGrad
=
data_
;
real
*
inData
=
image
.
getData
();
real
*
otData
=
outV
.
getData
();
real
*
otGrad
=
outGrad
.
getData
();
size_t
outStride
=
outV
.
getStride
();
real
*
origOutData
=
otData
;
real
*
origOutGrad
=
otGrad
;
for
(
size_t
n
=
0
;
n
<
num
;
++
n
)
{
if
(
!
outV
.
isContiguous
())
{
otData
=
origOutData
+
n
*
outStride
;
otGrad
=
origOutGrad
+
n
*
outStride
;
}
for
(
size_t
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
size_t
pd
=
0
;
pd
<
outputD
;
++
pd
)
{
for
(
size_t
ph
=
0
;
ph
<
outputH
;
++
ph
)
{
for
(
size_t
pw
=
0
;
pw
<
outputW
;
++
pw
)
{
int
dstart
=
pd
*
strideD
-
paddingD
;
int
hstart
=
ph
*
strideH
-
paddingH
;
int
wstart
=
pw
*
strideW
-
paddingW
;
int
dend
=
std
::
min
(
dstart
+
sizeZ
,
imgSizeD
);
int
hend
=
std
::
min
(
hstart
+
sizeY
,
imgSizeH
);
int
wend
=
std
::
min
(
wstart
+
sizeX
,
imgSizeW
);
dstart
=
std
::
max
(
dstart
,
0
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
for
(
int
d
=
0
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
tgtGrad
[(
d
*
imgSizeH
+
h
)
*
imgSizeW
+
w
]
=
scaleTargets
*
tgtGrad
[(
d
*
imgSizeH
+
h
)
*
imgSizeW
+
w
]
+
scaleOutput
*
otGrad
[(
pd
*
outputH
+
ph
)
*
outputW
+
pw
]
*
(
inData
[(
d
*
imgSizeH
+
h
)
*
imgSizeW
+
w
]
==
otData
[(
pd
*
outputH
+
ph
)
*
outputW
+
pw
]);
}
}
}
}
}
}
// offset
inData
+=
imgSizeD
*
imgSizeH
*
imgSizeW
;
tgtGrad
+=
imgSizeD
*
imgSizeH
*
imgSizeW
;
otData
+=
outputD
*
outputH
*
outputW
;
otGrad
+=
outputD
*
outputH
*
outputW
;
}
}
}
void
CpuMatrix
::
avgPool3DForward
(
Matrix
&
input
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
)
{
// The main loop
size_t
num
=
input
.
getHeight
();
size_t
inDepth
=
imgSizeD
;
size_t
inHeight
=
imgSizeH
;
size_t
inWidth
=
imgSizeW
;
CHECK
(
inDepth
*
inHeight
*
inWidth
*
channels
==
input
.
getWidth
());
CHECK
(
outputD
*
outputH
*
outputW
*
channels
*
num
==
height_
*
width_
);
real
*
tgtData
=
data_
;
real
*
inData
=
input
.
getData
();
for
(
size_t
n
=
0
;
n
<
num
;
++
n
)
{
if
(
!
isContiguous
())
{
tgtData
=
data_
+
n
*
getStride
();
}
for
(
size_t
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
size_t
pd
=
0
;
pd
<
outputD
;
++
pd
)
{
for
(
size_t
ph
=
0
;
ph
<
outputH
;
++
ph
)
{
for
(
size_t
pw
=
0
;
pw
<
outputW
;
++
pw
)
{
int
dstart
=
pd
*
strideD
-
paddingD
;
int
hstart
=
ph
*
strideH
-
paddingH
;
int
wstart
=
pw
*
strideW
-
paddingW
;
int
dend
=
std
::
min
(
dstart
+
sizeZ
,
inDepth
+
paddingD
);
int
hend
=
std
::
min
(
hstart
+
sizeY
,
inHeight
+
paddingH
);
int
wend
=
std
::
min
(
wstart
+
sizeX
,
inWidth
+
paddingW
);
int
poolSize
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
dstart
=
std
::
max
(
dstart
,
0
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
dend
=
std
::
min
(
dend
,
static_cast
<
int
>
(
inDepth
));
hend
=
std
::
min
(
hend
,
static_cast
<
int
>
(
inHeight
));
wend
=
std
::
min
(
wend
,
static_cast
<
int
>
(
inWidth
));
CHECK
(
poolSize
);
tgtData
[(
pd
*
outputH
+
ph
)
*
outputW
+
pw
]
=
0
;
// clear
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
tgtData
[(
pd
*
outputH
+
ph
)
*
outputW
+
pw
]
+=
inData
[(
d
*
inHeight
+
h
)
*
inWidth
+
w
];
}
}
}
tgtData
[(
pd
*
outputH
+
ph
)
*
outputW
+
pw
]
/=
poolSize
;
}
}
}
// compute offset
inData
+=
inDepth
*
inHeight
*
inWidth
;
tgtData
+=
outputD
*
outputH
*
outputW
;
}
}
}
void
CpuMatrix
::
avgPool3DBackward
(
Matrix
&
input
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
)
{
size_t
num
=
input
.
getHeight
();
size_t
channels
=
input
.
getWidth
()
/
outputD
/
outputH
/
outputW
;
CHECK
(
imgSizeD
*
imgSizeH
*
imgSizeW
*
channels
==
getWidth
());
real
*
inData
=
input
.
getData
();
real
*
outData
=
getData
();
for
(
size_t
n
=
0
;
n
<
num
;
++
n
)
{
if
(
!
input
.
isContiguous
())
{
inData
=
input
.
getData
()
+
n
*
input
.
getStride
();
}
for
(
size_t
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
size_t
pd
=
0
;
pd
<
outputD
;
++
pd
)
{
for
(
size_t
ph
=
0
;
ph
<
outputH
;
++
ph
)
{
for
(
size_t
pw
=
0
;
pw
<
outputW
;
++
pw
)
{
int
dstart
=
pd
*
strideD
-
paddingD
;
int
hstart
=
ph
*
strideH
-
paddingH
;
int
wstart
=
pw
*
strideW
-
paddingW
;
int
dend
=
std
::
min
(
dstart
+
sizeZ
,
imgSizeD
+
paddingD
);
int
hend
=
std
::
min
(
hstart
+
sizeY
,
imgSizeH
+
paddingH
);
int
wend
=
std
::
min
(
wstart
+
sizeX
,
imgSizeW
+
paddingW
);
int
poolSize
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
dstart
=
std
::
max
(
dstart
,
0
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
dend
=
std
::
min
(
dend
,
static_cast
<
int
>
(
imgSizeD
));
hend
=
std
::
min
(
hend
,
static_cast
<
int
>
(
imgSizeH
));
wend
=
std
::
min
(
wend
,
static_cast
<
int
>
(
imgSizeW
));
CHECK
(
poolSize
);
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
outData
[(
d
*
imgSizeH
+
h
)
*
imgSizeW
+
w
]
+=
inData
[(
pd
*
outputH
+
ph
)
*
outputW
+
pw
]
/
poolSize
;
}
}
}
}
}
}
// offset
outData
+=
imgSizeD
*
imgSizeH
*
imgSizeW
;
inData
+=
outputD
*
outputH
*
outputW
;
}
}
}
/**
* Input: one or more sequences. Each sequence contains some instances.
* Output: output size is the number of input sequences (NOT input instances).
...
...
paddle/math/Matrix.h
浏览文件 @
2377d719
...
...
@@ -928,15 +928,102 @@ public:
size_t
paddingW
)
{
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
/**
* Input: one or more sequences. Each sequence contains some instances.
*
* Output: output size is the number of input sequences (NOT input
* instances).
*
* output[i] is set to max_input[i].
* Pooling 3D forward operation, pick out the largest element
* in the sizeX of value
*/
virtual
void
maxPool3DForward
(
Matrix
&
inputMat
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
)
{
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
virtual
void
maxPool3DBackward
(
Matrix
&
image
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
Matrix
&
outGrad
,
Matrix
&
outV
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
)
{
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
virtual
void
avgPool3DForward
(
Matrix
&
input
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
)
{
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
virtual
void
avgPool3DBackward
(
Matrix
&
input
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
)
{
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
/**
* Input: one or more sequences. Each sequence contains some instances.
*
* Output: output size is the number of input sequences (NOT input
* instances).
*
* output[i] is set to max_input[i].
*/
virtual
void
maxSequenceForward
(
Matrix
&
input
,
const
IVector
&
sequence
,
IVector
&
index
)
{
...
...
@@ -1348,6 +1435,83 @@ public:
size_t
paddingH
,
size_t
paddingW
);
/////////////////////////
void
maxPool3DForward
(
Matrix
&
inputMat
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
);
void
maxPool3DBackward
(
Matrix
&
image
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
Matrix
&
outGrad
,
Matrix
&
outV
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
);
void
avgPool3DForward
(
Matrix
&
input
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
);
void
avgPool3DBackward
(
Matrix
&
input
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
);
void
maxSequenceForward
(
Matrix
&
input
,
const
IVector
&
sequence
,
IVector
&
index
);
...
...
@@ -1506,6 +1670,82 @@ public:
real
scaleOutput
,
size_t
paddingH
,
size_t
paddingW
);
//////////////////////
void
maxPool3DForward
(
Matrix
&
inputMat
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
);
void
maxPool3DBackward
(
Matrix
&
image
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
Matrix
&
outGrad
,
Matrix
&
outV
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
);
void
avgPool3DForward
(
Matrix
&
input
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
);
void
avgPool3DBackward
(
Matrix
&
input
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
sizeZ
,
size_t
sizeY
,
size_t
sizeX
,
size_t
strideD
,
size_t
strideH
,
size_t
strideW
,
size_t
outputD
,
size_t
outputH
,
size_t
outputW
,
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingD
,
size_t
paddingH
,
size_t
paddingW
);
void
maxSequenceForward
(
Matrix
&
input
,
const
IVector
&
sequence
,
...
...
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
2377d719
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include "TensorCheck.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h"
#include "paddle/testing/TestUtil.h"
...
...
@@ -1203,4 +1204,207 @@ TEST(Matrix, warpCTC) {
}
}
/////
void
testMatrixPool3D
(
int
depth
,
int
height
,
int
width
)
{
int
channel
=
3
;
int
filterX
=
3
,
filterY
=
4
,
filterZ
=
5
;
int
strideX
=
2
,
strideY
=
2
,
strideZ
=
2
;
int
padX
=
1
,
padY
=
1
,
padZ
=
1
;
MatrixPtr
cpuImage
=
std
::
make_shared
<
CpuMatrix
>
(
1
,
channel
*
depth
*
height
*
width
);
MatrixPtr
gpuImage
=
std
::
make_shared
<
GpuMatrix
>
(
1
,
channel
*
depth
*
height
*
width
);
int
outD
=
outputSize
(
depth
,
filterZ
,
padZ
,
strideZ
,
true
);
int
outH
=
outputSize
(
height
,
filterY
,
padZ
,
strideY
,
true
);
int
outW
=
outputSize
(
width
,
filterX
,
padZ
,
strideX
,
true
);
int
colBufWidth
=
outD
*
outH
*
outW
;
MatrixPtr
cpuOutput
=
std
::
make_shared
<
CpuMatrix
>
(
1
,
channel
*
colBufWidth
);
MatrixPtr
gpuOutput
=
std
::
make_shared
<
GpuMatrix
>
(
1
,
channel
*
colBufWidth
);
cpuImage
->
randomizeUniform
();
gpuImage
->
copyFrom
(
*
cpuImage
);
// std::cout << "test maxPool3DForward...\n";
cpuOutput
->
maxPool3DForward
(
*
cpuImage
,
depth
,
height
,
width
,
channel
,
filterZ
,
filterY
,
filterX
,
strideZ
,
strideY
,
strideX
,
outD
,
outH
,
outW
,
padZ
,
padY
,
padX
);
gpuOutput
->
maxPool3DForward
(
*
gpuImage
,
depth
,
height
,
width
,
channel
,
filterZ
,
filterY
,
filterX
,
strideZ
,
strideY
,
strideX
,
outD
,
outH
,
outW
,
padZ
,
padY
,
padX
);
TensorCheckErr
(
*
cpuOutput
,
*
gpuOutput
);
cpuImage
->
randomizeUniform
();
gpuImage
->
copyFrom
(
*
cpuImage
);
// std::cout << "test avgPool3DForward...\n";
cpuOutput
->
avgPool3DForward
(
*
cpuImage
,
depth
,
height
,
width
,
channel
,
filterZ
,
filterY
,
filterX
,
strideZ
,
strideY
,
strideX
,
outD
,
outH
,
outW
,
padZ
,
padY
,
padX
);
gpuOutput
->
avgPool3DForward
(
*
gpuImage
,
depth
,
height
,
width
,
channel
,
filterZ
,
filterY
,
filterX
,
strideZ
,
strideY
,
strideX
,
outD
,
outH
,
outW
,
padZ
,
padY
,
padX
);
TensorCheckErr
(
*
cpuOutput
,
*
gpuOutput
);
cpuImage
->
randomizeUniform
();
gpuImage
->
copyFrom
(
*
cpuImage
);
cpuOutput
->
randomizeUniform
();
gpuOutput
->
copyFrom
(
*
cpuOutput
);
// std::cout << "test avgPool3DBackward...\n";
cpuImage
->
avgPool3DBackward
(
*
cpuOutput
,
depth
,
height
,
width
,
filterZ
,
filterY
,
filterX
,
strideZ
,
strideY
,
strideX
,
outD
,
outH
,
outW
,
1
,
1
,
padZ
,
padY
,
padX
);
gpuImage
->
avgPool3DBackward
(
*
gpuOutput
,
depth
,
height
,
width
,
filterZ
,
filterY
,
filterX
,
strideZ
,
strideY
,
strideX
,
outD
,
outH
,
outW
,
1
,
1
,
padZ
,
padY
,
padX
);
TensorCheckErr
(
*
cpuImage
,
*
gpuImage
);
cpuImage
->
randomizeUniform
();
gpuImage
->
copyFrom
(
*
cpuImage
);
cpuOutput
->
randomizeUniform
();
gpuOutput
->
copyFrom
(
*
cpuOutput
);
// std::cout << "test maxPool3DBackward...\n";
cpuImage
->
maxPool3DBackward
(
*
cpuImage
,
depth
,
height
,
width
,
*
cpuOutput
,
*
cpuOutput
,
filterZ
,
filterY
,
filterX
,
strideZ
,
strideY
,
strideX
,
outD
,
outH
,
outW
,
1
,
1
,
padZ
,
padY
,
padX
);
gpuImage
->
maxPool3DBackward
(
*
gpuImage
,
depth
,
height
,
width
,
*
gpuOutput
,
*
gpuOutput
,
filterZ
,
filterY
,
filterX
,
strideZ
,
strideY
,
strideX
,
outD
,
outH
,
outW
,
1
,
1
,
padZ
,
padY
,
padX
);
TensorCheckErr
(
*
cpuImage
,
*
gpuImage
);
}
TEST
(
Matrix
,
Pool3D
)
{
for
(
auto
depth
:
{
9
,
16
,
64
,
128
})
{
for
(
auto
height
:
{
9
,
11
,
128
,
256
})
{
for
(
auto
width
:
{
9
,
32
,
128
})
{
VLOG
(
3
)
<<
"depth="
<<
depth
<<
" height="
<<
height
<<
" width="
<<
width
;
testMatrixPool3D
(
depth
,
height
,
width
);
}
}
}
}
#endif
paddle/parameter/Argument.cpp
浏览文件 @
2377d719
...
...
@@ -186,6 +186,7 @@ void Argument::resizeAndCopyFrom(const Argument& src,
resizeAndCopy
(
strs
,
src
.
strs
,
useGpu
,
stream
);
frameWidth
=
src
.
frameWidth
;
frameHeight
=
src
.
frameHeight
;
frameDepth
=
src
.
frameDepth
;
}
int32_t
Argument
::
resizeAndCopyFrom
(
const
Argument
&
src
,
...
...
@@ -206,6 +207,7 @@ int32_t Argument::resizeAndCopyFrom(const Argument& src,
dataId
=
src
.
dataId
;
frameWidth
=
src
.
frameWidth
;
frameHeight
=
src
.
frameHeight
;
frameDepth
=
src
.
frameDepth
;
if
(
!
src
.
sequenceStartPositions
)
{
// non-sequence input, copy samples directly
...
...
paddle/parameter/Argument.h
浏览文件 @
2377d719
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -35,6 +32,7 @@ struct Argument {
strs
(
nullptr
),
frameHeight
(
0
),
frameWidth
(
0
),
frameDepth
(
0
),
sequenceStartPositions
(
nullptr
),
subSequenceStartPositions
(
nullptr
),
cpuSequenceDims
(
nullptr
),
...
...
@@ -64,6 +62,7 @@ struct Argument {
allCount
=
argument
.
allCount
;
frameHeight
=
argument
.
frameHeight
;
frameWidth
=
argument
.
frameWidth
;
frameDepth
=
argument
.
frameDepth
;
dataId
=
argument
.
dataId
;
}
...
...
@@ -76,6 +75,7 @@ struct Argument {
// A dataBatch includes batchSize frames, one frame maybe not only vector
size_t
frameHeight
;
size_t
frameWidth
;
size_t
frameDepth
;
// If NULL, each position is treated independently.
// Otherwise, its size should be #NumberOfSequences + 1.
...
...
@@ -136,8 +136,10 @@ struct Argument {
}
size_t
getFrameHeight
()
const
{
return
frameHeight
;
}
size_t
getFrameWidth
()
const
{
return
frameWidth
;
}
size_t
getFrameDepth
()
const
{
return
frameDepth
;
}
void
setFrameHeight
(
size_t
h
)
{
frameHeight
=
h
;
}
void
setFrameWidth
(
size_t
w
)
{
frameWidth
=
w
;
}
void
setFrameDepth
(
size_t
d
)
{
frameDepth
=
d
;
}
int64_t
getNumSequences
()
const
{
return
sequenceStartPositions
?
sequenceStartPositions
->
getSize
()
-
1
...
...
proto/ModelConfig.proto
浏览文件 @
2377d719
...
...
@@ -82,6 +82,12 @@ message ConvConfig {
// if not set, use img_size
optional
uint32
img_size_y
=
14
;
optional
uint32
filter_size_z
=
15
[
default
=
1
];
optional
uint32
padding_z
=
16
[
default
=
1
];
optional
uint32
stride_z
=
17
[
default
=
1
];
optional
uint32
output_z
=
18
[
default
=
1
];
optional
uint32
img_size_z
=
19
[
default
=
1
];
}
message
PoolConfig
{
...
...
@@ -124,6 +130,12 @@ message PoolConfig {
// if not set, use padding
optional
uint32
padding_y
=
13
;
optional
uint32
size_z
=
14
[
default
=
1
];
optional
uint32
stride_z
=
15
[
default
=
1
];
optional
uint32
output_z
=
16
[
default
=
1
];
optional
uint32
img_size_z
=
17
[
default
=
1
];
optional
uint32
padding_z
=
18
[
default
=
1
];
}
message
SppConfig
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录