Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
860bf192
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
860bf192
编写于
8月 24, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add maxPoolIdx
上级
790379f1
变更
7
展开全部
隐藏空白更改
内联
并排
Showing
7 changed file
with
473 addition
and
287 deletion
+473
-287
paddle/cuda/include/hl_cnn.h
paddle/cuda/include/hl_cnn.h
+2
-2
paddle/cuda/include/stub/hl_cnn_stub.h
paddle/cuda/include/stub/hl_cnn_stub.h
+2
-2
paddle/cuda/src/hl_cuda_cnn.cu
paddle/cuda/src/hl_cuda_cnn.cu
+39
-34
paddle/gserver/layers/Pool3DLayer.cpp
paddle/gserver/layers/Pool3DLayer.cpp
+6
-5
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+35
-51
paddle/math/Matrix.h
paddle/math/Matrix.h
+9
-9
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+380
-184
未找到文件。
paddle/cuda/include/hl_cnn.h
浏览文件 @
860bf192
...
...
@@ -192,11 +192,10 @@ extern void hl_maxpool3D_forward(const int frameCnt,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
real
*
maxPoolIdxData
,
const
int
tgtStride
);
extern
void
hl_maxpool3D_backward
(
const
int
frameCnt
,
const
real
*
inputData
,
const
real
*
outData
,
const
real
*
outGrad
,
const
int
channels
,
const
int
depth
,
...
...
@@ -217,6 +216,7 @@ extern void hl_maxpool3D_backward(const int frameCnt,
real
scaleA
,
real
scaleB
,
real
*
targetGrad
,
real
*
maxPoolIdxData
,
const
int
outStride
);
extern
void
hl_avgpool3D_forward
(
const
int
frameCnt
,
...
...
paddle/cuda/include/stub/hl_cnn_stub.h
浏览文件 @
860bf192
...
...
@@ -106,11 +106,10 @@ inline void hl_maxpool3D_forward(const int frameCnt,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
real
*
maxPoolIdxData
,
const
int
tgtStride
)
{}
inline
void
hl_maxpool3D_backward
(
const
int
frameCnt
,
const
real
*
inputData
,
const
real
*
outData
,
const
real
*
outGrad
,
const
int
channels
,
const
int
depth
,
...
...
@@ -131,6 +130,7 @@ inline void hl_maxpool3D_backward(const int frameCnt,
real
scaleA
,
real
scaleB
,
real
*
targetGrad
,
real
*
maxPoolIdxData
,
const
int
outStride
)
{}
inline
void
hl_avgpool3D_forward
(
const
int
frameCnt
,
...
...
paddle/cuda/src/hl_cuda_cnn.cu
浏览文件 @
860bf192
...
...
@@ -366,10 +366,11 @@ __global__ void KeMaxPool3DForward(const int nthreads,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
offset
D
,
const
int
offset
H
,
const
int
offset
W
,
const
int
pad
D
,
const
int
pad
H
,
const
int
pad
W
,
real
*
tgtData
,
real
*
maxPoolIdxData
,
const
int
tgtStride
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
(
nthreads
);
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
...
...
@@ -378,9 +379,9 @@ __global__ void KeMaxPool3DForward(const int nthreads,
int
pd
=
(
index
/
pooledW
/
pooledH
)
%
pooledD
;
int
c
=
(
index
/
pooledW
/
pooledH
/
pooledD
)
%
channels
;
int
frameNum
=
index
/
pooledW
/
pooledH
/
pooledD
/
channels
;
int
dstart
=
pd
*
strideD
-
offset
D
;
int
hstart
=
ph
*
strideH
-
offset
H
;
int
wstart
=
pw
*
strideW
-
offset
W
;
int
dstart
=
pd
*
strideD
-
pad
D
;
int
hstart
=
ph
*
strideH
-
pad
H
;
int
wstart
=
pw
*
strideW
-
pad
W
;
int
dend
=
min
(
dstart
+
ksizeD
,
depth
);
int
hend
=
min
(
hstart
+
ksizeH
,
height
);
int
wend
=
min
(
wstart
+
ksizeW
,
width
);
...
...
@@ -388,18 +389,22 @@ __global__ void KeMaxPool3DForward(const int nthreads,
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
real
maxval
=
-
FLT_MAX
;
int
maxIdx
=
-
1
;
inputData
+=
(
frameNum
*
channels
+
c
)
*
depth
*
height
*
width
;
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
if
(
maxval
<
inputData
[(
d
*
height
+
h
)
*
width
+
w
])
if
(
maxval
<
inputData
[(
d
*
height
+
h
)
*
width
+
w
])
{
maxval
=
inputData
[(
d
*
height
+
h
)
*
width
+
w
];
maxIdx
=
(
d
*
height
+
h
)
*
width
+
w
;
}
}
}
}
int
tgtIndex
=
index
%
(
pooledW
*
pooledH
*
pooledD
*
channels
)
+
frameNum
*
tgtStride
;
tgtData
[
tgtIndex
]
=
maxval
;
maxPoolIdxData
[
tgtIndex
]
=
maxIdx
;
}
}
...
...
@@ -418,10 +423,11 @@ void hl_maxpool3D_forward(const int frameCnt,
const
int
strideD
,
const
int
strideH
,
const
int
strideW
,
const
int
pad
ding
D
,
const
int
pad
ding
H
,
const
int
pad
ding
W
,
const
int
padD
,
const
int
padH
,
const
int
padW
,
real
*
tgtData
,
real
*
maxPoolIdxData
,
const
int
tgtStride
)
{
int
num_kernels
=
pooledD
*
pooledH
*
pooledW
*
channels
*
frameCnt
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
...
...
@@ -443,17 +449,16 @@ void hl_maxpool3D_forward(const int frameCnt,
strideD
,
strideH
,
strideW
,
pad
ding
D
,
pad
ding
H
,
pad
ding
W
,
padD
,
padH
,
padW
,
tgtData
,
maxPoolIdxData
,
tgtStride
);
CHECK_SYNC
(
"hl_maxpool3D_forward failed"
);
}
__global__
void
KeMaxPool3DBackward
(
const
int
nthreads
,
const
real
*
inputData
,
const
real
*
outData
,
const
real
*
outGrad
,
const
int
channels
,
const
int
depth
,
...
...
@@ -474,33 +479,35 @@ __global__ void KeMaxPool3DBackward(const int nthreads,
real
scaleA
,
real
scaleB
,
real
*
targetGrad
,
real
*
maxPoolIdxData
,
const
int
outStride
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
(
nthreads
);
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
// find out the local index
// find out the local offset
int
offsetW
=
index
%
width
+
padW
;
int
offsetH
=
(
index
/
width
)
%
height
+
padH
;
int
offsetD
=
(
index
/
width
/
height
)
%
depth
+
padD
;
int
offsetW
=
index
%
width
;
int
offsetH
=
(
index
/
width
)
%
height
;
int
offsetD
=
(
index
/
width
/
height
)
%
depth
;
int
offsetC
=
(
index
/
width
/
height
/
depth
)
%
channels
;
int
frameNum
=
index
/
width
/
height
/
depth
/
channels
;
int
pdstart
=
(
offsetD
<
sizeZ
)
?
0
:
(
offsetD
-
sizeZ
)
/
strideD
+
1
;
int
phstart
=
(
offsetH
<
sizeY
)
?
0
:
(
offsetH
-
sizeY
)
/
strideH
+
1
;
int
pwstart
=
(
offsetW
<
sizeX
)
?
0
:
(
offsetW
-
sizeX
)
/
strideW
+
1
;
int
pdend
=
min
(
offsetD
/
strideD
+
1
,
pooledD
);
int
phend
=
min
(
offsetH
/
strideH
+
1
,
pooledH
);
int
pwend
=
min
(
offsetW
/
strideW
+
1
,
pooledW
);
int
pdstart
=
(
offsetD
+
padD
<
sizeZ
)
?
0
:
(
offsetD
+
padD
-
sizeZ
)
/
strideD
+
1
;
int
phstart
=
(
offsetH
+
padH
<
sizeY
)
?
0
:
(
offsetH
+
padH
-
sizeY
)
/
strideH
+
1
;
int
pwstart
=
(
offsetW
+
padW
<
sizeX
)
?
0
:
(
offsetW
+
padW
-
sizeX
)
/
strideW
+
1
;
int
pdend
=
min
((
offsetD
+
padD
)
/
strideD
+
1
,
pooledD
);
int
phend
=
min
((
offsetH
+
padH
)
/
strideH
+
1
,
pooledH
);
int
pwend
=
min
((
offsetW
+
padW
)
/
strideW
+
1
,
pooledW
);
real
gradient
=
0
;
real
input
=
inputData
[
index
];
outData
+=
((
frameNum
*
channels
+
offsetC
)
*
pooledD
*
pooledH
*
pooledW
);
outGrad
+=
((
frameNum
*
channels
+
offsetC
)
*
pooledD
*
pooledH
*
pooledW
);
maxPoolIdxData
+=
((
frameNum
*
channels
+
offsetC
)
*
pooledD
*
pooledH
*
pooledW
);
for
(
int
pd
=
pdstart
;
pd
<
pdend
;
++
pd
)
{
for
(
int
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
int
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
if
(
input
==
outData
[(
pd
*
pooledH
+
ph
)
*
pooledW
+
pw
])
if
(((
offsetD
*
height
+
offsetH
)
*
width
+
offsetW
)
==
maxPoolIdxData
[(
pd
*
pooledH
+
ph
)
*
pooledW
+
pw
])
gradient
+=
outGrad
[(
pd
*
pooledH
+
ph
)
*
pooledW
+
pw
];
}
}
...
...
@@ -510,8 +517,6 @@ __global__ void KeMaxPool3DBackward(const int nthreads,
}
void
hl_maxpool3D_backward
(
const
int
frameCnt
,
const
real
*
inputData
,
const
real
*
outData
,
const
real
*
outGrad
,
const
int
channels
,
const
int
depth
,
...
...
@@ -532,13 +537,12 @@ void hl_maxpool3D_backward(const int frameCnt,
real
scaleA
,
real
scaleB
,
real
*
targetGrad
,
real
*
maxPoolIdxData
,
const
int
outStride
)
{
int
num_kernels
=
depth
*
height
*
width
*
channels
*
frameCnt
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
KeMaxPool3DBackward
<<<
blocks
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
num_kernels
,
inputData
,
outData
,
outGrad
,
channels
,
depth
,
...
...
@@ -559,6 +563,7 @@ void hl_maxpool3D_backward(const int frameCnt,
scaleA
,
scaleB
,
targetGrad
,
maxPoolIdxData
,
outStride
);
CHECK_SYNC
(
"hl_maxpool3D_backward"
);
}
...
...
paddle/gserver/layers/Pool3DLayer.cpp
浏览文件 @
860bf192
...
...
@@ -72,9 +72,10 @@ size_t Pool3DLayer::getSize() {
void
Pool3DLayer
::
forward
(
PassType
passType
)
{
Layer
::
forward
(
passType
);
const
MatrixPtr
&
inMat
=
inputLayers_
[
0
]
->
getOutputValue
();
in
t
batchSize
=
inMat
->
getHeight
();
in
t
outWidth
=
getSize
();
size_
t
batchSize
=
inMat
->
getHeight
();
size_
t
outWidth
=
getSize
();
resetOutput
(
batchSize
,
outWidth
);
Matrix
::
resizeOrCreate
(
maxPoolIdx_
,
batchSize
,
outWidth
,
false
,
useGpu_
);
const
MatrixPtr
outMat
=
getOutputValue
();
if
(
poolType_
==
"avg"
)
{
...
...
@@ -97,6 +98,7 @@ void Pool3DLayer::forward(PassType passType) {
paddingW_
);
}
else
if
(
poolType_
==
"max"
)
{
outMat
->
maxPool3DForward
(
*
inMat
,
*
maxPoolIdx_
,
channels_
,
imgSizeD_
,
imgSizeH_
,
...
...
@@ -149,9 +151,8 @@ void Pool3DLayer::backward(const UpdateCallback& callback) {
1.0
,
1.0
);
}
else
if
(
poolType_
==
"max"
)
{
inGradMat
->
maxPool3DBackward
(
*
inMat
,
*
outGradMat
,
*
outMat
,
inGradMat
->
maxPool3DBackward
(
*
outGradMat
,
*
maxPoolIdx_
,
imgSizeD_
,
imgSizeH_
,
imgSizeW_
,
...
...
paddle/math/Matrix.cpp
浏览文件 @
860bf192
...
...
@@ -1191,6 +1191,7 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad,
}
void
GpuMatrix
::
maxPool3DForward
(
Matrix
&
inputMat
,
Matrix
&
maxPoolIdx
,
size_t
channels
,
size_t
imgSizeD
,
size_t
imgSizeH
,
...
...
@@ -1210,6 +1211,7 @@ void GpuMatrix::maxPool3DForward(Matrix& inputMat,
CHECK
(
inputMat
.
useGpu_
)
<<
"Matrix type are not correct"
;
real
*
inputData
=
inputMat
.
getData
();
real
*
maxPoolIdxData
=
maxPoolIdx
.
getData
();
size_t
num
=
inputMat
.
getHeight
();
size_t
width
=
imgSizeW
;
size_t
height
=
imgSizeH
;
...
...
@@ -1237,12 +1239,12 @@ void GpuMatrix::maxPool3DForward(Matrix& inputMat,
paddingH
,
paddingW
,
getData
(),
maxPoolIdxData
,
getStride
());
}
void
GpuMatrix
::
maxPool3DBackward
(
Matrix
&
inputMat
,
Matrix
&
outGrad
,
Matrix
&
outV
,
void
GpuMatrix
::
maxPool3DBackward
(
Matrix
&
outGrad
,
Matrix
&
maxPoolIdx
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
...
...
@@ -1260,26 +1262,21 @@ void GpuMatrix::maxPool3DBackward(Matrix& inputMat,
size_t
paddingW
,
real
scaleTargets
,
real
scaleOutput
)
{
CHECK
(
inputMat
.
useGpu_
&&
outGrad
.
useGpu_
&&
outV
.
useGpu_
)
<<
"Matrix type are not equal"
;
CHECK
(
outGrad
.
useGpu_
&&
maxPoolIdx
.
useGpu_
)
<<
"Matrix type are not equal"
;
real
*
inputData
=
inputMat
.
getData
();
real
*
outData
=
outV
.
getData
();
real
*
outDiff
=
outGrad
.
getData
();
size_t
frameNum
=
inputMat
.
getHeight
();
size_t
channels
=
outV
.
getWidth
()
/
outputD
/
outputH
/
outputW
;
real
*
maxPoolIdxData
=
maxPoolIdx
.
getData
();
size_t
frameNum
=
getHeight
();
size_t
channels
=
outGrad
.
getWidth
()
/
outputD
/
outputH
/
outputW
;
size_t
width
=
imgSizeW
;
size_t
height
=
imgSizeH
;
size_t
depth
=
imgSizeD
;
CHECK
(
depth
*
height
*
width
*
channels
==
inputMat
.
getWidth
());
CHECK
(
height_
==
inputMat
.
getHeight
());
CHECK
(
depth
*
height
*
width
*
channels
==
getWidth
());
CHECK
(
width_
==
depth
*
width
*
height
*
channels
);
CHECK
(
outGrad
.
getHeight
()
==
outV
.
getHeight
()
&&
outGrad
.
getWidth
()
==
outV
.
getWidth
());
CHECK
(
outGrad
.
getHeight
()
==
maxPoolIdx
.
getHeight
()
&&
outGrad
.
getWidth
()
==
maxPoolIdx
.
getWidth
());
hl_maxpool3D_backward
(
frameNum
,
inputData
,
outData
,
outDiff
,
channels
,
depth
,
...
...
@@ -1300,6 +1297,7 @@ void GpuMatrix::maxPool3DBackward(Matrix& inputMat,
scaleTargets
,
scaleOutput
,
getData
(),
maxPoolIdxData
,
outGrad
.
getStride
());
}
...
...
@@ -2148,6 +2146,7 @@ void CpuMatrix::avgPoolBackward(Matrix& input,
}
void
CpuMatrix
::
maxPool3DForward
(
Matrix
&
inputMat
,
Matrix
&
maxPoolIdx
,
size_t
channels
,
size_t
imgSizeD
,
size_t
imgSizeH
,
...
...
@@ -2166,6 +2165,7 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat,
size_t
paddingW
)
{
real
*
inputData
=
inputMat
.
getData
();
real
*
outData
=
getData
();
real
*
maxPoolIdxData
=
maxPoolIdx
.
getData
();
size_t
num
=
inputMat
.
getHeight
();
size_t
inWidth
=
imgSizeW
;
size_t
inHeight
=
imgSizeH
;
...
...
@@ -2179,6 +2179,7 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat,
for
(
size_t
i
=
0
;
i
<
height_
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
width_
;
j
++
)
{
outData
[(
i
)
*
outStride
+
j
]
=
-
(
real
)
FLT_MAX
;
maxPoolIdxData
[(
i
)
*
outStride
+
j
]
=
-
1
;
}
}
...
...
@@ -2186,6 +2187,7 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat,
for
(
size_t
n
=
0
;
n
<
num
;
++
n
)
{
// frame by frame
if
(
!
isContiguous
())
{
outData
=
getData
()
+
n
*
outStride
;
maxPoolIdxData
=
maxPoolIdx
.
getData
()
+
n
*
outStride
;
}
for
(
size_t
c
=
0
;
c
<
channels
;
++
c
)
{
// channel by channel
for
(
size_t
pd
=
0
;
pd
<
outputD
;
++
pd
)
{
...
...
@@ -2200,6 +2202,7 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat,
dstart
=
std
::
max
(
dstart
,
0
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
int
maxIdx
=
-
1
;
real
maxOutData
=
outData
[(
pd
*
outputH
+
ph
)
*
outputW
+
pw
];
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
...
...
@@ -2207,24 +2210,26 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat,
if
(
maxOutData
<
inputData
[(
d
*
inHeight
+
h
)
*
inWidth
+
w
])
{
maxOutData
=
inputData
[(
d
*
inHeight
+
h
)
*
inWidth
+
w
];
maxIdx
=
(
d
*
inHeight
+
h
)
*
inWidth
+
w
;
}
}
}
}
outData
[(
pd
*
outputH
+
ph
)
*
outputW
+
pw
]
=
maxOutData
;
maxPoolIdxData
[(
pd
*
outputH
+
ph
)
*
outputW
+
pw
]
=
maxIdx
;
}
}
}
// compute offset
inputData
+=
inDepth
*
inHeight
*
inWidth
;
outData
+=
outputD
*
outputH
*
outputW
;
maxPoolIdxData
+=
outputD
*
outputH
*
outputW
;
}
}
}
void
CpuMatrix
::
maxPool3DBackward
(
Matrix
&
image
,
Matrix
&
outGrad
,
Matrix
&
outV
,
void
CpuMatrix
::
maxPool3DBackward
(
Matrix
&
outGrad
,
Matrix
&
maxPoolIdx
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
...
...
@@ -2242,59 +2247,38 @@ void CpuMatrix::maxPool3DBackward(Matrix& image,
size_t
paddingW
,
real
scaleTargets
,
real
scaleOutput
)
{
size_t
num
=
image
.
getHeight
();
size_t
num
=
getHeight
();
size_t
channels
=
size_t
(
width_
/
imgSizeD
/
imgSizeH
/
imgSizeW
);
CHECK
(
image
.
getWidth
()
==
imgSizeD
*
imgSizeH
*
imgSizeW
*
channels
);
CHECK
(
image
.
getHeight
()
==
height_
&&
image
.
getWidth
()
==
width_
);
CHECK
(
outV
.
getHeight
()
==
outGrad
.
getHeight
()
&&
outV
.
getWidth
()
==
outGrad
.
getWidth
());
CHECK
(
maxPoolIdx
.
getHeight
()
==
outGrad
.
getHeight
()
&&
maxPoolIdx
.
getWidth
()
==
outGrad
.
getWidth
());
real
*
tgtGrad
=
getData
();
real
*
inData
=
image
.
getData
();
real
*
otData
=
outV
.
getData
();
real
*
otGrad
=
outGrad
.
getData
();
real
*
maxPoolIdxData
=
maxPoolIdx
.
getData
();
size_t
outStride
=
out
V
.
getStride
();
size_t
outStride
=
out
Grad
.
getStride
();
;
for
(
size_t
n
=
0
;
n
<
num
;
++
n
)
{
if
(
!
outV
.
isContiguous
())
{
otData
=
outV
.
getData
()
+
n
*
outStride
;
if
(
!
outGrad
.
isContiguous
())
{
otGrad
=
outGrad
.
getData
()
+
n
*
outStride
;
maxPoolIdxData
=
maxPoolIdx
.
getData
()
+
n
*
outStride
;
}
for
(
size_t
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
size_t
pd
=
0
;
pd
<
outputD
;
++
pd
)
{
for
(
size_t
ph
=
0
;
ph
<
outputH
;
++
ph
)
{
for
(
size_t
pw
=
0
;
pw
<
outputW
;
++
pw
)
{
int
dstart
=
pd
*
strideD
-
paddingD
;
int
hstart
=
ph
*
strideH
-
paddingH
;
int
wstart
=
pw
*
strideW
-
paddingW
;
int
dend
=
std
::
min
(
dstart
+
sizeZ
,
imgSizeD
);
int
hend
=
std
::
min
(
hstart
+
sizeY
,
imgSizeH
);
int
wend
=
std
::
min
(
wstart
+
sizeX
,
imgSizeW
);
dstart
=
std
::
max
(
dstart
,
0
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
tgtGrad
[(
d
*
imgSizeH
+
h
)
*
imgSizeW
+
w
]
=
scaleTargets
*
tgtGrad
[(
d
*
imgSizeH
+
h
)
*
imgSizeW
+
w
]
+
scaleOutput
*
otGrad
[(
pd
*
outputH
+
ph
)
*
outputW
+
pw
]
*
(
inData
[(
d
*
imgSizeH
+
h
)
*
imgSizeW
+
w
]
==
otData
[(
pd
*
outputH
+
ph
)
*
outputW
+
pw
]);
}
}
}
const
size_t
index
=
(
pd
*
outputH
+
ph
)
*
outputW
+
pw
;
const
size_t
tgtIdx
=
static_cast
<
size_t
>
(
maxPoolIdxData
[
index
]);
tgtGrad
[
tgtIdx
]
=
scaleTargets
*
tgtGrad
[
tgtIdx
]
+
scaleOutput
*
otGrad
[
index
];
}
}
}
// offset
inData
+=
imgSizeD
*
imgSizeH
*
imgSizeW
;
tgtGrad
+=
imgSizeD
*
imgSizeH
*
imgSizeW
;
otData
+=
outputD
*
outputH
*
outputW
;
otGrad
+=
outputD
*
outputH
*
outputW
;
maxPoolIdxData
+=
outputD
*
outputH
*
outputW
;
}
}
}
...
...
paddle/math/Matrix.h
浏览文件 @
860bf192
...
...
@@ -933,6 +933,7 @@ public:
* in the sizeX of value
*/
virtual
void
maxPool3DForward
(
Matrix
&
inputMat
,
Matrix
&
maxPoolIdx
,
size_t
channels
,
size_t
imgSizeD
,
size_t
imgSizeH
,
...
...
@@ -952,9 +953,8 @@ public:
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
virtual
void
maxPool3DBackward
(
Matrix
&
image
,
Matrix
&
outGrad
,
Matrix
&
outV
,
virtual
void
maxPool3DBackward
(
Matrix
&
outGrad
,
Matrix
&
maxPoolIdx
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
...
...
@@ -1436,6 +1436,7 @@ public:
size_t
paddingW
);
void
maxPool3DForward
(
Matrix
&
inputMat
,
Matrix
&
maxPoolIdx
,
size_t
channels
,
size_t
imgSizeD
,
size_t
imgSizeH
,
...
...
@@ -1453,9 +1454,8 @@ public:
size_t
paddingH
,
size_t
paddingW
);
void
maxPool3DBackward
(
Matrix
&
image
,
Matrix
&
outGrad
,
Matrix
&
outV
,
void
maxPool3DBackward
(
Matrix
&
outGrad
,
Matrix
&
maxPoolIdx
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
...
...
@@ -1671,6 +1671,7 @@ public:
size_t
paddingW
);
void
maxPool3DForward
(
Matrix
&
inputMat
,
Matrix
&
maxPoolIdx
,
size_t
channels
,
size_t
imgSizeD
,
size_t
imgSizeH
,
...
...
@@ -1688,9 +1689,8 @@ public:
size_t
paddingH
,
size_t
paddingW
);
void
maxPool3DBackward
(
Matrix
&
image
,
Matrix
&
outGrad
,
Matrix
&
outV
,
void
maxPool3DBackward
(
Matrix
&
outGrad
,
Matrix
&
maxPoolIdx
,
size_t
imgSizeD
,
size_t
imgSizeH
,
size_t
imgSizeW
,
...
...
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
860bf192
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录