Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
8fded24c
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8fded24c
编写于
2月 21, 2017
作者:
L
Liang Zhao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implement top k classification error in class matrix
上级
d2565128
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
168 addition
and
146 deletion
+168
-146
paddle/cuda/include/hl_matrix.h
paddle/cuda/include/hl_matrix.h
+0
-13
paddle/cuda/include/hl_top_k.h
paddle/cuda/include/hl_top_k.h
+27
-1
paddle/cuda/src/hl_cuda_matrix.cu
paddle/cuda/src/hl_cuda_matrix.cu
+0
-53
paddle/cuda/src/hl_top_k.cu
paddle/cuda/src/hl_top_k.cu
+78
-0
paddle/gserver/evaluators/Evaluator.cpp
paddle/gserver/evaluators/Evaluator.cpp
+6
-46
paddle/gserver/layers/Layer.h
paddle/gserver/layers/Layer.h
+1
-0
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+50
-30
paddle/math/Matrix.h
paddle/math/Matrix.h
+6
-3
未找到文件。
paddle/cuda/include/hl_matrix.h
浏览文件 @
8fded24c
...
...
@@ -69,19 +69,6 @@ extern void hl_sequence_softmax_forward(real* A_d,
const
int
*
index
,
int
numSequence
);
/**
* @brief Matrix classification error.
*
* @param[in] A_d input matrix (M x N).
* @param[in] B_d input vector (M x 1).
* @param[out] C_d output vector (M x 1).
* @param[in] dimM matrix height.
* @param[in] dimN matrix width.
*
*/
extern
void
hl_matrix_classification_error
(
real
*
A_d
,
int
*
B_d
,
real
*
C_d
,
int
dimM
,
int
dimN
);
/**
* @brief Matrix cross entropy.
*
...
...
paddle/cuda/include/hl_top_k.h
浏览文件 @
8fded24c
...
...
@@ -58,4 +58,30 @@ extern void hl_sparse_matrix_top_k(real* topVal,
int
beamSize
,
int
numSamples
);
#endif
/* HL_TOP_K_H_ */
/**
* @brief Matrix classification error.
*
* @param[out] topVal top k element.
* @param[in] ldv leading dimension of topVal.
* @param[out] topIds top k index.
* @param[in] src input value.
* @param[in] lds leading dimension of src.
* @param[in] dim width of input value.
* @param[in] topkSize size of top k element.
* @param[in] numSamples height of input value.
* @param[in] label ground truth label.
* @param[out] recResult top-k classification error.
*
*/
extern
void
hl_matrix_classification_error
(
real
*
topVal
,
int
ldv
,
int
*
topIds
,
real
*
src
,
int
lds
,
int
dim
,
int
topkSize
,
int
numSamples
,
int
*
label
,
real
*
recResult
);
#endif // HL_TOP_K_H_
paddle/cuda/src/hl_cuda_matrix.cu
浏览文件 @
8fded24c
...
...
@@ -265,59 +265,6 @@ void hl_matrix_softmax_derivative(real *grad_d,
CHECK_SYNC
(
"hl_matrix_softmax_derivative failed"
);
}
template
<
int
blockSize
>
__global__
void
KeMatrixClassificationError
(
real
*
in_A
,
int
*
in_B
,
real
*
out_C
,
int
dimN
)
{
__shared__
real
max_s
[
blockSize
];
__shared__
int
max_l
[
blockSize
];
const
int
tid
=
threadIdx
.
x
;
const
int
rowId
=
blockIdx
.
x
;
max_s
[
tid
]
=
-
1e30
f
;
in_A
+=
rowId
*
dimN
;
real
tmp
;
for
(
int
colId
=
tid
;
colId
<
dimN
;
colId
+=
blockSize
)
{
tmp
=
in_A
[
colId
];
if
(
max_s
[
tid
]
<
tmp
)
{
max_s
[
tid
]
=
tmp
;
max_l
[
tid
]
=
colId
;
}
}
__syncthreads
();
for
(
int
stride
=
blockSize
/
2
;
stride
>
0
;
stride
=
stride
/
2
)
{
if
(
tid
<
stride
)
{
if
(
max_s
[
tid
]
<
max_s
[
tid
+
stride
])
{
max_s
[
tid
]
=
max_s
[
tid
+
stride
];
max_l
[
tid
]
=
max_l
[
tid
+
stride
];
}
}
__syncthreads
();
}
__syncthreads
();
if
(
tid
==
0
)
{
out_C
[
rowId
]
=
(
max_l
[
0
]
==
in_B
[
rowId
]
?
0
:
1.0
f
);
}
}
void
hl_matrix_classification_error
(
real
*
A_d
,
int
*
B_d
,
real
*
C_d
,
int
dimM
,
int
dimN
)
{
CHECK_NOTNULL
(
A_d
);
CHECK_NOTNULL
(
B_d
);
CHECK_NOTNULL
(
C_d
);
// each sample is calculated by one block
KeMatrixClassificationError
<
1024
><<<
dimM
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
A_d
,
B_d
,
C_d
,
dimN
);
CHECK_SYNC
(
"hl_matrix_classification_error"
);
}
__global__
void
KeMatrixMultiBinaryCrossEntropy
(
real
*
output
,
real
*
entropy
,
int
*
row
,
...
...
paddle/cuda/src/hl_top_k.cu
浏览文件 @
8fded24c
...
...
@@ -384,3 +384,81 @@ void hl_sparse_matrix_top_k(real* topVal, int ldv,
CHECK_SYNC
(
"hl_sparse_matrix_top_k failed"
);
}
/**
* Each block compute one sample.
* In a block:
* 1. every thread get top maxLength value;
* 2. merge to shTopK, block reduce and get max value;
* 3. go to the second setp, until one thread's topK value is null;
* 4. go to the first setp, until get the topK value.
*/
template
<
int
maxLength
,
int
blockSize
>
__global__
void
KeMatrixTopKClassificationError
(
real
*
topVal
,
int
ldv
,
int
*
topIds
,
real
*
src
,
int
lds
,
int
dim
,
int
beamSize
,
int
*
label
,
real
*
recResult
)
{
__shared__
Pair
shTopK
[
blockSize
];
__shared__
int
maxId
[
blockSize
/
2
];
const
int
tid
=
threadIdx
.
x
;
const
int
warp
=
threadIdx
.
x
/
32
;
src
+=
blockIdx
.
x
*
lds
;
topVal
+=
blockIdx
.
x
*
ldv
;
topIds
+=
blockIdx
.
x
*
beamSize
;
Pair
topK
[
maxLength
];
// NOLINT
int
beam
=
maxLength
;
Pair
max
;
bool
isEmpty
=
false
;
bool
firstStep
=
true
;
int
topkSize
=
beamSize
;
for
(
int
k
=
0
;
k
<
maxLength
;
k
++
)
{
topK
[
k
].
set
(
-
HL_FLOAT_MAX
,
-
1
);
}
while
(
beamSize
)
{
threadGetTopK
<
maxLength
,
blockSize
>
(
topK
,
beam
,
beamSize
,
src
,
firstStep
,
isEmpty
,
max
,
dim
,
tid
);
shTopK
[
tid
]
=
topK
[
0
];
blockReduce
<
maxLength
,
blockSize
>
(
shTopK
,
maxId
,
topK
,
&
topVal
,
&
topIds
,
beam
,
beamSize
,
tid
,
warp
);
}
__syncthreads
();
if
(
tid
==
0
)
{
for
(
int
i
=
0
;
i
<
topkSize
;
i
++
)
{
if
(
*--
topIds
==
label
[
blockIdx
.
x
])
{
recResult
[
blockIdx
.
x
]
=
0
;
break
;
}
recResult
[
blockIdx
.
x
]
=
1.0
f
;
}
}
}
void
hl_matrix_classification_error
(
real
*
topVal
,
int
ldv
,
int
*
topIds
,
real
*
src
,
int
lds
,
int
dim
,
int
topkSize
,
int
numSamples
,
int
*
label
,
real
*
recResult
)
{
CHECK_NOTNULL
(
topVal
);
CHECK_NOTNULL
(
topIds
);
CHECK_NOTNULL
(
src
);
if
(
topkSize
>
dim
)
topkSize
=
dim
;
dim3
threads
(
256
,
1
);
dim3
grid
(
numSamples
,
1
);
KeMatrixTopKClassificationError
<
5
,
256
>
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
topVal
,
ldv
,
topIds
,
src
,
lds
,
dim
,
topkSize
,
label
,
recResult
);
CHECK_SYNC
(
"hl_matrix_top_k classification error failed"
);
}
paddle/gserver/evaluators/Evaluator.cpp
浏览文件 @
8fded24c
...
...
@@ -39,12 +39,13 @@ void Evaluator::eval(const NeuralNetwork& nn) {
*/
class
ClassificationErrorEvaluator
:
public
Evaluator
{
public:
/*
ClassificationErrorEvaluator() : totalScore2_(0) {}
virtual void start() {
Evaluator::start();
totalScore2_ = 0;
}
} */
virtual
void
updateSamplesNum
(
const
std
::
vector
<
Argument
>&
arguments
)
{
if
(
3
==
arguments
.
size
())
{
...
...
@@ -83,42 +84,11 @@ public:
1
,
/* trans= */
false
,
useGpu
(
arguments
[
0
].
deviceId
));
const
MatrixPtr
errorMat2
=
Matrix
::
create
(
output
->
getHeight
(),
1
,
/* trans= */
false
,
false
);
errorMat
->
zeroMem
();
if
(
label
!=
nullptr
)
{
errorMat
->
classificationError
(
*
output
,
*
label
);
// top-1 error
if
(
config_
.
top_k
()
>
1
)
{
size_t
height
=
output
->
getHeight
();
size_t
width
=
config_
.
top_k
();
IVector
::
resizeOrCreate
(
maxIds_
,
height
*
width
,
useGpu
(
arguments
[
0
].
deviceId
));
Matrix
::
resizeOrCreate
(
maxValues_
,
height
,
width
,
false
,
useGpu
(
arguments
[
0
].
deviceId
));
output
->
rowMax
(
*
maxIds_
,
*
maxValues_
);
// top-k values
IVectorPtr
dest
=
IVector
::
create
(
maxIds_
->
getSize
(),
false
);
IVectorPtr
dest2
=
IVector
::
create
(
label
->
getSize
(),
false
);
dest
->
copyFrom
(
*
maxIds_
);
dest2
->
copyFrom
(
*
label
);
int
*
ids
=
dest
->
getData
();
int
*
lbl
=
dest2
->
getData
();
for
(
size_t
i
=
0
;
i
<
height
;
++
i
)
{
bool
contain
=
false
;
for
(
size_t
j
=
0
;
j
<
width
&&
!
contain
;
++
j
)
{
contain
=
(
ids
[
i
*
width
+
j
]
==
lbl
[
i
]);
}
if
(
!
contain
)
{
totalScore2_
+=
1.0
;
// update top-k error
}
}
}
errorMat
->
classificationError
(
*
output
,
*
label
,
config_
.
top_k
());
}
else
if
(
dynamic_cast
<
CpuSparseMatrix
*>
(
multiBinaryLabel
.
get
())
||
dynamic_cast
<
GpuSparseMatrix
*>
(
multiBinaryLabel
.
get
()))
{
errorMat
->
classificationErrorMulti
(
...
...
@@ -139,9 +109,8 @@ public:
os
<<
config_
.
name
()
<<
"="
<<
(
numSamples_
?
totalScore_
/
numSamples_
:
0
);
}
else
{
os
<<
"top_1_error="
<<
(
numSamples_
?
totalScore_
/
numSamples_
:
0
)
<<
" top_"
<<
config_
.
top_k
()
<<
"_error="
<<
(
numSamples_
?
totalScore2_
/
numSamples_
:
0
);
os
<<
" top_"
<<
config_
.
top_k
()
<<
"_error="
<<
(
numSamples_
?
totalScore_
/
numSamples_
:
0
);
}
}
...
...
@@ -151,17 +120,8 @@ public:
}
virtual
void
distributeEval
(
ParameterClient2
*
client
)
{
double
data
[
3
]
=
{
totalScore_
,
totalScore2_
,
numSamples_
};
client
->
reduce
(
data
,
data
,
3
,
FLAGS_trainer_id
,
0
);
totalScore_
=
data
[
0
];
totalScore2_
=
data
[
1
];
numSamples_
=
data
[
2
];
mergeResultsOfAllClients
(
client
);
}
private:
IVectorPtr
maxIds_
;
MatrixPtr
maxValues_
;
double
totalScore2_
;
};
/**
...
...
paddle/gserver/layers/Layer.h
浏览文件 @
8fded24c
...
...
@@ -311,6 +311,7 @@ public:
return
*
output
->
second
;
}
else
{
LOG
(
FATAL
)
<<
"No specific output "
<<
str
;
return
*
((
Argument
*
)
nullptr
);
}
}
}
...
...
paddle/math/Matrix.cpp
浏览文件 @
8fded24c
...
...
@@ -793,19 +793,32 @@ void GpuMatrix::maxoutBackward(Matrix& a,
}
/*calulate the error of classification */
void
GpuMatrix
::
classificationError
(
Matrix
&
output
,
IVector
&
label
)
{
auto
output_ptr
=
dynamic_cast
<
const
GpuMatrix
*>
(
&
output
);
auto
label_ptr
=
dynamic_cast
<
const
GpuIVector
*>
(
&
label
);
CHECK
(
output_ptr
&&
label_ptr
)
<<
"Invalid argument pointer"
;
CHECK
(
height_
==
output_ptr
->
height_
&&
width_
==
1
)
void
GpuMatrix
::
classificationError
(
Matrix
&
output
,
IVector
&
label
,
size_t
topkSize
)
{
auto
gpuOutput
=
dynamic_cast
<
GpuMatrix
*>
(
&
output
);
auto
gpuLabel
=
dynamic_cast
<
GpuIVector
*>
(
&
label
);
size_t
numSamples
=
this
->
getHeight
();
GpuMatrixPtr
gpuTopVal
=
std
::
make_shared
<
GpuMatrix
>
(
numSamples
,
topkSize
);
GpuIVectorPtr
gpuTopIds
=
std
::
make_shared
<
GpuIVector
>
(
numSamples
*
topkSize
);
CHECK
(
gpuOutput
&&
gpuLabel
)
<<
"Invalid argument pointer"
;
CHECK
(
gpuTopVal
&&
gpuTopIds
)
<<
"Allocate GPU memory failed"
;
CHECK
(
gpuLabel
->
getSize
()
==
numSamples
)
<<
"Vector size is not equal"
;
CHECK
(
numSamples
==
gpuOutput
->
getHeight
()
&&
this
->
getWidth
()
==
1
)
<<
"Matrix dimensions are not equal"
;
hl_matrix_classification_error
((
real
*
)
output_ptr
->
data_
,
(
int
*
)
label_ptr
->
getData
(),
data_
,
height_
,
output_ptr
->
width_
);
size_t
dim
=
gpuOutput
->
getWidth
();
hl_matrix_classification_error
(
gpuTopVal
->
getData
(),
gpuTopVal
->
getStride
(),
gpuTopIds
->
getData
(),
gpuOutput
->
getData
(),
gpuOutput
->
getStride
(),
dim
,
topkSize
,
numSamples
,
gpuLabel
->
getData
(),
this
->
getData
());
}
/* copy -log(output[i * width + label]) to this->data[i] */
...
...
@@ -3200,32 +3213,39 @@ void CpuMatrix::rowNormalizeL1(Matrix& out) {
}
/* calulate classification error */
void
CpuMatrix
::
classificationError
(
Matrix
&
output
,
IVector
&
label
)
{
CHECK
(
dynamic_cast
<
const
CpuMatrix
*>
(
&
output
));
CHECK
(
dynamic_cast
<
const
CpuIVector
*>
(
&
label
));
void
CpuMatrix
::
classificationError
(
Matrix
&
output
,
IVector
&
label
,
size_t
topkSize
)
{
size_t
numSamples
=
this
->
getHeight
();
auto
cpuOutput
=
dynamic_cast
<
CpuMatrix
*>
(
&
output
);
auto
cpuLabel
=
dynamic_cast
<
CpuIVector
*>
(
&
label
);
IVectorPtr
cpuTopIds
=
std
::
make_shared
<
CpuIVector
>
(
numSamples
*
topkSize
);
MatrixPtr
cpuTopVal
=
std
::
make_shared
<
CpuMatrix
>
(
numSamples
,
topkSize
);
CHECK
(
cpuOutput
&&
cpuLabel
)
<<
"Invalid argument pointer"
;
CHECK
(
cpuTopIds
&&
cpuTopVal
)
<<
"Allocate cpu memory failed"
;
CHECK
(
cpuLabel
->
getSize
()
==
numSamples
)
<<
"Vector size is not equal"
;
CHECK
(
cpuOutput
->
getHeight
()
==
numSamples
&&
this
->
getWidth
()
==
1
)
<<
"Matrix dimensions are not equal"
;
CHECK_EQ
(
getWidth
(),
(
size_t
)
1
);
size_t
numSamples
=
getHeight
();
CHECK_EQ
(
label
.
getSize
(),
numSamples
);
CHECK_EQ
(
output
.
getHeight
(),
numSamples
);
// top k matrix classification
cpuOutput
->
rowMax
(
*
cpuTopIds
,
*
cpuTopVal
);
size_t
dim
=
output
.
getWidth
();
real
*
out
=
output
.
getData
();
int
*
lbl
=
label
.
getData
();
real
maxData
=
0.0
;
int
maxIndex
=
-
1
;
size_t
dim
=
cpuOutput
->
getWidth
();
real
*
result
=
this
->
getData
();
int
*
ids
=
cpuTopIds
->
getData
();
int
*
lbl
=
cpuLabel
->
getData
();
for
(
size_t
i
=
0
;
i
<
numSamples
;
++
i
)
{
CHECK_GE
(
lbl
[
i
],
0
);
CHECK_LT
((
size_t
)
lbl
[
i
],
dim
);
maxData
=
out
[
i
*
dim
];
maxIndex
=
0
;
for
(
size_t
j
=
0
;
j
<
dim
;
++
j
)
{
if
(
maxData
<
out
[
i
*
dim
+
j
])
{
maxIndex
=
j
;
maxData
=
out
[
i
*
dim
+
j
];
for
(
size_t
j
=
0
;
j
<
topkSize
;
++
j
)
{
if
(
ids
[
j
+
i
*
topkSize
]
==
lbl
[
i
])
{
result
[
i
]
=
0
;
break
;
}
result
[
i
]
=
1.0
f
;
}
getData
()[
i
]
=
(
maxIndex
!=
lbl
[
i
]);
}
}
...
...
paddle/math/Matrix.h
浏览文件 @
8fded24c
...
...
@@ -836,8 +836,11 @@ public:
* output[i] = 1 if row i is an error.
*
* output[i] = 0 if row i is correct.
*
*/
virtual
void
classificationError
(
Matrix
&
output
,
IVector
&
label
)
{
virtual
void
classificationError
(
Matrix
&
output
,
IVector
&
label
,
size_t
topkSize
=
1
)
{
LOG
(
FATAL
)
<<
"Not implemented"
;
}
...
...
@@ -1314,7 +1317,7 @@ public:
void
check
(
std
::
ostream
&
os
,
Matrix
&
refMat
,
bool
printDiff
=
true
);
void
randomizeUniform
();
void
classificationError
(
Matrix
&
output
,
IVector
&
label
);
void
classificationError
(
Matrix
&
output
,
IVector
&
label
,
size_t
topkSize
=
1
);
void
convExpand
(
Matrix
&
feature
,
int
feaImgHeight
,
...
...
@@ -1739,7 +1742,7 @@ public:
void
randomizeUniform
();
void
classificationError
(
Matrix
&
output
,
IVector
&
label
);
void
classificationError
(
Matrix
&
output
,
IVector
&
label
,
size_t
topkSize
=
1
);
void
addByBitCode
(
size_t
numClasses
,
const
IVector
&
codes
,
const
Matrix
&
vec
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录