Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f2c7c9bc
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f2c7c9bc
编写于
2月 22, 2017
作者:
W
wangkuiyi
提交者:
GitHub
2月 22, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1337 from lzhao4ever/topk-error
Add top-k error
上级
59ca13a3
046349dd
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
236 addition
and
122 deletion
+236
-122
paddle/cuda/include/hl_matrix.h
paddle/cuda/include/hl_matrix.h
+0
-13
paddle/cuda/include/hl_top_k.h
paddle/cuda/include/hl_top_k.h
+27
-1
paddle/cuda/include/stub/hl_matrix_stub.h
paddle/cuda/include/stub/hl_matrix_stub.h
+10
-2
paddle/cuda/src/hl_cuda_matrix.cu
paddle/cuda/src/hl_cuda_matrix.cu
+0
-53
paddle/cuda/src/hl_top_k.cu
paddle/cuda/src/hl_top_k.cu
+78
-0
paddle/gserver/evaluators/Evaluator.cpp
paddle/gserver/evaluators/Evaluator.cpp
+21
-1
paddle/gserver/layers/Layer.h
paddle/gserver/layers/Layer.h
+1
-0
paddle/gserver/tests/test_Evaluator.cpp
paddle/gserver/tests/test_Evaluator.cpp
+1
-0
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+53
-31
paddle/math/Matrix.h
paddle/math/Matrix.h
+6
-3
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+12
-7
paddle/parameter/Parameter.cpp
paddle/parameter/Parameter.cpp
+0
-4
proto/ModelConfig.proto
proto/ModelConfig.proto
+4
-0
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+3
-0
python/paddle/trainer_config_helpers/evaluators.py
python/paddle/trainer_config_helpers/evaluators.py
+10
-0
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+10
-7
未找到文件。
paddle/cuda/include/hl_matrix.h
浏览文件 @
f2c7c9bc
...
@@ -69,19 +69,6 @@ extern void hl_sequence_softmax_forward(real* A_d,
...
@@ -69,19 +69,6 @@ extern void hl_sequence_softmax_forward(real* A_d,
const
int
*
index
,
const
int
*
index
,
int
numSequence
);
int
numSequence
);
/**
* @brief Matrix classification error.
*
* @param[in] A_d input matrix (M x N).
* @param[in] B_d input vector (M x 1).
* @param[out] C_d output vector (M x 1).
* @param[in] dimM matrix height.
* @param[in] dimN matrix width.
*
*/
extern
void
hl_matrix_classification_error
(
real
*
A_d
,
int
*
B_d
,
real
*
C_d
,
int
dimM
,
int
dimN
);
/**
/**
* @brief Matrix cross entropy.
* @brief Matrix cross entropy.
*
*
...
...
paddle/cuda/include/hl_top_k.h
浏览文件 @
f2c7c9bc
...
@@ -58,4 +58,30 @@ extern void hl_sparse_matrix_top_k(real* topVal,
...
@@ -58,4 +58,30 @@ extern void hl_sparse_matrix_top_k(real* topVal,
int
beamSize
,
int
beamSize
,
int
numSamples
);
int
numSamples
);
#endif
/* HL_TOP_K_H_ */
/**
* @brief Matrix classification error.
*
* @param[out] topVal top k element.
* @param[in] ldv leading dimension of topVal.
* @param[out] topIds top k index.
* @param[in] src input value.
* @param[in] lds leading dimension of src.
* @param[in] dim width of input value.
* @param[in] topkSize size of top k element.
* @param[in] numSamples height of input value.
* @param[in] label ground truth label.
* @param[out] recResult top-k classification error.
*
*/
extern
void
hl_matrix_classification_error
(
real
*
topVal
,
int
ldv
,
int
*
topIds
,
real
*
src
,
int
lds
,
int
dim
,
int
topkSize
,
int
numSamples
,
int
*
label
,
real
*
recResult
);
#endif // HL_TOP_K_H_
paddle/cuda/include/stub/hl_matrix_stub.h
浏览文件 @
f2c7c9bc
...
@@ -35,8 +35,16 @@ inline void hl_sequence_softmax_forward(real* A_d,
...
@@ -35,8 +35,16 @@ inline void hl_sequence_softmax_forward(real* A_d,
inline
void
hl_matrix_softmax_derivative
(
inline
void
hl_matrix_softmax_derivative
(
real
*
grad_d
,
real
*
output_d
,
real
*
sftmaxSum_d
,
int
dimM
,
int
dimN
)
{}
real
*
grad_d
,
real
*
output_d
,
real
*
sftmaxSum_d
,
int
dimM
,
int
dimN
)
{}
inline
void
hl_matrix_classification_error
(
inline
void
hl_matrix_classification_error
(
real
*
topVal
,
real
*
A_d
,
int
*
B_d
,
real
*
C_d
,
int
dimM
,
int
dimN
)
{}
int
ldv
,
int
*
topIds
,
real
*
src
,
int
lds
,
int
dim
,
int
topkSize
,
int
numSamples
,
int
*
label
,
real
*
recResult
)
{}
inline
void
hl_matrix_cross_entropy
(
inline
void
hl_matrix_cross_entropy
(
real
*
A_d
,
real
*
C_d
,
int
*
label_d
,
int
dimM
,
int
dimN
)
{}
real
*
A_d
,
real
*
C_d
,
int
*
label_d
,
int
dimM
,
int
dimN
)
{}
...
...
paddle/cuda/src/hl_cuda_matrix.cu
浏览文件 @
f2c7c9bc
...
@@ -265,59 +265,6 @@ void hl_matrix_softmax_derivative(real *grad_d,
...
@@ -265,59 +265,6 @@ void hl_matrix_softmax_derivative(real *grad_d,
CHECK_SYNC
(
"hl_matrix_softmax_derivative failed"
);
CHECK_SYNC
(
"hl_matrix_softmax_derivative failed"
);
}
}
template
<
int
blockSize
>
__global__
void
KeMatrixClassificationError
(
real
*
in_A
,
int
*
in_B
,
real
*
out_C
,
int
dimN
)
{
__shared__
real
max_s
[
blockSize
];
__shared__
int
max_l
[
blockSize
];
const
int
tid
=
threadIdx
.
x
;
const
int
rowId
=
blockIdx
.
x
;
max_s
[
tid
]
=
-
1e30
f
;
in_A
+=
rowId
*
dimN
;
real
tmp
;
for
(
int
colId
=
tid
;
colId
<
dimN
;
colId
+=
blockSize
)
{
tmp
=
in_A
[
colId
];
if
(
max_s
[
tid
]
<
tmp
)
{
max_s
[
tid
]
=
tmp
;
max_l
[
tid
]
=
colId
;
}
}
__syncthreads
();
for
(
int
stride
=
blockSize
/
2
;
stride
>
0
;
stride
=
stride
/
2
)
{
if
(
tid
<
stride
)
{
if
(
max_s
[
tid
]
<
max_s
[
tid
+
stride
])
{
max_s
[
tid
]
=
max_s
[
tid
+
stride
];
max_l
[
tid
]
=
max_l
[
tid
+
stride
];
}
}
__syncthreads
();
}
__syncthreads
();
if
(
tid
==
0
)
{
out_C
[
rowId
]
=
(
max_l
[
0
]
==
in_B
[
rowId
]
?
0
:
1.0
f
);
}
}
void
hl_matrix_classification_error
(
real
*
A_d
,
int
*
B_d
,
real
*
C_d
,
int
dimM
,
int
dimN
)
{
CHECK_NOTNULL
(
A_d
);
CHECK_NOTNULL
(
B_d
);
CHECK_NOTNULL
(
C_d
);
// each sample is calculated by one block
KeMatrixClassificationError
<
1024
><<<
dimM
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
A_d
,
B_d
,
C_d
,
dimN
);
CHECK_SYNC
(
"hl_matrix_classification_error"
);
}
__global__
void
KeMatrixMultiBinaryCrossEntropy
(
real
*
output
,
__global__
void
KeMatrixMultiBinaryCrossEntropy
(
real
*
output
,
real
*
entropy
,
real
*
entropy
,
int
*
row
,
int
*
row
,
...
...
paddle/cuda/src/hl_top_k.cu
浏览文件 @
f2c7c9bc
...
@@ -384,3 +384,81 @@ void hl_sparse_matrix_top_k(real* topVal, int ldv,
...
@@ -384,3 +384,81 @@ void hl_sparse_matrix_top_k(real* topVal, int ldv,
CHECK_SYNC
(
"hl_sparse_matrix_top_k failed"
);
CHECK_SYNC
(
"hl_sparse_matrix_top_k failed"
);
}
}
/**
* Each block compute one sample.
* In a block:
* 1. every thread get top maxLength value;
* 2. merge to shTopK, block reduce and get max value;
* 3. go to the second setp, until one thread's topK value is null;
* 4. go to the first setp, until get the topK value.
*/
template
<
int
maxLength
,
int
blockSize
>
__global__
void
KeMatrixTopKClassificationError
(
real
*
topVal
,
int
ldv
,
int
*
topIds
,
real
*
src
,
int
lds
,
int
dim
,
int
beamSize
,
int
*
label
,
real
*
recResult
)
{
__shared__
Pair
shTopK
[
blockSize
];
__shared__
int
maxId
[
blockSize
/
2
];
const
int
tid
=
threadIdx
.
x
;
const
int
warp
=
threadIdx
.
x
/
32
;
src
+=
blockIdx
.
x
*
lds
;
topVal
+=
blockIdx
.
x
*
ldv
;
topIds
+=
blockIdx
.
x
*
beamSize
;
Pair
topK
[
maxLength
];
// NOLINT
int
beam
=
maxLength
;
Pair
max
;
bool
isEmpty
=
false
;
bool
firstStep
=
true
;
int
topkSize
=
beamSize
;
for
(
int
k
=
0
;
k
<
maxLength
;
k
++
)
{
topK
[
k
].
set
(
-
HL_FLOAT_MAX
,
-
1
);
}
while
(
beamSize
)
{
threadGetTopK
<
maxLength
,
blockSize
>
(
topK
,
beam
,
beamSize
,
src
,
firstStep
,
isEmpty
,
max
,
dim
,
tid
);
shTopK
[
tid
]
=
topK
[
0
];
blockReduce
<
maxLength
,
blockSize
>
(
shTopK
,
maxId
,
topK
,
&
topVal
,
&
topIds
,
beam
,
beamSize
,
tid
,
warp
);
}
__syncthreads
();
if
(
tid
==
0
)
{
for
(
int
i
=
0
;
i
<
topkSize
;
i
++
)
{
if
(
*--
topIds
==
label
[
blockIdx
.
x
])
{
recResult
[
blockIdx
.
x
]
=
0
;
break
;
}
recResult
[
blockIdx
.
x
]
=
1.0
f
;
}
}
}
void
hl_matrix_classification_error
(
real
*
topVal
,
int
ldv
,
int
*
topIds
,
real
*
src
,
int
lds
,
int
dim
,
int
topkSize
,
int
numSamples
,
int
*
label
,
real
*
recResult
)
{
CHECK_NOTNULL
(
topVal
);
CHECK_NOTNULL
(
topIds
);
CHECK_NOTNULL
(
src
);
if
(
topkSize
>
dim
)
topkSize
=
dim
;
dim3
threads
(
256
,
1
);
dim3
grid
(
numSamples
,
1
);
KeMatrixTopKClassificationError
<
5
,
256
>
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
topVal
,
ldv
,
topIds
,
src
,
lds
,
dim
,
topkSize
,
label
,
recResult
);
CHECK_SYNC
(
"hl_matrix_top_k classification error failed"
);
}
paddle/gserver/evaluators/Evaluator.cpp
浏览文件 @
f2c7c9bc
...
@@ -39,6 +39,14 @@ void Evaluator::eval(const NeuralNetwork& nn) {
...
@@ -39,6 +39,14 @@ void Evaluator::eval(const NeuralNetwork& nn) {
*/
*/
class
ClassificationErrorEvaluator
:
public
Evaluator
{
class
ClassificationErrorEvaluator
:
public
Evaluator
{
public:
public:
/*
ClassificationErrorEvaluator() : totalScore2_(0) {}
virtual void start() {
Evaluator::start();
totalScore2_ = 0;
} */
virtual
void
updateSamplesNum
(
const
std
::
vector
<
Argument
>&
arguments
)
{
virtual
void
updateSamplesNum
(
const
std
::
vector
<
Argument
>&
arguments
)
{
if
(
3
==
arguments
.
size
())
{
if
(
3
==
arguments
.
size
())
{
numSamples_
+=
arguments
[
2
].
value
->
getSum
();
numSamples_
+=
arguments
[
2
].
value
->
getSum
();
...
@@ -76,9 +84,11 @@ public:
...
@@ -76,9 +84,11 @@ public:
1
,
1
,
/* trans= */
false
,
/* trans= */
false
,
useGpu
(
arguments
[
0
].
deviceId
));
useGpu
(
arguments
[
0
].
deviceId
));
errorMat
->
zeroMem
();
errorMat
->
zeroMem
();
if
(
label
!=
nullptr
)
{
if
(
label
!=
nullptr
)
{
errorMat
->
classificationError
(
*
output
,
*
label
);
errorMat
->
classificationError
(
*
output
,
*
label
,
config_
.
top_k
()
);
}
else
if
(
dynamic_cast
<
CpuSparseMatrix
*>
(
multiBinaryLabel
.
get
())
||
}
else
if
(
dynamic_cast
<
CpuSparseMatrix
*>
(
multiBinaryLabel
.
get
())
||
dynamic_cast
<
GpuSparseMatrix
*>
(
multiBinaryLabel
.
get
()))
{
dynamic_cast
<
GpuSparseMatrix
*>
(
multiBinaryLabel
.
get
()))
{
errorMat
->
classificationErrorMulti
(
errorMat
->
classificationErrorMulti
(
...
@@ -94,6 +104,16 @@ public:
...
@@ -94,6 +104,16 @@ public:
return
errorMat
;
return
errorMat
;
}
}
void
printStats
(
std
::
ostream
&
os
)
const
{
if
(
config_
.
top_k
()
==
1
)
{
os
<<
config_
.
name
()
<<
"="
<<
(
numSamples_
?
totalScore_
/
numSamples_
:
0
);
}
else
{
os
<<
" top_"
<<
config_
.
top_k
()
<<
"_error="
<<
(
numSamples_
?
totalScore_
/
numSamples_
:
0
);
}
}
virtual
real
evalImp
(
std
::
vector
<
Argument
>&
arguments
)
{
virtual
real
evalImp
(
std
::
vector
<
Argument
>&
arguments
)
{
MatrixPtr
errorMat
=
calcError
(
arguments
);
MatrixPtr
errorMat
=
calcError
(
arguments
);
return
errorMat
->
getSum
();
return
errorMat
->
getSum
();
...
...
paddle/gserver/layers/Layer.h
浏览文件 @
f2c7c9bc
...
@@ -311,6 +311,7 @@ public:
...
@@ -311,6 +311,7 @@ public:
return
*
output
->
second
;
return
*
output
->
second
;
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"No specific output "
<<
str
;
LOG
(
FATAL
)
<<
"No specific output "
<<
str
;
return
*
((
Argument
*
)
nullptr
);
}
}
}
}
}
}
...
...
paddle/gserver/tests/test_Evaluator.cpp
浏览文件 @
f2c7c9bc
...
@@ -129,6 +129,7 @@ void testEvaluatorAll(TestConfig testConf,
...
@@ -129,6 +129,7 @@ void testEvaluatorAll(TestConfig testConf,
TEST
(
Evaluator
,
classification_error
)
{
TEST
(
Evaluator
,
classification_error
)
{
TestConfig
config
;
TestConfig
config
;
config
.
evaluatorConfig
.
set_type
(
"classification_error"
);
config
.
evaluatorConfig
.
set_type
(
"classification_error"
);
config
.
evaluatorConfig
.
set_top_k
(
5
);
config
.
inputDefs
.
push_back
({
INPUT_DATA
,
"output"
,
50
});
config
.
inputDefs
.
push_back
({
INPUT_DATA
,
"output"
,
50
});
config
.
inputDefs
.
push_back
({
INPUT_LABEL
,
"label"
,
50
});
config
.
inputDefs
.
push_back
({
INPUT_LABEL
,
"label"
,
50
});
...
...
paddle/math/Matrix.cpp
浏览文件 @
f2c7c9bc
...
@@ -732,6 +732,7 @@ void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
...
@@ -732,6 +732,7 @@ void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
size_t
beam
=
maxVal
.
getWidth
();
size_t
beam
=
maxVal
.
getWidth
();
CHECK_EQ
(
maxIds
.
getSize
(),
numSamples
*
beam
);
CHECK_EQ
(
maxIds
.
getSize
(),
numSamples
*
beam
);
CHECK_EQ
(
maxVal
.
getHeight
(),
numSamples
);
CHECK_EQ
(
maxVal
.
getHeight
(),
numSamples
);
CHECK_EQ
(
maxVal
.
getWidth
(),
beam
);
hl_matrix_top_k
(
maxVal
.
getData
(),
hl_matrix_top_k
(
maxVal
.
getData
(),
maxVal
.
getStride
(),
maxVal
.
getStride
(),
...
@@ -792,19 +793,32 @@ void GpuMatrix::maxoutBackward(Matrix& a,
...
@@ -792,19 +793,32 @@ void GpuMatrix::maxoutBackward(Matrix& a,
}
}
/*calulate the error of classification */
/*calulate the error of classification */
void
GpuMatrix
::
classificationError
(
Matrix
&
output
,
IVector
&
label
)
{
void
GpuMatrix
::
classificationError
(
Matrix
&
output
,
auto
output_ptr
=
dynamic_cast
<
const
GpuMatrix
*>
(
&
output
);
IVector
&
label
,
auto
label_ptr
=
dynamic_cast
<
const
GpuIVector
*>
(
&
label
);
size_t
topkSize
)
{
CHECK
(
output_ptr
&&
label_ptr
)
<<
"Invalid argument pointer"
;
auto
gpuOutput
=
dynamic_cast
<
GpuMatrix
*>
(
&
output
);
auto
gpuLabel
=
dynamic_cast
<
GpuIVector
*>
(
&
label
);
CHECK
(
height_
==
output_ptr
->
height_
&&
width_
==
1
)
size_t
numSamples
=
this
->
getHeight
();
GpuMatrixPtr
gpuTopVal
=
std
::
make_shared
<
GpuMatrix
>
(
numSamples
,
topkSize
);
GpuIVectorPtr
gpuTopIds
=
std
::
make_shared
<
GpuIVector
>
(
numSamples
*
topkSize
);
CHECK
(
gpuOutput
&&
gpuLabel
)
<<
"Invalid argument pointer"
;
CHECK
(
gpuTopVal
&&
gpuTopIds
)
<<
"Allocate GPU memory failed"
;
CHECK
(
gpuLabel
->
getSize
()
==
numSamples
)
<<
"Vector size is not equal"
;
CHECK
(
numSamples
==
gpuOutput
->
getHeight
()
&&
this
->
getWidth
()
==
1
)
<<
"Matrix dimensions are not equal"
;
<<
"Matrix dimensions are not equal"
;
hl_matrix_classification_error
((
real
*
)
output_ptr
->
data_
,
size_t
dim
=
gpuOutput
->
getWidth
();
(
int
*
)
label_ptr
->
getData
(),
hl_matrix_classification_error
(
gpuTopVal
->
getData
(),
data_
,
gpuTopVal
->
getStride
(),
height_
,
gpuTopIds
->
getData
(),
output_ptr
->
width_
);
gpuOutput
->
getData
(),
gpuOutput
->
getStride
(),
dim
,
topkSize
,
numSamples
,
gpuLabel
->
getData
(),
this
->
getData
());
}
}
/* copy -log(output[i * width + label]) to this->data[i] */
/* copy -log(output[i * width + label]) to this->data[i] */
...
@@ -3039,7 +3053,7 @@ void CpuMatrix::rowMax(Matrix& max) {
...
@@ -3039,7 +3053,7 @@ void CpuMatrix::rowMax(Matrix& max) {
max
.
maxRows
(
*
this
);
max
.
maxRows
(
*
this
);
}
}
/*
get beam size of max ids and values
*/
/*
Get the top k elements of each row of this matrix
*/
void
CpuMatrix
::
rowMax
(
IVector
&
maxIds
,
Matrix
&
maxVal
)
{
void
CpuMatrix
::
rowMax
(
IVector
&
maxIds
,
Matrix
&
maxVal
)
{
CHECK
(
isContiguous
());
CHECK
(
isContiguous
());
CHECK
(
!
maxIds
.
useGpu
()
&&
!
maxVal
.
useGpu
())
<<
"Matrix type are not equal"
;
CHECK
(
!
maxIds
.
useGpu
()
&&
!
maxVal
.
useGpu
())
<<
"Matrix type are not equal"
;
...
@@ -3047,6 +3061,7 @@ void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
...
@@ -3047,6 +3061,7 @@ void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
size_t
beam
=
maxVal
.
getWidth
();
size_t
beam
=
maxVal
.
getWidth
();
CHECK_EQ
(
maxIds
.
getSize
(),
numSamples
*
beam
);
CHECK_EQ
(
maxIds
.
getSize
(),
numSamples
*
beam
);
CHECK_EQ
(
maxVal
.
getHeight
(),
numSamples
);
CHECK_EQ
(
maxVal
.
getHeight
(),
numSamples
);
CHECK_EQ
(
maxVal
.
getWidth
(),
beam
);
real
*
a
=
getData
();
real
*
a
=
getData
();
int
*
s
=
maxIds
.
getData
();
int
*
s
=
maxIds
.
getData
();
...
@@ -3198,32 +3213,39 @@ void CpuMatrix::rowNormalizeL1(Matrix& out) {
...
@@ -3198,32 +3213,39 @@ void CpuMatrix::rowNormalizeL1(Matrix& out) {
}
}
/* calulate classification error */
/* calulate classification error */
void
CpuMatrix
::
classificationError
(
Matrix
&
output
,
IVector
&
label
)
{
void
CpuMatrix
::
classificationError
(
Matrix
&
output
,
CHECK
(
dynamic_cast
<
const
CpuMatrix
*>
(
&
output
));
IVector
&
label
,
CHECK
(
dynamic_cast
<
const
CpuIVector
*>
(
&
label
));
size_t
topkSize
)
{
size_t
numSamples
=
this
->
getHeight
();
auto
cpuOutput
=
dynamic_cast
<
CpuMatrix
*>
(
&
output
);
auto
cpuLabel
=
dynamic_cast
<
CpuIVector
*>
(
&
label
);
IVectorPtr
cpuTopIds
=
std
::
make_shared
<
CpuIVector
>
(
numSamples
*
topkSize
);
MatrixPtr
cpuTopVal
=
std
::
make_shared
<
CpuMatrix
>
(
numSamples
,
topkSize
);
CHECK
(
cpuOutput
&&
cpuLabel
)
<<
"Invalid argument pointer"
;
CHECK
(
cpuTopIds
&&
cpuTopVal
)
<<
"Allocate cpu memory failed"
;
CHECK
(
cpuLabel
->
getSize
()
==
numSamples
)
<<
"Vector size is not equal"
;
CHECK
(
cpuOutput
->
getHeight
()
==
numSamples
&&
this
->
getWidth
()
==
1
)
<<
"Matrix dimensions are not equal"
;
CHECK_EQ
(
getWidth
(),
(
size_t
)
1
);
// top k matrix classification
size_t
numSamples
=
getHeight
();
cpuOutput
->
rowMax
(
*
cpuTopIds
,
*
cpuTopVal
);
CHECK_EQ
(
label
.
getSize
(),
numSamples
);
CHECK_EQ
(
output
.
getHeight
(),
numSamples
);
size_t
dim
=
output
.
getWidth
();
size_t
dim
=
cpuOutput
->
getWidth
();
real
*
out
=
output
.
getData
();
real
*
result
=
this
->
getData
();
int
*
lbl
=
label
.
getData
();
int
*
ids
=
cpuTopIds
->
getData
();
real
maxData
=
0.0
;
int
*
lbl
=
cpuLabel
->
getData
();
int
maxIndex
=
-
1
;
for
(
size_t
i
=
0
;
i
<
numSamples
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
numSamples
;
++
i
)
{
CHECK_GE
(
lbl
[
i
],
0
);
CHECK_GE
(
lbl
[
i
],
0
);
CHECK_LT
((
size_t
)
lbl
[
i
],
dim
);
CHECK_LT
((
size_t
)
lbl
[
i
],
dim
);
maxData
=
out
[
i
*
dim
];
maxIndex
=
0
;
for
(
size_t
j
=
0
;
j
<
topkSize
;
++
j
)
{
for
(
size_t
j
=
0
;
j
<
dim
;
++
j
)
{
if
(
ids
[
j
+
i
*
topkSize
]
==
lbl
[
i
])
{
if
(
maxData
<
out
[
i
*
dim
+
j
])
{
result
[
i
]
=
0
;
maxIndex
=
j
;
break
;
maxData
=
out
[
i
*
dim
+
j
];
}
}
result
[
i
]
=
1.0
f
;
}
}
getData
()[
i
]
=
(
maxIndex
!=
lbl
[
i
]);
}
}
}
}
...
...
paddle/math/Matrix.h
浏览文件 @
f2c7c9bc
...
@@ -836,8 +836,11 @@ public:
...
@@ -836,8 +836,11 @@ public:
* output[i] = 1 if row i is an error.
* output[i] = 1 if row i is an error.
*
*
* output[i] = 0 if row i is correct.
* output[i] = 0 if row i is correct.
*
*/
*/
virtual
void
classificationError
(
Matrix
&
output
,
IVector
&
label
)
{
virtual
void
classificationError
(
Matrix
&
output
,
IVector
&
label
,
size_t
topkSize
=
1
)
{
LOG
(
FATAL
)
<<
"Not implemented"
;
LOG
(
FATAL
)
<<
"Not implemented"
;
}
}
...
@@ -1314,7 +1317,7 @@ public:
...
@@ -1314,7 +1317,7 @@ public:
void
check
(
std
::
ostream
&
os
,
Matrix
&
refMat
,
bool
printDiff
=
true
);
void
check
(
std
::
ostream
&
os
,
Matrix
&
refMat
,
bool
printDiff
=
true
);
void
randomizeUniform
();
void
randomizeUniform
();
void
classificationError
(
Matrix
&
output
,
IVector
&
label
);
void
classificationError
(
Matrix
&
output
,
IVector
&
label
,
size_t
topkSize
=
1
);
void
convExpand
(
Matrix
&
feature
,
void
convExpand
(
Matrix
&
feature
,
int
feaImgHeight
,
int
feaImgHeight
,
...
@@ -1739,7 +1742,7 @@ public:
...
@@ -1739,7 +1742,7 @@ public:
void
randomizeUniform
();
void
randomizeUniform
();
void
classificationError
(
Matrix
&
output
,
IVector
&
label
);
void
classificationError
(
Matrix
&
output
,
IVector
&
label
,
size_t
topkSize
=
1
);
void
addByBitCode
(
size_t
numClasses
,
const
IVector
&
codes
,
const
Matrix
&
vec
);
void
addByBitCode
(
size_t
numClasses
,
const
IVector
&
codes
,
const
Matrix
&
vec
);
...
...
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
f2c7c9bc
...
@@ -764,7 +764,7 @@ TEST(Matrix, paramReluBackwardDiff) {
...
@@ -764,7 +764,7 @@ TEST(Matrix, paramReluBackwardDiff) {
}
}
}
}
void
testClassificationError
(
int
numSamples
,
int
dim
)
{
void
testClassificationError
(
int
numSamples
,
int
dim
,
int
topkSize
)
{
MatrixPtr
cpuError
=
std
::
make_shared
<
CpuMatrix
>
(
numSamples
,
1
);
MatrixPtr
cpuError
=
std
::
make_shared
<
CpuMatrix
>
(
numSamples
,
1
);
MatrixPtr
gpuError
=
std
::
make_shared
<
GpuMatrix
>
(
numSamples
,
1
);
MatrixPtr
gpuError
=
std
::
make_shared
<
GpuMatrix
>
(
numSamples
,
1
);
MatrixPtr
cpuOutput
=
std
::
make_shared
<
CpuMatrix
>
(
numSamples
,
dim
);
MatrixPtr
cpuOutput
=
std
::
make_shared
<
CpuMatrix
>
(
numSamples
,
dim
);
...
@@ -777,17 +777,22 @@ void testClassificationError(int numSamples, int dim) {
...
@@ -777,17 +777,22 @@ void testClassificationError(int numSamples, int dim) {
gpuOutput
->
copyFrom
(
*
cpuOutput
);
gpuOutput
->
copyFrom
(
*
cpuOutput
);
gpuLabel
->
copyFrom
(
*
cpuLabel
);
gpuLabel
->
copyFrom
(
*
cpuLabel
);
cpuError
->
classificationError
(
*
cpuOutput
,
*
cpuLabel
);
cpuError
->
classificationError
(
*
cpuOutput
,
*
cpuLabel
,
topkSize
);
gpuError
->
classificationError
(
*
gpuOutput
,
*
gpuLabel
);
gpuError
->
classificationError
(
*
gpuOutput
,
*
gpuLabel
,
topkSize
);
TensorCheckEqual
(
*
cpuError
,
*
gpuError
);
TensorCheckEqual
(
*
cpuError
,
*
gpuError
);
}
}
TEST
(
Matrix
,
classificationError
)
{
TEST
(
Matrix
,
classificationError
)
{
for
(
auto
numSamples
:
{
1
,
10
,
100
,
1000
,
70000
})
{
for
(
auto
numSamples
:
{
1
,
5
,
31
,
90
,
150
,
300
})
{
for
(
auto
dim
:
{
1
,
10
,
100
,
1000
})
{
for
(
auto
dim
:
VLOG
(
3
)
<<
" numSamples="
<<
numSamples
<<
" dim="
<<
dim
;
{
1
,
5
,
8
,
10
,
15
,
64
,
80
,
120
,
256
,
300
,
1280
,
5120
,
50000
})
{
testClassificationError
(
numSamples
,
dim
);
for
(
auto
topkSize
:
{
1
,
5
,
10
,
20
,
40
,
(
int
)
rand
()
%
dim
+
1
})
{
if
(
topkSize
>
dim
)
continue
;
VLOG
(
3
)
<<
" sample= "
<<
numSamples
<<
" topkSize= "
<<
topkSize
<<
" dim= "
<<
dim
;
testClassificationError
(
numSamples
,
dim
,
topkSize
);
}
}
}
}
}
}
}
...
...
paddle/parameter/Parameter.cpp
浏览文件 @
f2c7c9bc
...
@@ -375,10 +375,6 @@ bool Parameter::load(const std::string& filename) {
...
@@ -375,10 +375,6 @@ bool Parameter::load(const std::string& filename) {
std
::
ifstream
fs
(
filename
,
std
::
ios_base
::
binary
);
std
::
ifstream
fs
(
filename
,
std
::
ios_base
::
binary
);
if
(
!
fs
)
{
if
(
!
fs
)
{
LOG
(
INFO
)
<<
"missing parameters ["
<<
filename
<<
"] while loading model."
;
LOG
(
INFO
)
<<
"missing parameters ["
<<
filename
<<
"] while loading model."
;
if
(
isStatic
())
{
LOG
(
FATAL
)
<<
getName
()
<<
" is static but missing, not allowed."
;
return
false
;
}
if
(
kMissParameterFail
==
FLAGS_load_missing_parameter_strategy
)
{
if
(
kMissParameterFail
==
FLAGS_load_missing_parameter_strategy
)
{
LOG
(
FATAL
)
<<
getName
()
<<
" missing, not allowed."
;
LOG
(
FATAL
)
<<
getName
()
<<
" missing, not allowed."
;
return
false
;
return
false
;
...
...
proto/ModelConfig.proto
浏览文件 @
f2c7c9bc
...
@@ -475,6 +475,10 @@ message EvaluatorConfig {
...
@@ -475,6 +475,10 @@ message EvaluatorConfig {
// Used by ChunkEvaluator
// Used by ChunkEvaluator
// chunk of these types are not counted
// chunk of these types are not counted
repeated
int32
excluded_chunk_types
=
12
;
repeated
int32
excluded_chunk_types
=
12
;
// Used by ClassificationErrorEvaluator
// top # classification error
optional
int32
top_k
=
13
[
default
=
1
];
}
}
message
LinkConfig
{
message
LinkConfig
{
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
f2c7c9bc
...
@@ -1253,6 +1253,7 @@ def Evaluator(
...
@@ -1253,6 +1253,7 @@ def Evaluator(
dict_file
=
None
,
dict_file
=
None
,
result_file
=
None
,
result_file
=
None
,
num_results
=
None
,
num_results
=
None
,
top_k
=
None
,
delimited
=
None
,
delimited
=
None
,
excluded_chunk_types
=
None
,
):
excluded_chunk_types
=
None
,
):
evaluator
=
g_config
.
model_config
.
evaluators
.
add
()
evaluator
=
g_config
.
model_config
.
evaluators
.
add
()
...
@@ -1280,6 +1281,8 @@ def Evaluator(
...
@@ -1280,6 +1281,8 @@ def Evaluator(
evaluator
.
result_file
=
result_file
evaluator
.
result_file
=
result_file
if
num_results
is
not
None
:
if
num_results
is
not
None
:
evaluator
.
num_results
=
num_results
evaluator
.
num_results
=
num_results
if
top_k
is
not
None
:
evaluator
.
top_k
=
top_k
if
delimited
is
not
None
:
if
delimited
is
not
None
:
evaluator
.
delimited
=
delimited
evaluator
.
delimited
=
delimited
...
...
python/paddle/trainer_config_helpers/evaluators.py
浏览文件 @
f2c7c9bc
...
@@ -71,6 +71,7 @@ def evaluator_base(
...
@@ -71,6 +71,7 @@ def evaluator_base(
result_file
=
None
,
result_file
=
None
,
num_results
=
None
,
num_results
=
None
,
delimited
=
None
,
delimited
=
None
,
top_k
=
None
,
excluded_chunk_types
=
None
,
):
excluded_chunk_types
=
None
,
):
"""
"""
Evaluator will evaluate the network status while training/testing.
Evaluator will evaluate the network status while training/testing.
...
@@ -104,12 +105,15 @@ def evaluator_base(
...
@@ -104,12 +105,15 @@ def evaluator_base(
:param weight: An input layer which is a weight for each sample.
:param weight: An input layer which is a weight for each sample.
Each evaluator may calculate differently to use this weight.
Each evaluator may calculate differently to use this weight.
:type weight: LayerOutput.
:type weight: LayerOutput.
:param top_k: number k in top-k error rate
:type top_k: int
"""
"""
# inputs type assertions.
# inputs type assertions.
assert
classification_threshold
is
None
or
isinstance
(
assert
classification_threshold
is
None
or
isinstance
(
classification_threshold
,
float
)
classification_threshold
,
float
)
assert
positive_label
is
None
or
isinstance
(
positive_label
,
int
)
assert
positive_label
is
None
or
isinstance
(
positive_label
,
int
)
assert
num_results
is
None
or
isinstance
(
num_results
,
int
)
assert
num_results
is
None
or
isinstance
(
num_results
,
int
)
assert
top_k
is
None
or
isinstance
(
top_k
,
int
)
if
not
isinstance
(
input
,
list
):
if
not
isinstance
(
input
,
list
):
input
=
[
input
]
input
=
[
input
]
...
@@ -130,6 +134,8 @@ def evaluator_base(
...
@@ -130,6 +134,8 @@ def evaluator_base(
dict_file
=
dict_file
,
dict_file
=
dict_file
,
result_file
=
result_file
,
result_file
=
result_file
,
delimited
=
delimited
,
delimited
=
delimited
,
num_results
=
num_results
,
top_k
=
top_k
,
excluded_chunk_types
=
excluded_chunk_types
,
)
excluded_chunk_types
=
excluded_chunk_types
,
)
...
@@ -139,6 +145,7 @@ def classification_error_evaluator(input,
...
@@ -139,6 +145,7 @@ def classification_error_evaluator(input,
label
,
label
,
name
=
None
,
name
=
None
,
weight
=
None
,
weight
=
None
,
top_k
=
None
,
threshold
=
None
):
threshold
=
None
):
"""
"""
Classification Error Evaluator. It will print error rate for classification.
Classification Error Evaluator. It will print error rate for classification.
...
@@ -167,6 +174,8 @@ def classification_error_evaluator(input,
...
@@ -167,6 +174,8 @@ def classification_error_evaluator(input,
then means not set weight. The larger weight it is, the more
then means not set weight. The larger weight it is, the more
important this sample is.
important this sample is.
:type weight: LayerOutput
:type weight: LayerOutput
:param top_k: number k in top-k error rate
:type top_k: int
:param threshold: The classification threshold.
:param threshold: The classification threshold.
:type threshold: float
:type threshold: float
:return: None.
:return: None.
...
@@ -178,6 +187,7 @@ def classification_error_evaluator(input,
...
@@ -178,6 +187,7 @@ def classification_error_evaluator(input,
input
=
input
,
input
=
input
,
label
=
label
,
label
=
label
,
weight
=
weight
,
weight
=
weight
,
top_k
=
top_k
,
classification_threshold
=
threshold
,
)
classification_threshold
=
threshold
,
)
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
f2c7c9bc
...
@@ -2870,8 +2870,8 @@ def gru_step_layer(input,
...
@@ -2870,8 +2870,8 @@ def gru_step_layer(input,
:param name:
:param name:
:param gate_act:
:param gate_act:
:param bias_attr:
:param bias_attr:
:param param_attr: the parameter_attribute for transforming the output_mem
:param param_attr: the parameter_attribute for transforming the output_mem
from previous step.
from previous step.
:param layer_attr:
:param layer_attr:
:return: LayerOutput object.
:return: LayerOutput object.
:rtype: LayerOutput
:rtype: LayerOutput
...
@@ -2882,10 +2882,10 @@ def gru_step_layer(input,
...
@@ -2882,10 +2882,10 @@ def gru_step_layer(input,
Layer
(
Layer
(
name
=
name
,
name
=
name
,
type
=
LayerType
.
GRU_STEP_LAYER
,
type
=
LayerType
.
GRU_STEP_LAYER
,
# The parameter here is for transforming the output_mem. The input has
# The parameter here is for transforming the output_mem. The input has
# already been transformed outside this module so it does not need
# already been transformed outside this module so it does not need
# parameter associated with it.
# parameter associated with it.
# The parameter here is instead grouped with input is due to
# The parameter here is instead grouped with input is due to
# backward model compatibility.
# backward model compatibility.
inputs
=
[
Input
(
input
.
name
,
**
param_attr
.
attr
),
output_mem
.
name
],
inputs
=
[
Input
(
input
.
name
,
**
param_attr
.
attr
),
output_mem
.
name
],
bias
=
ParamAttr
.
to_bias
(
bias_attr
),
bias
=
ParamAttr
.
to_bias
(
bias_attr
),
...
@@ -3536,6 +3536,7 @@ def classification_cost(input,
...
@@ -3536,6 +3536,7 @@ def classification_cost(input,
label
,
label
,
weight
=
None
,
weight
=
None
,
name
=
None
,
name
=
None
,
top_k
=
None
,
evaluator
=
classification_error_evaluator
,
evaluator
=
classification_error_evaluator
,
layer_attr
=
None
):
layer_attr
=
None
):
"""
"""
...
@@ -3550,6 +3551,8 @@ def classification_cost(input,
...
@@ -3550,6 +3551,8 @@ def classification_cost(input,
:param weight: The weight affects the cost, namely the scale of cost.
:param weight: The weight affects the cost, namely the scale of cost.
It is an optional argument.
It is an optional argument.
:type weight: LayerOutput
:type weight: LayerOutput
:param top_k: number k in top-k error rate
:type top_k: int
:param evaluator: Evaluator method.
:param evaluator: Evaluator method.
:param layer_attr: layer's extra attribute.
:param layer_attr: layer's extra attribute.
:type layer_attr: ExtraLayerAttribute
:type layer_attr: ExtraLayerAttribute
...
@@ -3577,7 +3580,7 @@ def classification_cost(input,
...
@@ -3577,7 +3580,7 @@ def classification_cost(input,
assert
isinstance
(
e
.
for_classification
,
bool
)
assert
isinstance
(
e
.
for_classification
,
bool
)
assert
e
.
for_classification
assert
e
.
for_classification
e
(
name
=
e
.
__name__
,
input
=
input
,
label
=
label
,
weight
=
weight
)
e
(
name
=
e
.
__name__
,
input
=
input
,
label
=
label
,
weight
=
weight
,
top_k
=
top_k
)
if
not
isinstance
(
evaluator
,
collections
.
Sequence
):
if
not
isinstance
(
evaluator
,
collections
.
Sequence
):
evaluator
=
[
evaluator
]
evaluator
=
[
evaluator
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录