Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
529f24c2
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
529f24c2
编写于
12月 12, 2016
作者:
H
hedaoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cpu cmrnorm
上级
b3f0f3d2
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
279 addition
and
168 deletion
+279
-168
paddle/cuda/src/hl_cuda_cnn.cu
paddle/cuda/src/hl_cuda_cnn.cu
+77
-115
paddle/gserver/tests/test_LayerGrad.cpp
paddle/gserver/tests/test_LayerGrad.cpp
+1
-2
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+86
-51
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+115
-0
未找到文件。
paddle/cuda/src/hl_cuda_cnn.cu
浏览文件 @
529f24c2
...
...
@@ -381,57 +381,45 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad,
CHECK_SYNC
(
"hl_avgpool_backward failed"
);
}
__global__
void
KeCMRNormFillScale
(
size_t
nthreads
,
const
real
*
in
,
__global__
void
KeCMRNormFillScale
(
size_t
imageSize
,
const
real
*
in
,
real
*
scale
,
size_t
channels
,
size_t
height
,
size_t
width
,
size_t
size
,
real
alpha
)
{
size_t
index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
index
<
nthreads
)
{
// find out the local offset
size_t
w
=
index
%
width
;
size_t
h
=
(
index
/
width
)
%
height
;
size_t
n
=
index
/
width
/
height
;
size_t
offset
=
(
n
*
channels
*
height
+
h
)
*
width
+
w
;
size_t
step
=
height
*
width
;
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
imageSize
)
{
const
int
w
=
idx
%
width
;
const
int
h
=
(
idx
/
width
)
%
height
;
const
int
n
=
idx
/
width
/
height
;
const
int
offset
=
(
n
*
channels
*
height
+
h
)
*
width
+
w
;
in
+=
offset
;
scale
+=
offset
;
size_t
head
=
0
;
size_t
pre_pad
=
(
size
-
1
)
/
2
;
size_t
post_pad
=
size
-
pre_pad
-
1
;
real
accum_scale
=
0
;
// fill the scale at [n, :, h, w]
// accumulate values
while
(
head
<
post_pad
)
{
accum_scale
+=
in
[
head
*
step
]
*
in
[
head
*
step
];
++
head
;
}
// until we reach size, nothing needs to be subtracted
while
(
head
<
size
)
{
accum_scale
+=
in
[
head
*
step
]
*
in
[
head
*
step
];
scale
[(
head
-
post_pad
)
*
step
]
=
1.
+
accum_scale
*
alpha
;
++
head
;
}
// both add and subtract
while
(
head
<
channels
)
{
accum_scale
+=
in
[
head
*
step
]
*
in
[
head
*
step
];
accum_scale
-=
in
[(
head
-
size
)
*
step
]
*
in
[(
head
-
size
)
*
step
];
scale
[(
head
-
post_pad
)
*
step
]
=
1.
+
accum_scale
*
alpha
;
++
head
;
}
// subtract only
while
(
head
<
channels
+
post_pad
)
{
accum_scale
-=
in
[(
head
-
size
)
*
step
]
*
in
[(
head
-
size
)
*
step
];
scale
[(
head
-
post_pad
)
*
step
]
=
1.
+
accum_scale
*
alpha
;
++
head
;
const
int
step
=
height
*
width
;
const
int
pre_pad
=
(
size
-
1
)
/
2
;
const
int
post_pad
=
size
-
pre_pad
-
1
;
real
accum
=
0
;
int
index
=
0
;
while
(
index
<
channels
+
post_pad
)
{
if
(
index
<
channels
)
{
accum
+=
in
[
index
*
step
]
*
in
[
index
*
step
];
}
if
(
index
>=
size
)
{
accum
-=
in
[(
index
-
size
)
*
step
]
*
in
[(
index
-
size
)
*
step
];
}
if
(
index
>=
post_pad
)
{
scale
[(
index
-
post_pad
)
*
step
]
=
1.
+
accum
*
alpha
;
}
++
index
;
}
}
}
__global__
void
KeCMRNormOutput
(
size_t
nthreads
,
const
real
*
in
,
const
real
*
scale
,
real
negative_beta
,
real
*
out
)
{
size_
t
index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
index
<
nthreads
)
{
__global__
void
KeCMRNormOutput
(
size_t
inputSize
,
const
real
*
in
,
const
real
*
scale
,
real
negative_beta
,
real
*
out
)
{
const
in
t
index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
index
<
inputSize
)
{
out
[
index
]
=
in
[
index
]
*
pow
(
scale
[
index
],
negative_beta
);
}
}
...
...
@@ -440,84 +428,60 @@ void hl_CMRNorm_forward(size_t frameCnt, const real* in, real* scale,
real
*
out
,
size_t
channels
,
size_t
height
,
size_t
width
,
size_t
sizeX
,
real
alpha
,
real
beta
)
{
size_t
threadsNum
=
frameCnt
*
height
*
width
;
size_t
blocksX
=
(
threadsNum
+
1024
-
1
)
/
1024
;
size_t
blocksY
=
1
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocksX
,
blocksY
);
KeCMRNormFillScale
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
threadsNum
,
in
,
scale
,
channels
,
height
,
width
,
sizeX
,
alpha
);
threadsNum
=
frameCnt
*
height
*
width
*
channels
;
blocksX
=
(
threadsNum
+
1024
-
1
)
/
1024
;
dim3
threads2
(
1024
,
1
);
dim3
grid2
(
blocksX
,
blocksY
);
KeCMRNormOutput
<<<
grid2
,
threads2
,
0
,
STREAM_DEFAULT
>>>
(
threadsNum
,
in
,
scale
,
beta
,
out
);
size_t
imageSize
=
frameCnt
*
height
*
width
;
int
blockSize
=
1024
;
int
gridSize
=
(
imageSize
+
1024
-
1
)
/
1024
;
KeCMRNormFillScale
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
imageSize
,
in
,
scale
,
channels
,
height
,
width
,
sizeX
,
alpha
);
size_t
inputSize
=
frameCnt
*
height
*
width
*
channels
;
blockSize
=
1024
;
gridSize
=
(
inputSize
+
1024
-
1
)
/
1024
;
KeCMRNormOutput
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
inputSize
,
in
,
scale
,
beta
,
out
);
CHECK_SYNC
(
"hl_CMRNorm_forward"
);
}
__global__
void
KeCMRNormDiff
(
size_t
nthreads
,
const
real
*
bottom_data
,
__global__
void
KeCMRNormDiff
(
size_t
imageSize
,
const
real
*
bottom_data
,
const
real
*
top_data
,
const
real
*
scale
,
const
real
*
top_diff
,
size_t
channels
,
size_t
height
,
size_t
width
,
size_t
size
,
real
negative_beta
,
real
cache_ratio
,
real
*
bottom_diff
)
{
int
index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
index
<
nthreads
)
{
// find out the local offset
size_t
w
=
index
%
width
;
size_t
h
=
(
index
/
width
)
%
height
;
size_t
n
=
index
/
width
/
height
;
size_t
offset
=
(
n
*
channels
*
height
+
h
)
*
width
+
w
;
size_t
step
=
height
*
width
;
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
imageSize
)
{
const
int
w
=
idx
%
width
;
const
int
h
=
(
idx
/
width
)
%
height
;
const
int
n
=
idx
/
width
/
height
;
const
int
offset
=
(
n
*
channels
*
height
+
h
)
*
width
+
w
;
bottom_data
+=
offset
;
top_data
+=
offset
;
scale
+=
offset
;
top_diff
+=
offset
;
bottom_diff
+=
offset
;
int
head
=
0
;
int
pre_pad
=
size
-
(
size
+
1
)
/
2
;
int
post_pad
=
size
-
pre_pad
-
1
;
real
accum_ratio
=
0
;
// accumulate values
while
(
head
<
post_pad
)
{
accum_ratio
+=
top_diff
[
head
*
step
]
*
top_data
[
head
*
step
]
/
scale
[
head
*
step
];
++
head
;
}
// until we reach size, nothing needs to be subtracted
while
(
head
<
size
)
{
accum_ratio
+=
top_diff
[
head
*
step
]
*
top_data
[
head
*
step
]
/
scale
[
head
*
step
];
bottom_diff
[(
head
-
post_pad
)
*
step
]
+=
top_diff
[(
head
-
post_pad
)
*
step
]
*
pow
(
scale
[(
head
-
post_pad
)
*
step
],
negative_beta
)
-
cache_ratio
*
bottom_data
[(
head
-
post_pad
)
*
step
]
*
accum_ratio
;
++
head
;
}
// both add and subtract
while
(
head
<
channels
)
{
accum_ratio
+=
top_diff
[
head
*
step
]
*
top_data
[
head
*
step
]
/
scale
[
head
*
step
];
accum_ratio
-=
top_diff
[(
head
-
size
)
*
step
]
*
top_data
[(
head
-
size
)
*
step
]
/
scale
[(
head
-
size
)
*
step
];
bottom_diff
[(
head
-
post_pad
)
*
step
]
+=
top_diff
[(
head
-
post_pad
)
*
step
]
*
pow
(
scale
[(
head
-
post_pad
)
*
step
],
negative_beta
)
-
cache_ratio
*
bottom_data
[(
head
-
post_pad
)
*
step
]
*
accum_ratio
;
++
head
;
}
// subtract only
while
(
head
<
channels
+
post_pad
)
{
accum_ratio
-=
top_diff
[(
head
-
size
)
*
step
]
*
top_data
[(
head
-
size
)
*
step
]
/
scale
[(
head
-
size
)
*
step
];
bottom_diff
[(
head
-
post_pad
)
*
step
]
+=
top_diff
[(
head
-
post_pad
)
*
step
]
*
pow
(
scale
[(
head
-
post_pad
)
*
step
],
negative_beta
)
-
cache_ratio
*
bottom_data
[(
head
-
post_pad
)
*
step
]
*
accum_ratio
;
++
head
;
const
int
step
=
height
*
width
;
const
int
pre_pad
=
size
-
(
size
+
1
)
/
2
;
const
int
post_pad
=
size
-
pre_pad
-
1
;
int
index
=
0
;
real
accum
=
0
;
while
(
index
<
channels
+
post_pad
)
{
if
(
index
<
channels
)
{
accum
+=
top_diff
[
index
*
step
]
*
top_data
[
index
*
step
]
/
scale
[
index
*
step
];
}
if
(
index
>=
size
)
{
accum
-=
top_diff
[(
index
-
size
)
*
step
]
*
top_data
[(
index
-
size
)
*
step
]
/
scale
[(
index
-
size
)
*
step
];
}
if
(
index
>=
post_pad
)
{
bottom_diff
[(
index
-
post_pad
)
*
step
]
+=
top_diff
[(
index
-
post_pad
)
*
step
]
*
pow
(
scale
[(
index
-
post_pad
)
*
step
],
negative_beta
)
-
cache_ratio
*
bottom_data
[(
index
-
post_pad
)
*
step
]
*
accum
;
}
++
index
;
}
}
}
...
...
@@ -528,14 +492,12 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
real
*
inDiff
,
size_t
channels
,
size_t
height
,
size_t
width
,
size_t
sizeX
,
real
alpha
,
real
beta
)
{
size_t
threadsNum
=
frameCnt
*
height
*
width
;
size_t
blocksX
=
(
threadsNum
+
1024
-
1
)
/
1024
;
size_t
blocksY
=
1
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocksX
,
blocksY
);
KeCMRNormDiff
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
threadsNum
,
inV
,
outV
,
scale
,
outDiff
,
channels
,
height
,
width
,
sizeX
,
alpha
,
beta
,
inDiff
);
size_t
imageSize
=
frameCnt
*
height
*
width
;
int
blockSize
=
1024
;
int
gridSize
=
(
imageSize
+
1024
-
1
)
/
1024
;
KeCMRNormDiff
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
imageSize
,
inV
,
outV
,
scale
,
outDiff
,
channels
,
height
,
width
,
sizeX
,
alpha
,
beta
,
inDiff
);
CHECK_SYNC
(
"hl_CMRNorm_backward"
);
}
...
...
paddle/gserver/tests/test_LayerGrad.cpp
浏览文件 @
529f24c2
...
...
@@ -1021,11 +1021,10 @@ void testNormLayer(const string& normType, bool trans, bool useGpu) {
testLayerGrad
(
config
,
"norm"
,
100
,
trans
,
useGpu
);
}
#ifndef PADDLE_ONLY_CPU
TEST
(
Layer
,
NormLayer
)
{
testNormLayer
(
"cmrnorm-projection"
,
/* trans= */
false
,
/* useGpu= */
true
);
testNormLayer
(
"cmrnorm-projection"
,
/* trans= */
false
,
/* useGpu= */
false
);
}
#endif
void
setPoolConfig
(
TestConfig
*
config
,
PoolConfig
*
pool
,
...
...
paddle/math/Matrix.cpp
浏览文件 @
529f24c2
...
...
@@ -2227,52 +2227,43 @@ void CpuMatrix::crossMapNormalFwd(Matrix& input,
size_t
sizeX
,
float
scale
,
float
pow
)
{
size_t
num
=
input
.
getHeight
();
CHECK
(
isContiguous
());
CHECK
(
input
.
isContiguous
());
CHECK
(
denoms
.
isContiguous
());
CHECK_EQ
(
getHeight
(),
input
.
getHeight
());
CHECK_EQ
(
getWidth
(),
input
.
getWidth
());
CHECK_EQ
(
getHeight
(),
denoms
.
getHeight
());
CHECK_EQ
(
getWidth
(),
denoms
.
getWidth
());
size_t
numSample
=
input
.
getHeight
();
size_t
numCols
=
input
.
getWidth
();
size_t
height
=
imgSizeH
;
size_t
width
=
imgSizeW
;
size_t
numCols
=
input
.
getWidth
();
CHECK
(
height
*
width
*
channels
==
input
.
getWidth
());
CHECK
(
denoms
.
getHeight
()
==
input
.
getHeight
()
&&
denoms
.
getWidth
()
==
input
.
getWidth
()
&&
input
.
getHeight
()
==
height_
&&
input
.
getWidth
()
==
width_
);
real
*
imgData
=
input
.
getData
();
real
*
diffData
=
input
.
getData
();
real
*
targetData
=
getData
();
size_t
halfSize
=
sizeX
/
2
;
size_t
imgPixels
=
height
*
width
;
// use integral vector to implement the sum in local window
real
*
integralData
=
(
real
*
)
malloc
((
channels
+
sizeX
+
1
)
*
sizeof
(
real
));
// NOLINT // TODO:
for
(
size_t
i
=
0
;
i
<=
halfSize
;
i
++
)
{
integralData
[
i
]
=
0
;
}
for
(
size_t
i
=
0
;
i
<
num
;
i
++
)
{
real
*
targetPtr
=
targetData
+
i
*
numCols
;
real
*
imgPtr
=
imgData
+
i
*
numCols
;
real
*
diffPtr
=
diffData
+
i
*
numCols
;
for
(
size_t
m
=
0
;
m
<
height
;
m
++
)
{
for
(
size_t
n
=
0
;
n
<
width
;
n
++
)
{
for
(
size_t
c
=
0
;
c
<
channels
;
c
++
)
{
integralData
[
c
+
halfSize
+
1
]
=
integralData
[
c
+
halfSize
]
+
_square
(
*
(
diffPtr
+
c
*
imgPixels
));
}
for
(
size_t
k
=
channels
+
halfSize
+
1
;
k
<=
channels
+
sizeX
;
k
++
)
{
integralData
[
k
]
=
integralData
[
channels
+
halfSize
];
CHECK
(
height
*
width
*
channels
==
numCols
);
// TODO(hedaoyuan) After commit TensorExpress code,
// Reconstruction this code to remove the temporary memory.
CpuMatrix
tmp
(
channels
,
height
*
width
);
CpuMatrix
tmp2
(
tmp
.
getData
(),
1
,
channels
*
height
*
width
);
denoms
.
zero
();
const
int
start
=
-
((
int
)
sizeX
-
1
)
/
2
;
const
int
end
=
(
int
)
sizeX
+
start
;
for
(
size_t
i
=
0
;
i
<
numSample
;
i
++
)
{
input
.
subMatrix
(
i
,
1
)
->
square2
(
tmp2
);
CpuMatrix
subDen
(
denoms
.
subMatrix
(
i
,
1
)
->
getData
(),
channels
,
height
*
width
);
for
(
int
c
=
0
;
c
<
(
int
)
channels
;
c
++
)
{
for
(
int
s
=
start
;
s
<
end
;
s
++
)
{
if
(
c
+
s
>=
0
&&
c
+
s
<
(
int
)
channels
)
{
subDen
.
subMatrix
(
c
,
1
)
->
add
(
*
tmp
.
subMatrix
(
c
+
s
,
1
));
}
for
(
size_t
k
=
0
;
k
<
channels
;
k
+=
1
)
{
real
a
=
integralData
[
k
+
sizeX
]
-
integralData
[
k
];
a
=
scale
*
a
+
1
;
targetPtr
[
k
*
imgPixels
]
=
imgPtr
[
k
*
imgPixels
]
*
_pow
(
a
,
-
pow
);
}
diffPtr
++
;
targetPtr
++
;
imgPtr
++
;
}
}
}
free
(
integralData
);
integralData
=
NULL
;
denoms
.
add
(
scale
,
(
real
)
1
);
this
->
pow2
(
denoms
,
-
pow
);
this
->
dotMul
(
input
);
}
void
CpuMatrix
::
crossMapNormalBwd
(
Matrix
&
localGrad
,
...
...
@@ -2282,19 +2273,63 @@ void CpuMatrix::crossMapNormalBwd(Matrix& localGrad,
size_t
channels
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
size
,
size_t
size
X
,
float
scale
,
float
pow
)
{
LOG
(
FATAL
)
<<
"Not implemented"
;
CHECK
(
imgSizeH
*
imgSizeW
*
channels
==
preOutV
.
getWidth
());
CHECK
(
denoms
.
getHeight
()
==
preOutV
.
getHeight
()
&&
denoms
.
getWidth
()
==
preOutV
.
getWidth
()
&&
preOutV
.
getHeight
()
==
height_
&&
preOutV
.
getWidth
()
==
width_
);
CHECK
(
denoms
.
getHeight
()
==
localGrad
.
getHeight
()
&&
denoms
.
getWidth
()
==
localGrad
.
getWidth
());
// NOLINT // TODO:
CHECK
(
isContiguous
());
CHECK
(
localGrad
.
isContiguous
());
CHECK
(
denoms
.
isContiguous
());
CHECK
(
preOutV
.
isContiguous
());
CHECK
(
localOutV
.
isContiguous
());
CHECK_EQ
(
getHeight
(),
localGrad
.
getHeight
());
CHECK_EQ
(
getWidth
(),
localGrad
.
getWidth
());
CHECK_EQ
(
getHeight
(),
denoms
.
getHeight
());
CHECK_EQ
(
getWidth
(),
denoms
.
getWidth
());
CHECK_EQ
(
getHeight
(),
preOutV
.
getHeight
());
CHECK_EQ
(
getWidth
(),
preOutV
.
getWidth
());
CHECK_EQ
(
getHeight
(),
localOutV
.
getHeight
());
CHECK_EQ
(
getWidth
(),
localOutV
.
getWidth
());
size_t
numSample
=
getHeight
();
size_t
numCols
=
getWidth
();
size_t
height
=
imgSizeH
;
size_t
width
=
imgSizeW
;
CHECK
(
height
*
width
*
channels
==
numCols
);
// TODO(hedaoyuan) After commit TensorExpress code,
// Reconstruction this code to remove the temporary memory.
CpuMatrix
tmp
(
1
,
height
*
width
);
const
int
start
=
-
((
int
)
sizeX
)
/
2
;
const
int
end
=
(
int
)
sizeX
+
start
;
const
real
ratio
=
-
(
real
)
2
*
scale
*
pow
;
for
(
size_t
i
=
0
;
i
<
numSample
;
i
++
)
{
CpuMatrix
inputDiff
(
this
->
subMatrix
(
i
,
1
)
->
getData
(),
channels
,
height
*
width
);
CpuMatrix
outDiff
(
localGrad
.
subMatrix
(
i
,
1
)
->
getData
(),
channels
,
height
*
width
);
CpuMatrix
input
(
preOutV
.
subMatrix
(
i
,
1
)
->
getData
(),
channels
,
height
*
width
);
CpuMatrix
output
(
localOutV
.
subMatrix
(
i
,
1
)
->
getData
(),
channels
,
height
*
width
);
CpuMatrix
subDen
(
denoms
.
subMatrix
(
i
,
1
)
->
getData
(),
channels
,
height
*
width
);
for
(
int
c
=
0
;
c
<
(
int
)
channels
;
c
++
)
{
tmp
.
pow2
(
*
subDen
.
subMatrix
(
c
,
1
),
-
pow
);
inputDiff
.
subMatrix
(
c
,
1
)
->
addDotMul
(
tmp
,
*
outDiff
.
subMatrix
(
c
,
1
),
(
real
)
1
,
(
real
)
1
);
for
(
int
s
=
start
;
s
<
end
;
s
++
)
{
if
(
c
+
s
>=
0
&&
c
+
s
<
(
int
)
channels
)
{
tmp
.
dotMul
(
*
outDiff
.
subMatrix
(
c
+
s
,
1
),
*
output
.
subMatrix
(
c
+
s
,
1
));
tmp
.
mulScalar
(
ratio
);
tmp
.
dotDiv
(
tmp
,
*
subDen
.
subMatrix
(
c
+
s
,
1
));
tmp
.
dotMul
(
*
input
.
subMatrix
(
c
,
1
));
inputDiff
.
subMatrix
(
c
,
1
)
->
add
(
tmp
);
}
}
}
}
}
/**
...
...
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
529f24c2
...
...
@@ -1261,6 +1261,121 @@ TEST(Matrix, MaxOutFwdBwd) {
}
}
}
void
testCrossMapNormalFwd
(
int
numSamples
,
int
channels
,
int
imgSizeH
,
int
imgSizeW
,
int
sizeX
)
{
float
scale
=
1.5
;
float
pow
=
0.5
;
int
width
=
imgSizeH
*
imgSizeW
*
channels
;
MatrixPtr
input
=
CpuMatrix
::
create
(
numSamples
,
width
,
false
,
false
);
MatrixPtr
denorms
=
CpuMatrix
::
create
(
numSamples
,
width
,
false
,
false
);
MatrixPtr
target
=
CpuMatrix
::
create
(
numSamples
,
width
,
false
,
false
);
MatrixPtr
inputGpu
=
GpuMatrix
::
create
(
numSamples
,
width
,
false
,
true
);
MatrixPtr
denormsGpu
=
GpuMatrix
::
create
(
numSamples
,
width
,
false
,
true
);
MatrixPtr
targetGpu
=
GpuMatrix
::
create
(
numSamples
,
width
,
false
,
true
);
input
->
randomizeUniform
();
target
->
randomizeUniform
();
inputGpu
->
copyFrom
(
*
input
);
targetGpu
->
copyFrom
(
*
target
);
target
->
crossMapNormalFwd
(
*
input
,
imgSizeH
,
imgSizeW
,
*
denorms
,
channels
,
sizeX
,
scale
,
pow
);
targetGpu
->
crossMapNormalFwd
(
*
inputGpu
,
imgSizeH
,
imgSizeW
,
*
denormsGpu
,
channels
,
sizeX
,
scale
,
pow
);
TensorCheckErr
(
*
target
,
*
targetGpu
);
TensorCheckErr
(
*
denorms
,
*
denormsGpu
);
}
TEST
(
Matrix
,
crossMapNormalFwd
)
{
for
(
auto
numSamples
:
{
5
,
32
})
{
for
(
auto
channels
:
{
1
,
5
,
32
})
{
for
(
auto
imgSizeH
:
{
5
,
33
,
100
})
{
for
(
auto
imgSizeW
:
{
5
,
32
,
96
})
{
for
(
auto
sizeX
:
{
1
,
2
,
3
,
5
,
7
})
{
VLOG
(
3
)
<<
" numSamples="
<<
numSamples
<<
" channels="
<<
channels
<<
" imgSizeH="
<<
imgSizeH
<<
" imgSizeW="
<<
imgSizeW
<<
" sizeX="
<<
sizeX
;
testCrossMapNormalFwd
(
numSamples
,
channels
,
imgSizeH
,
imgSizeW
,
sizeX
);
}
}
}
}
}
}
void
testCrossMapNormalBwd
(
int
numSamples
,
int
channels
,
int
imgSizeH
,
int
imgSizeW
,
int
sizeX
)
{
float
scale
=
1.5
;
float
pow
=
0.5
;
size_t
width
=
imgSizeH
*
imgSizeW
*
channels
;
MatrixPtr
localGrad
=
CpuMatrix
::
create
(
numSamples
,
width
,
false
,
false
);
MatrixPtr
denoms
=
CpuMatrix
::
create
(
numSamples
,
width
,
false
,
false
);
MatrixPtr
output
=
CpuMatrix
::
create
(
numSamples
,
width
,
false
,
false
);
MatrixPtr
preOutV
=
CpuMatrix
::
create
(
numSamples
,
width
,
false
,
false
);
MatrixPtr
localOutV
=
CpuMatrix
::
create
(
numSamples
,
width
,
false
,
false
);
localGrad
->
randomizeUniform
();
denoms
->
randomizeUniform
();
preOutV
->
randomizeUniform
();
localOutV
->
randomizeUniform
();
output
->
randomizeUniform
();
denoms
->
add
(
0.01
);
MatrixPtr
localGradGpu
=
GpuMatrix
::
create
(
numSamples
,
width
,
false
,
true
);
MatrixPtr
denomsGpu
=
GpuMatrix
::
create
(
numSamples
,
width
,
false
,
true
);
MatrixPtr
outputGpu
=
GpuMatrix
::
create
(
numSamples
,
width
,
false
,
true
);
MatrixPtr
preOutVGpu
=
GpuMatrix
::
create
(
numSamples
,
width
,
false
,
true
);
MatrixPtr
localOutVGpu
=
GpuMatrix
::
create
(
numSamples
,
width
,
false
,
true
);
localGradGpu
->
copyFrom
(
*
localGrad
);
denomsGpu
->
copyFrom
(
*
denoms
);
preOutVGpu
->
copyFrom
(
*
preOutV
);
localOutVGpu
->
copyFrom
(
*
localOutV
);
outputGpu
->
copyFrom
(
*
output
);
output
->
crossMapNormalBwd
(
*
localGrad
,
*
denoms
,
*
preOutV
,
*
localOutV
,
channels
,
imgSizeH
,
imgSizeW
,
sizeX
,
scale
,
pow
);
outputGpu
->
crossMapNormalBwd
(
*
localGradGpu
,
*
denomsGpu
,
*
preOutVGpu
,
*
localOutVGpu
,
channels
,
imgSizeH
,
imgSizeW
,
sizeX
,
scale
,
pow
);
TensorCheckErr
(
*
output
,
*
outputGpu
);
}
TEST
(
Matrix
,
crossMapNormalBwd
)
{
for
(
auto
numSamples
:
{
5
,
32
})
{
for
(
auto
channels
:
{
1
,
5
,
32
})
{
for
(
auto
imgSizeH
:
{
5
,
33
,
100
})
{
for
(
auto
imgSizeW
:
{
5
,
32
,
96
})
{
for
(
auto
sizeX
:
{
1
,
2
,
3
,
5
,
7
})
{
VLOG
(
3
)
<<
" numSamples="
<<
numSamples
<<
" channels="
<<
channels
<<
" imgSizeH="
<<
imgSizeH
<<
" imgSizeW="
<<
imgSizeW
<<
" sizeX="
<<
sizeX
;
testCrossMapNormalBwd
(
numSamples
,
channels
,
imgSizeH
,
imgSizeW
,
sizeX
);
}
}
}
}
}
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录