Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b1ab8b56
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b1ab8b56
编写于
2月 23, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use plain C++ 03 to implement getStatsInfo.
上级
bb751b7f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
129 addition
and
116 deletion
+129
-116
paddle/gserver/evaluators/Evaluator.cpp
paddle/gserver/evaluators/Evaluator.cpp
+99
-97
paddle/gserver/evaluators/Evaluator.h
paddle/gserver/evaluators/Evaluator.h
+10
-3
paddle/utils/Error.h
paddle/utils/Error.h
+16
-12
paddle/utils/tests/test_Error.cpp
paddle/utils/tests/test_Error.cpp
+4
-4
未找到文件。
paddle/gserver/evaluators/Evaluator.cpp
浏览文件 @
b1ab8b56
...
...
@@ -626,78 +626,34 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
return
0
;
}
template
<
typename
T1
,
typename
T2
>
void
PrecisionRecallEvaluator
::
printStatsHelper
(
T1
labelCallback
,
T2
microAvgCallback
)
const
{
int
label
=
config_
.
positive_label
();
if
(
label
!=
-
1
)
{
CHECK
(
label
>=
0
&&
label
<
(
int
)
statsInfo_
.
size
())
<<
"positive_label ["
<<
label
<<
"] should be in range [0, "
<<
statsInfo_
.
size
()
<<
")"
;
double
precision
=
calcPrecision
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FP
);
double
recall
=
calcRecall
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FN
);
labelCallback
(
label
,
precision
,
recall
,
calcF1Score
(
precision
,
recall
));
return
;
}
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
// macro average method: precision = (precision1+precision2)/2
double
microTotalTP
=
0
;
double
microTotalFP
=
0
;
double
microTotalFN
=
0
;
double
macroAvgPrecision
=
0
;
double
macroAvgRecall
=
0
;
size_t
numLabels
=
statsInfo_
.
size
();
for
(
size_t
i
=
0
;
i
<
numLabels
;
++
i
)
{
microTotalTP
+=
statsInfo_
[
i
].
TP
;
microTotalFP
+=
statsInfo_
[
i
].
FP
;
microTotalFN
+=
statsInfo_
[
i
].
FN
;
macroAvgPrecision
+=
calcPrecision
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FP
);
macroAvgRecall
+=
calcRecall
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FN
);
}
macroAvgPrecision
/=
numLabels
;
macroAvgRecall
/=
numLabels
;
double
macroAvgF1Score
=
calcF1Score
(
macroAvgPrecision
,
macroAvgRecall
);
double
microAvgPrecision
=
calcPrecision
(
microTotalTP
,
microTotalFP
);
double
microAvgRecall
=
calcPrecision
(
microTotalTP
,
microTotalFN
);
double
microAvgF1Score
=
calcF1Score
(
microAvgPrecision
,
microAvgRecall
);
microAvgCallback
(
macroAvgPrecision
,
macroAvgRecall
,
macroAvgF1Score
,
isMultiBinaryLabel_
,
microAvgPrecision
,
microAvgRecall
,
microAvgF1Score
);
}
void
PrecisionRecallEvaluator
::
printStats
(
std
::
ostream
&
os
)
const
{
this
->
printStatsHelper
(
[
&
os
](
int
label
,
double
precision
,
double
recall
,
double
f1
)
{
os
<<
"positive_label="
<<
label
<<
" precision="
<<
precision
<<
" recall="
<<
recall
<<
" F1-score="
<<
f1
;
},
[
&
os
](
double
macroAvgPrecision
,
double
macroAvgRecall
,
double
macroAvgF1Score
,
bool
isMultiBinaryLabel
,
double
microAvgPrecision
,
double
microAvgRecall
,
double
microAvgF1Score
)
{
os
<<
"macro-average-precision="
<<
macroAvgPrecision
<<
" macro-average-recall="
<<
macroAvgRecall
<<
" macro-average-F1-score="
<<
macroAvgF1Score
;
if
(
!
isMultiBinaryLabel
)
{
// precision and recall are equal in this case
os
<<
" micro-average-precision="
<<
microAvgPrecision
;
}
else
{
os
<<
" micro-average-precision="
<<
microAvgPrecision
<<
" micro-average-recall="
<<
microAvgRecall
<<
" micro-average-F1-score="
<<
microAvgF1Score
;
}
});
double
precision
,
recall
,
f1
,
macroAvgPrecision
,
macroAvgRecall
,
macroAvgF1Score
,
microAvgPrecision
,
microAvgRecall
,
microAvgF1Score
;
bool
containMacroMicroInfo
=
getStatsInfo
(
&
precision
,
&
recall
,
&
f1
,
&
macroAvgPrecision
,
&
macroAvgRecall
,
&
macroAvgF1Score
,
&
microAvgPrecision
,
&
microAvgRecall
,
&
microAvgF1Score
);
os
<<
"positive_label="
<<
config_
.
positive_label
()
<<
" precision="
<<
precision
<<
" recall="
<<
recall
<<
" F1-score="
<<
f1
;
if
(
containMacroMicroInfo
)
{
os
<<
"macro-average-precision="
<<
macroAvgPrecision
<<
" macro-average-recall="
<<
macroAvgRecall
<<
" macro-average-F1-score="
<<
macroAvgF1Score
;
if
(
!
isMultiBinaryLabel_
)
{
// precision and recall are equal in this case
os
<<
" micro-average-precision="
<<
microAvgPrecision
;
}
else
{
os
<<
" micro-average-precision="
<<
microAvgPrecision
<<
" micro-average-recall="
<<
microAvgRecall
<<
" micro-average-F1-score="
<<
microAvgF1Score
;
}
};
}
void
PrecisionRecallEvaluator
::
calcStatsInfo
(
const
MatrixPtr
&
output
,
...
...
@@ -780,32 +736,33 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
void
PrecisionRecallEvaluator
::
storeLocalValues
()
const
{
if
(
this
->
values_
.
size
()
==
0
)
{
this
->
printStatsHelper
(
[
this
](
int
label
,
double
precision
,
double
recall
,
double
f1
)
{
values_
[
"positive_label"
]
=
(
double
)
label
;
values_
[
"precision"
]
=
precision
;
values_
[
"recal"
]
=
recall
;
values_
[
"F1-score"
]
=
f1
;
},
[
this
](
double
macroAvgPrecision
,
double
macroAvgRecall
,
double
macroAvgF1Score
,
bool
isMultiBinaryLabel
,
double
microAvgPrecision
,
double
microAvgRecall
,
double
microAvgF1Score
)
{
values_
[
"macro-average-precision"
]
=
macroAvgPrecision
;
values_
[
"macro-average-recall"
]
=
macroAvgRecall
;
values_
[
"macro-average-F1-score"
]
=
macroAvgF1Score
;
if
(
!
isMultiBinaryLabel
)
{
// precision and recall are equal in this case
values_
[
"micro-average-precision"
]
=
microAvgPrecision
;
}
else
{
values_
[
"micro-average-precision"
]
=
microAvgPrecision
;
values_
[
"micro-average-recall"
]
=
microAvgRecall
;
values_
[
"micro-average-F1-score"
]
=
microAvgF1Score
;
}
});
double
precision
,
recall
,
f1
,
macroAvgPrecision
,
macroAvgRecall
,
macroAvgF1Score
,
microAvgPrecision
,
microAvgRecall
,
microAvgF1Score
;
bool
containMacroMicroInfo
=
getStatsInfo
(
&
precision
,
&
recall
,
&
f1
,
&
macroAvgPrecision
,
&
macroAvgRecall
,
&
macroAvgF1Score
,
&
microAvgPrecision
,
&
microAvgRecall
,
&
microAvgF1Score
);
values_
[
"precision"
]
=
precision
;
values_
[
"recal"
]
=
recall
;
values_
[
"F1-score"
]
=
f1
;
if
(
containMacroMicroInfo
)
{
values_
[
"macro-average-precision"
]
=
macroAvgPrecision
;
values_
[
"macro-average-recall"
]
=
macroAvgRecall
;
values_
[
"macro-average-F1-score"
]
=
macroAvgF1Score
;
if
(
!
isMultiBinaryLabel_
)
{
// precision and recall are equal in this case
values_
[
"micro-average-precision"
]
=
microAvgPrecision
;
}
else
{
values_
[
"micro-average-precision"
]
=
microAvgPrecision
;
values_
[
"micro-average-recall"
]
=
microAvgRecall
;
values_
[
"micro-average-F1-score"
]
=
microAvgF1Score
;
}
}
}
}
...
...
@@ -865,6 +822,51 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
delete
[]
buf
;
}
bool
PrecisionRecallEvaluator
::
getStatsInfo
(
double
*
precision
,
double
*
recall
,
double
*
f1
,
double
*
macroAvgPrecision
,
double
*
macroAvgRecall
,
double
*
macroAvgF1Score
,
double
*
microAvgPrecision
,
double
*
microAvgRecall
,
double
*
microAvgF1Score
)
const
{
int
label
=
config_
.
positive_label
();
if
(
label
!=
-
1
)
{
CHECK
(
label
>=
0
&&
label
<
(
int
)
statsInfo_
.
size
())
<<
"positive_label ["
<<
label
<<
"] should be in range [0, "
<<
statsInfo_
.
size
()
<<
")"
;
*
precision
=
calcPrecision
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FP
);
*
recall
=
calcRecall
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FN
);
*
f1
=
calcF1Score
(
*
precision
,
*
recall
);
return
false
;
}
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
// macro average method: precision = (precision1+precision2)/2
double
microTotalTP
=
0
;
double
microTotalFP
=
0
;
double
microTotalFN
=
0
;
*
macroAvgPrecision
=
0
;
*
macroAvgRecall
=
0
;
size_t
numLabels
=
statsInfo_
.
size
();
for
(
size_t
i
=
0
;
i
<
numLabels
;
++
i
)
{
microTotalTP
+=
statsInfo_
[
i
].
TP
;
microTotalFP
+=
statsInfo_
[
i
].
FP
;
microTotalFN
+=
statsInfo_
[
i
].
FN
;
*
macroAvgPrecision
+=
calcPrecision
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FP
);
*
macroAvgRecall
+=
calcRecall
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FN
);
}
*
macroAvgPrecision
/=
numLabels
;
*
macroAvgRecall
/=
numLabels
;
*
macroAvgF1Score
=
calcF1Score
(
*
macroAvgPrecision
,
*
macroAvgRecall
);
*
microAvgPrecision
=
calcPrecision
(
microTotalTP
,
microTotalFP
);
*
microAvgRecall
=
calcPrecision
(
microTotalTP
,
microTotalFN
);
*
microAvgF1Score
=
calcF1Score
(
*
microAvgPrecision
,
*
microAvgRecall
);
return
true
;
}
REGISTER_EVALUATOR
(
pnpair
,
PnpairEvaluator
);
void
PnpairEvaluator
::
start
()
{
Evaluator
::
start
();
...
...
paddle/gserver/evaluators/Evaluator.h
浏览文件 @
b1ab8b56
...
...
@@ -125,7 +125,7 @@ public:
* has multiple field, the name could be `evaluator_name.field1`. For example
* the PrecisionRecallEvaluator contains `precision`, `recall` fields. The get
* names will return `precision_recall_evaluator.precision`,
* `precision_recall.recal`, etc.
* `precision_recall
_evaluator
.recal`, etc.
*
* Also, if current Evaluator is a combined evaluator. getNames will return
* all names of all evaluators inside the combined evaluator.
...
...
@@ -387,8 +387,15 @@ private:
IVectorPtr
cpuLabel_
;
MatrixPtr
cpuWeight_
;
template
<
typename
T1
,
typename
T2
>
void
printStatsHelper
(
T1
labelCallback
,
T2
microAvgCallback
)
const
;
bool
getStatsInfo
(
double
*
precision
,
double
*
recall
,
double
*
f1
,
double
*
macroAvgPrecision
,
double
*
macroAvgRecall
,
double
*
macroAvgF1Score
,
double
*
microAvgPrecision
,
double
*
microAvgRecall
,
double
*
microAvgF1Score
)
const
;
void
calcStatsInfo
(
const
MatrixPtr
&
output
,
const
IVectorPtr
&
label
,
...
...
paddle/utils/Error.h
浏览文件 @
b1ab8b56
...
...
@@ -37,10 +37,10 @@ namespace paddle {
*
* Error __must_check bar() {
* // do something.
*
Status s
= foo(); // invoke other method return status.
* if (
!s) return s
;
*
Error err
= foo(); // invoke other method return status.
* if (
err) return err
;
* // do something else.
* return
Status
();
* return
Error
();
* }
* @endcode{cpp}
*
...
...
@@ -53,8 +53,8 @@ namespace paddle {
*
* int foo(Error* error) {
* // Do something.
* Error
s
= bar();
* if (
!s
) {
* Error
err
= bar();
* if (
err
) {
* *error = s;
* return 0;
* }
...
...
@@ -68,10 +68,10 @@ namespace paddle {
* }
*
* Error foobar() {
* Error
s
;
* Error
err
;
* // do something.
* foo(&
s
);
* if (
!s) return s
;
* foo(&
err
);
* if (
err) return err
;
* }
* @endcode{cpp}
*
...
...
@@ -112,18 +112,22 @@ public:
}
/**
* @brief operator bool, return True if there is
no
error.
* @brief operator bool, return True if there is
something
error.
*/
operator
bool
()
const
{
return
msg_
==
nullptr
;
}
operator
bool
()
const
{
return
!
this
->
isOK
()
;
}
bool
isOK
()
const
{
return
*
this
;
}
/**
* @brief isOK return True if there is no error.
* @return True if no error.
*/
bool
isOK
()
const
{
return
msg_
==
nullptr
;
}
/**
* @brief check this status by glog.
* @note It is a temp method used during cleaning Paddle code. It will be
* removed later.
*/
void
check
()
const
{
CHECK
(
*
this
)
<<
msg
();
}
void
check
()
const
{
CHECK
(
this
->
isOK
()
)
<<
msg
();
}
private:
std
::
shared_ptr
<
std
::
string
>
msg_
;
...
...
paddle/utils/tests/test_Error.cpp
浏览文件 @
b1ab8b56
...
...
@@ -18,17 +18,17 @@ limitations under the License. */
TEST
(
Error
,
testAll
)
{
paddle
::
Error
error
;
ASSERT_TRUE
(
error
);
error
=
paddle
::
Error
(
"I'm the error"
);
ASSERT_FALSE
(
error
);
error
=
paddle
::
Error
(
"I'm the error"
);
ASSERT_TRUE
(
error
);
ASSERT_STREQ
(
"I'm the error"
,
error
.
msg
());
error
=
paddle
::
Error
(
"error2"
);
ASSERT_
FALS
E
(
error
);
ASSERT_
TRU
E
(
error
);
ASSERT_STREQ
(
"error2"
,
error
.
msg
());
int
i
=
3
;
auto
error3
=
paddle
::
Error
(
"error%d"
,
i
);
ASSERT_
FALS
E
(
error3
);
ASSERT_
TRU
E
(
error3
);
ASSERT_STREQ
(
"error3"
,
error3
.
msg
());
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录