Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
b1ab8b56
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 2 年 前同步成功
通知
708
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b1ab8b56
编写于
2月 23, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use plain C++ 03 to implement getStatsInfo.
上级
bb751b7f
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
129 addition
and
116 deletion
+129
-116
paddle/gserver/evaluators/Evaluator.cpp
paddle/gserver/evaluators/Evaluator.cpp
+99
-97
paddle/gserver/evaluators/Evaluator.h
paddle/gserver/evaluators/Evaluator.h
+10
-3
paddle/utils/Error.h
paddle/utils/Error.h
+16
-12
paddle/utils/tests/test_Error.cpp
paddle/utils/tests/test_Error.cpp
+4
-4
未找到文件。
paddle/gserver/evaluators/Evaluator.cpp
浏览文件 @
b1ab8b56
...
@@ -626,70 +626,26 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
...
@@ -626,70 +626,26 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
return
0
;
return
0
;
}
}
template
<
typename
T1
,
typename
T2
>
void
PrecisionRecallEvaluator
::
printStatsHelper
(
T1
labelCallback
,
T2
microAvgCallback
)
const
{
int
label
=
config_
.
positive_label
();
if
(
label
!=
-
1
)
{
CHECK
(
label
>=
0
&&
label
<
(
int
)
statsInfo_
.
size
())
<<
"positive_label ["
<<
label
<<
"] should be in range [0, "
<<
statsInfo_
.
size
()
<<
")"
;
double
precision
=
calcPrecision
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FP
);
double
recall
=
calcRecall
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FN
);
labelCallback
(
label
,
precision
,
recall
,
calcF1Score
(
precision
,
recall
));
return
;
}
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
// macro average method: precision = (precision1+precision2)/2
double
microTotalTP
=
0
;
double
microTotalFP
=
0
;
double
microTotalFN
=
0
;
double
macroAvgPrecision
=
0
;
double
macroAvgRecall
=
0
;
size_t
numLabels
=
statsInfo_
.
size
();
for
(
size_t
i
=
0
;
i
<
numLabels
;
++
i
)
{
microTotalTP
+=
statsInfo_
[
i
].
TP
;
microTotalFP
+=
statsInfo_
[
i
].
FP
;
microTotalFN
+=
statsInfo_
[
i
].
FN
;
macroAvgPrecision
+=
calcPrecision
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FP
);
macroAvgRecall
+=
calcRecall
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FN
);
}
macroAvgPrecision
/=
numLabels
;
macroAvgRecall
/=
numLabels
;
double
macroAvgF1Score
=
calcF1Score
(
macroAvgPrecision
,
macroAvgRecall
);
double
microAvgPrecision
=
calcPrecision
(
microTotalTP
,
microTotalFP
);
double
microAvgRecall
=
calcPrecision
(
microTotalTP
,
microTotalFN
);
double
microAvgF1Score
=
calcF1Score
(
microAvgPrecision
,
microAvgRecall
);
microAvgCallback
(
macroAvgPrecision
,
macroAvgRecall
,
macroAvgF1Score
,
isMultiBinaryLabel_
,
microAvgPrecision
,
microAvgRecall
,
microAvgF1Score
);
}
void
PrecisionRecallEvaluator
::
printStats
(
std
::
ostream
&
os
)
const
{
void
PrecisionRecallEvaluator
::
printStats
(
std
::
ostream
&
os
)
const
{
this
->
printStatsHelper
(
double
precision
,
recall
,
f1
,
macroAvgPrecision
,
macroAvgRecall
,
[
&
os
](
int
label
,
double
precision
,
double
recall
,
double
f1
)
{
macroAvgF1Score
,
microAvgPrecision
,
microAvgRecall
,
microAvgF1Score
;
os
<<
"positive_label="
<<
label
<<
" precision="
<<
precision
bool
containMacroMicroInfo
=
getStatsInfo
(
&
precision
,
<<
" recall="
<<
recall
<<
" F1-score="
<<
f1
;
&
recall
,
},
&
f1
,
[
&
os
](
double
macroAvgPrecision
,
&
macroAvgPrecision
,
double
macroAvgRecall
,
&
macroAvgRecall
,
double
macroAvgF1Score
,
&
macroAvgF1Score
,
bool
isMultiBinaryLabel
,
&
microAvgPrecision
,
double
microAvgPrecision
,
&
microAvgRecall
,
double
microAvgRecall
,
&
microAvgF1Score
);
double
microAvgF1Score
)
{
os
<<
"positive_label="
<<
config_
.
positive_label
()
<<
" precision="
<<
precision
<<
" recall="
<<
recall
<<
" F1-score="
<<
f1
;
if
(
containMacroMicroInfo
)
{
os
<<
"macro-average-precision="
<<
macroAvgPrecision
os
<<
"macro-average-precision="
<<
macroAvgPrecision
<<
" macro-average-recall="
<<
macroAvgRecall
<<
" macro-average-recall="
<<
macroAvgRecall
<<
" macro-average-F1-score="
<<
macroAvgF1Score
;
<<
" macro-average-F1-score="
<<
macroAvgF1Score
;
if
(
!
isMultiBinaryLabel
)
{
if
(
!
isMultiBinaryLabel_
)
{
// precision and recall are equal in this case
// precision and recall are equal in this case
os
<<
" micro-average-precision="
<<
microAvgPrecision
;
os
<<
" micro-average-precision="
<<
microAvgPrecision
;
}
else
{
}
else
{
...
@@ -697,7 +653,7 @@ void PrecisionRecallEvaluator::printStats(std::ostream& os) const {
...
@@ -697,7 +653,7 @@ void PrecisionRecallEvaluator::printStats(std::ostream& os) const {
<<
" micro-average-recall="
<<
microAvgRecall
<<
" micro-average-recall="
<<
microAvgRecall
<<
" micro-average-F1-score="
<<
microAvgF1Score
;
<<
" micro-average-F1-score="
<<
microAvgF1Score
;
}
}
})
;
}
;
}
}
void
PrecisionRecallEvaluator
::
calcStatsInfo
(
const
MatrixPtr
&
output
,
void
PrecisionRecallEvaluator
::
calcStatsInfo
(
const
MatrixPtr
&
output
,
...
@@ -780,24 +736,25 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
...
@@ -780,24 +736,25 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
void
PrecisionRecallEvaluator
::
storeLocalValues
()
const
{
void
PrecisionRecallEvaluator
::
storeLocalValues
()
const
{
if
(
this
->
values_
.
size
()
==
0
)
{
if
(
this
->
values_
.
size
()
==
0
)
{
this
->
printStatsHelper
(
double
precision
,
recall
,
f1
,
macroAvgPrecision
,
macroAvgRecall
,
[
this
](
int
label
,
double
precision
,
double
recall
,
double
f1
)
{
macroAvgF1Score
,
microAvgPrecision
,
microAvgRecall
,
microAvgF1Score
;
values_
[
"positive_label"
]
=
(
double
)
label
;
bool
containMacroMicroInfo
=
getStatsInfo
(
&
precision
,
&
recall
,
&
f1
,
&
macroAvgPrecision
,
&
macroAvgRecall
,
&
macroAvgF1Score
,
&
microAvgPrecision
,
&
microAvgRecall
,
&
microAvgF1Score
);
values_
[
"precision"
]
=
precision
;
values_
[
"precision"
]
=
precision
;
values_
[
"recal"
]
=
recall
;
values_
[
"recal"
]
=
recall
;
values_
[
"F1-score"
]
=
f1
;
values_
[
"F1-score"
]
=
f1
;
},
if
(
containMacroMicroInfo
)
{
[
this
](
double
macroAvgPrecision
,
double
macroAvgRecall
,
double
macroAvgF1Score
,
bool
isMultiBinaryLabel
,
double
microAvgPrecision
,
double
microAvgRecall
,
double
microAvgF1Score
)
{
values_
[
"macro-average-precision"
]
=
macroAvgPrecision
;
values_
[
"macro-average-precision"
]
=
macroAvgPrecision
;
values_
[
"macro-average-recall"
]
=
macroAvgRecall
;
values_
[
"macro-average-recall"
]
=
macroAvgRecall
;
values_
[
"macro-average-F1-score"
]
=
macroAvgF1Score
;
values_
[
"macro-average-F1-score"
]
=
macroAvgF1Score
;
if
(
!
isMultiBinaryLabel
)
{
if
(
!
isMultiBinaryLabel_
)
{
// precision and recall are equal in this case
// precision and recall are equal in this case
values_
[
"micro-average-precision"
]
=
microAvgPrecision
;
values_
[
"micro-average-precision"
]
=
microAvgPrecision
;
}
else
{
}
else
{
...
@@ -805,7 +762,7 @@ void PrecisionRecallEvaluator::storeLocalValues() const {
...
@@ -805,7 +762,7 @@ void PrecisionRecallEvaluator::storeLocalValues() const {
values_
[
"micro-average-recall"
]
=
microAvgRecall
;
values_
[
"micro-average-recall"
]
=
microAvgRecall
;
values_
[
"micro-average-F1-score"
]
=
microAvgF1Score
;
values_
[
"micro-average-F1-score"
]
=
microAvgF1Score
;
}
}
});
}
}
}
}
}
...
@@ -865,6 +822,51 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
...
@@ -865,6 +822,51 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
delete
[]
buf
;
delete
[]
buf
;
}
}
bool
PrecisionRecallEvaluator
::
getStatsInfo
(
double
*
precision
,
double
*
recall
,
double
*
f1
,
double
*
macroAvgPrecision
,
double
*
macroAvgRecall
,
double
*
macroAvgF1Score
,
double
*
microAvgPrecision
,
double
*
microAvgRecall
,
double
*
microAvgF1Score
)
const
{
int
label
=
config_
.
positive_label
();
if
(
label
!=
-
1
)
{
CHECK
(
label
>=
0
&&
label
<
(
int
)
statsInfo_
.
size
())
<<
"positive_label ["
<<
label
<<
"] should be in range [0, "
<<
statsInfo_
.
size
()
<<
")"
;
*
precision
=
calcPrecision
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FP
);
*
recall
=
calcRecall
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FN
);
*
f1
=
calcF1Score
(
*
precision
,
*
recall
);
return
false
;
}
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
// macro average method: precision = (precision1+precision2)/2
double
microTotalTP
=
0
;
double
microTotalFP
=
0
;
double
microTotalFN
=
0
;
*
macroAvgPrecision
=
0
;
*
macroAvgRecall
=
0
;
size_t
numLabels
=
statsInfo_
.
size
();
for
(
size_t
i
=
0
;
i
<
numLabels
;
++
i
)
{
microTotalTP
+=
statsInfo_
[
i
].
TP
;
microTotalFP
+=
statsInfo_
[
i
].
FP
;
microTotalFN
+=
statsInfo_
[
i
].
FN
;
*
macroAvgPrecision
+=
calcPrecision
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FP
);
*
macroAvgRecall
+=
calcRecall
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FN
);
}
*
macroAvgPrecision
/=
numLabels
;
*
macroAvgRecall
/=
numLabels
;
*
macroAvgF1Score
=
calcF1Score
(
*
macroAvgPrecision
,
*
macroAvgRecall
);
*
microAvgPrecision
=
calcPrecision
(
microTotalTP
,
microTotalFP
);
*
microAvgRecall
=
calcPrecision
(
microTotalTP
,
microTotalFN
);
*
microAvgF1Score
=
calcF1Score
(
*
microAvgPrecision
,
*
microAvgRecall
);
return
true
;
}
REGISTER_EVALUATOR
(
pnpair
,
PnpairEvaluator
);
REGISTER_EVALUATOR
(
pnpair
,
PnpairEvaluator
);
void
PnpairEvaluator
::
start
()
{
void
PnpairEvaluator
::
start
()
{
Evaluator
::
start
();
Evaluator
::
start
();
...
...
paddle/gserver/evaluators/Evaluator.h
浏览文件 @
b1ab8b56
...
@@ -125,7 +125,7 @@ public:
...
@@ -125,7 +125,7 @@ public:
* has multiple field, the name could be `evaluator_name.field1`. For example
* has multiple field, the name could be `evaluator_name.field1`. For example
* the PrecisionRecallEvaluator contains `precision`, `recall` fields. The get
* the PrecisionRecallEvaluator contains `precision`, `recall` fields. The get
* names will return `precision_recall_evaluator.precision`,
* names will return `precision_recall_evaluator.precision`,
* `precision_recall.recal`, etc.
* `precision_recall
_evaluator
.recal`, etc.
*
*
* Also, if current Evaluator is a combined evaluator. getNames will return
* Also, if current Evaluator is a combined evaluator. getNames will return
* all names of all evaluators inside the combined evaluator.
* all names of all evaluators inside the combined evaluator.
...
@@ -387,8 +387,15 @@ private:
...
@@ -387,8 +387,15 @@ private:
IVectorPtr
cpuLabel_
;
IVectorPtr
cpuLabel_
;
MatrixPtr
cpuWeight_
;
MatrixPtr
cpuWeight_
;
template
<
typename
T1
,
typename
T2
>
bool
getStatsInfo
(
double
*
precision
,
void
printStatsHelper
(
T1
labelCallback
,
T2
microAvgCallback
)
const
;
double
*
recall
,
double
*
f1
,
double
*
macroAvgPrecision
,
double
*
macroAvgRecall
,
double
*
macroAvgF1Score
,
double
*
microAvgPrecision
,
double
*
microAvgRecall
,
double
*
microAvgF1Score
)
const
;
void
calcStatsInfo
(
const
MatrixPtr
&
output
,
void
calcStatsInfo
(
const
MatrixPtr
&
output
,
const
IVectorPtr
&
label
,
const
IVectorPtr
&
label
,
...
...
paddle/utils/Error.h
浏览文件 @
b1ab8b56
...
@@ -37,10 +37,10 @@ namespace paddle {
...
@@ -37,10 +37,10 @@ namespace paddle {
*
*
* Error __must_check bar() {
* Error __must_check bar() {
* // do something.
* // do something.
*
Status s
= foo(); // invoke other method return status.
*
Error err
= foo(); // invoke other method return status.
* if (
!s) return s
;
* if (
err) return err
;
* // do something else.
* // do something else.
* return
Status
();
* return
Error
();
* }
* }
* @endcode{cpp}
* @endcode{cpp}
*
*
...
@@ -53,8 +53,8 @@ namespace paddle {
...
@@ -53,8 +53,8 @@ namespace paddle {
*
*
* int foo(Error* error) {
* int foo(Error* error) {
* // Do something.
* // Do something.
* Error
s
= bar();
* Error
err
= bar();
* if (
!s
) {
* if (
err
) {
* *error = s;
* *error = s;
* return 0;
* return 0;
* }
* }
...
@@ -68,10 +68,10 @@ namespace paddle {
...
@@ -68,10 +68,10 @@ namespace paddle {
* }
* }
*
*
* Error foobar() {
* Error foobar() {
* Error
s
;
* Error
err
;
* // do something.
* // do something.
* foo(&
s
);
* foo(&
err
);
* if (
!s) return s
;
* if (
err) return err
;
* }
* }
* @endcode{cpp}
* @endcode{cpp}
*
*
...
@@ -112,18 +112,22 @@ public:
...
@@ -112,18 +112,22 @@ public:
}
}
/**
/**
* @brief operator bool, return True if there is
no
error.
* @brief operator bool, return True if there is
something
error.
*/
*/
operator
bool
()
const
{
return
msg_
==
nullptr
;
}
operator
bool
()
const
{
return
!
this
->
isOK
()
;
}
bool
isOK
()
const
{
return
*
this
;
}
/**
* @brief isOK return True if there is no error.
* @return True if no error.
*/
bool
isOK
()
const
{
return
msg_
==
nullptr
;
}
/**
/**
* @brief check this status by glog.
* @brief check this status by glog.
* @note It is a temp method used during cleaning Paddle code. It will be
* @note It is a temp method used during cleaning Paddle code. It will be
* removed later.
* removed later.
*/
*/
void
check
()
const
{
CHECK
(
*
this
)
<<
msg
();
}
void
check
()
const
{
CHECK
(
this
->
isOK
()
)
<<
msg
();
}
private:
private:
std
::
shared_ptr
<
std
::
string
>
msg_
;
std
::
shared_ptr
<
std
::
string
>
msg_
;
...
...
paddle/utils/tests/test_Error.cpp
浏览文件 @
b1ab8b56
...
@@ -18,17 +18,17 @@ limitations under the License. */
...
@@ -18,17 +18,17 @@ limitations under the License. */
TEST
(
Error
,
testAll
)
{
TEST
(
Error
,
testAll
)
{
paddle
::
Error
error
;
paddle
::
Error
error
;
ASSERT_TRUE
(
error
);
error
=
paddle
::
Error
(
"I'm the error"
);
ASSERT_FALSE
(
error
);
ASSERT_FALSE
(
error
);
error
=
paddle
::
Error
(
"I'm the error"
);
ASSERT_TRUE
(
error
);
ASSERT_STREQ
(
"I'm the error"
,
error
.
msg
());
ASSERT_STREQ
(
"I'm the error"
,
error
.
msg
());
error
=
paddle
::
Error
(
"error2"
);
error
=
paddle
::
Error
(
"error2"
);
ASSERT_
FALS
E
(
error
);
ASSERT_
TRU
E
(
error
);
ASSERT_STREQ
(
"error2"
,
error
.
msg
());
ASSERT_STREQ
(
"error2"
,
error
.
msg
());
int
i
=
3
;
int
i
=
3
;
auto
error3
=
paddle
::
Error
(
"error%d"
,
i
);
auto
error3
=
paddle
::
Error
(
"error%d"
,
i
);
ASSERT_
FALS
E
(
error3
);
ASSERT_
TRU
E
(
error3
);
ASSERT_STREQ
(
"error3"
,
error3
.
msg
());
ASSERT_STREQ
(
"error3"
,
error3
.
msg
());
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录