Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
903d5c7e
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
903d5c7e
编写于
9月 07, 2016
作者:
H
He
提交者:
Yu Yang
9月 08, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bug fix for hl_matrix_classification_error
上级
fdd40e55
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
43 addition
and
23 deletion
+43
-23
paddle/cuda/src/hl_cuda_matrix.cu
paddle/cuda/src/hl_cuda_matrix.cu
+13
-20
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+30
-3
未找到文件。
paddle/cuda/src/hl_cuda_matrix.cu
浏览文件 @
903d5c7e
...
...
@@ -266,25 +266,21 @@ template<int blockSize>
__global__
void
KeMatrixClassificationError
(
real
*
in_A
,
int
*
in_B
,
real
*
out_C
,
int
dimM
,
int
dimN
)
{
__shared__
real
max_s
[
blockSize
];
__shared__
int
max_l
[
blockSize
];
int
cnt
=
(
dimN
+
blockSize
-
1
)
/
blockSize
;
int
tid
=
threadIdx
.
x
;
int
lmt
=
tid
;
int
index
=
0
;
real
t
;
const
int
tid
=
threadIdx
.
x
;
const
int
rowId
=
blockIdx
.
x
;
max_s
[
tid
]
=
-
1e30
f
;
for
(
int
ii
=
0
;
ii
<
cnt
&&
lmt
<
dimN
;
ii
++
)
{
index
=
blockIdx
.
y
*
dimN
+
lmt
;
t
=
in_A
[
index
];
if
(
max_s
[
tid
]
<
t
)
{
max_s
[
tid
]
=
t
;
max_l
[
tid
]
=
lmt
;
in_A
+=
rowId
*
dimN
;
real
tmp
;
for
(
int
colId
=
tid
;
colId
<
dimN
;
colId
+=
blockSize
)
{
tmp
=
in_A
[
colId
];
if
(
max_s
[
tid
]
<
tmp
)
{
max_s
[
tid
]
=
tmp
;
max_l
[
tid
]
=
colId
;
}
lmt
+=
blockSize
;
}
__syncthreads
();
...
...
@@ -300,7 +296,7 @@ __global__ void KeMatrixClassificationError(real* in_A,
__syncthreads
();
if
(
tid
==
0
)
{
out_C
[
blockIdx
.
y
]
=
(
max_l
[
0
]
==
in_B
[
blockIdx
.
y
]
?
0
:
1.0
f
);
out_C
[
rowId
]
=
(
max_l
[
0
]
==
in_B
[
rowId
]
?
0
:
1.0
f
);
}
}
...
...
@@ -313,12 +309,9 @@ void hl_matrix_classification_error(real* A_d,
CHECK_NOTNULL
(
B_d
);
CHECK_NOTNULL
(
C_d
);
int
blocksX
=
1
;
int
blocksY
=
dimM
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocksX
,
blocksY
);
KeMatrixClassificationError
<
1024
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
A_d
,
B_d
,
C_d
,
dimM
,
dimN
);
// each sample is calculated by one block
KeMatrixClassificationError
<
1024
><<<
dimM
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
A_d
,
B_d
,
C_d
,
dimN
);
CHECK_SYNC
(
"hl_matrix_classification_error"
);
}
...
...
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
903d5c7e
...
...
@@ -1697,7 +1697,6 @@ TEST(Matrix, cosSimDerivate) {
}
}
void
testParamReluForward
(
int
height
,
int
width
,
int
w_height
,
int
w_width
)
{
MatrixPtr
output
=
CpuMatrix
::
create
(
height
,
width
,
false
,
false
);
...
...
@@ -1736,7 +1735,6 @@ TEST(Matrix, paramReluForward) {
}
}
void
testParamReluBackwardW
(
int
height
,
int
width
,
int
w_height
,
int
w_width
)
{
MatrixPtr
oGrad
=
CpuMatrix
::
create
(
height
,
width
,
false
,
false
);
...
...
@@ -1775,7 +1773,6 @@ TEST(Matrix, paramReluBackwardW) {
}
}
void
testParamReluBackwardDiff
(
int
height
,
int
width
,
int
w_height
,
int
w_width
)
{
MatrixPtr
oGrad
=
CpuMatrix
::
create
(
height
,
width
,
false
,
false
);
...
...
@@ -1819,6 +1816,36 @@ TEST(Matrix, paramReluBackwardDiff) {
}
}
void
testClassificationError
(
int
numSamples
,
int
dim
)
{
MatrixPtr
cpuError
=
std
::
make_shared
<
CpuMatrix
>
(
numSamples
,
1
);
MatrixPtr
gpuError
=
std
::
make_shared
<
GpuMatrix
>
(
numSamples
,
1
);
MatrixPtr
cpuOutput
=
std
::
make_shared
<
CpuMatrix
>
(
numSamples
,
dim
);
MatrixPtr
gpuOutput
=
std
::
make_shared
<
GpuMatrix
>
(
numSamples
,
dim
);
IVectorPtr
cpuLabel
=
std
::
make_shared
<
CpuIVector
>
(
numSamples
);
IVectorPtr
gpuLabel
=
std
::
make_shared
<
GpuIVector
>
(
numSamples
);
cpuOutput
->
randomizeUniform
();
cpuLabel
->
rand
(
dim
);
gpuOutput
->
copyFrom
(
*
cpuOutput
);
gpuLabel
->
copyFrom
(
*
cpuLabel
);
cpuError
->
classificationError
(
cpuOutput
,
cpuLabel
);
gpuError
->
classificationError
(
gpuOutput
,
gpuLabel
);
MatrixPtr
check
=
std
::
make_shared
<
CpuMatrix
>
(
numSamples
,
1
);
check
->
copyFrom
(
*
gpuError
);
MatrixCheckEqual
(
*
cpuError
,
*
check
);
}
TEST
(
Matrix
,
classificationError
)
{
for
(
auto
numSamples
:
{
1
,
10
,
100
,
1000
,
70000
})
{
for
(
auto
dim
:
{
1
,
10
,
100
,
1000
})
{
VLOG
(
3
)
<<
" numSamples="
<<
numSamples
<<
" dim="
<<
dim
;
testClassificationError
(
numSamples
,
dim
);
}
}
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
initMain
(
argc
,
argv
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录