Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
c4519574
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c4519574
编写于
2月 23, 2017
作者:
Y
Yu Yang
提交者:
GitHub
2月 23, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1375 from reyoung/feature/EvaluatorValue
Feature/evaluator value
上级
f25c9c5f
9087e3a9
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
379 addition
and
73 deletion
+379
-73
paddle/gserver/evaluators/CTCErrorEvaluator.cpp
paddle/gserver/evaluators/CTCErrorEvaluator.cpp
+1
-1
paddle/gserver/evaluators/Evaluator.cpp
paddle/gserver/evaluators/Evaluator.cpp
+158
-54
paddle/gserver/evaluators/Evaluator.h
paddle/gserver/evaluators/Evaluator.h
+138
-2
paddle/gserver/gradientmachines/NeuralNetwork.cpp
paddle/gserver/gradientmachines/NeuralNetwork.cpp
+49
-1
paddle/gserver/tests/test_Evaluator.cpp
paddle/gserver/tests/test_Evaluator.cpp
+12
-0
paddle/utils/Error.h
paddle/utils/Error.h
+17
-11
paddle/utils/tests/test_Error.cpp
paddle/utils/tests/test_Error.cpp
+4
-4
未找到文件。
paddle/gserver/evaluators/CTCErrorEvaluator.cpp
浏览文件 @
c4519574
...
@@ -20,7 +20,7 @@ namespace paddle {
...
@@ -20,7 +20,7 @@ namespace paddle {
/**
/**
* calculate sequence-to-sequence edit distance
* calculate sequence-to-sequence edit distance
*/
*/
class
CTCErrorEvaluator
:
public
Evaluator
{
class
CTCErrorEvaluator
:
public
NotGetable
Evaluator
{
private:
private:
MatrixPtr
outActivations_
;
MatrixPtr
outActivations_
;
int
numTimes_
,
numClasses_
,
numSequences_
,
blank_
;
int
numTimes_
,
numClasses_
,
numSequences_
,
blank_
;
...
...
paddle/gserver/evaluators/Evaluator.cpp
浏览文件 @
c4519574
...
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
...
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/gserver/evaluators/Evaluator.h"
#include "paddle/gserver/evaluators/Evaluator.h"
#include "paddle/utils/Stat.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/StringUtil.h"
DECLARE_int32
(
trainer_id
);
DECLARE_int32
(
trainer_id
);
...
@@ -122,6 +122,10 @@ public:
...
@@ -122,6 +122,10 @@ public:
virtual
void
distributeEval
(
ParameterClient2
*
client
)
{
virtual
void
distributeEval
(
ParameterClient2
*
client
)
{
mergeResultsOfAllClients
(
client
);
mergeResultsOfAllClients
(
client
);
}
}
// Evaluator interface
protected:
std
::
string
getTypeImpl
()
const
{
return
"classification_error"
;
}
};
};
/**
/**
...
@@ -160,6 +164,10 @@ public:
...
@@ -160,6 +164,10 @@ public:
virtual
void
distributeEval
(
ParameterClient2
*
client
)
{
virtual
void
distributeEval
(
ParameterClient2
*
client
)
{
mergeResultsOfAllClients
(
client
);
mergeResultsOfAllClients
(
client
);
}
}
// Evaluator interface
protected:
std
::
string
getTypeImpl
()
const
{
return
"seq_classification_error"
;
}
};
};
REGISTER_EVALUATOR
(
seq_classification_error
,
REGISTER_EVALUATOR
(
seq_classification_error
,
SequenceClassificationErrorEvaluator
);
SequenceClassificationErrorEvaluator
);
...
@@ -250,6 +258,10 @@ public:
...
@@ -250,6 +258,10 @@ public:
private:
private:
IVectorPtr
cpuLabel_
;
IVectorPtr
cpuLabel_
;
MatrixPtr
cpuWeight_
;
MatrixPtr
cpuWeight_
;
// Evaluator interface
protected:
std
::
string
getTypeImpl
()
const
{
return
"sum"
;
}
};
};
/**
/**
* @brief column sum Evaluator
* @brief column sum Evaluator
...
@@ -357,10 +369,18 @@ public:
...
@@ -357,10 +369,18 @@ public:
}
}
private:
private:
ColumnSumEvaluator
()
{}
int32_t
colIdx_
;
int32_t
colIdx_
;
size_t
colNum_
;
size_t
colNum_
;
MatrixPtr
sum_
;
/* cpu matrix */
MatrixPtr
sum_
;
/* cpu matrix */
// Evaluator interface
protected:
std
::
string
getTypeImpl
()
const
{
if
(
colIdx_
==
-
1
)
return
"last-column-sum"
;
else
return
"column-sum"
;
}
};
};
void
AucEvaluator
::
start
()
{
void
AucEvaluator
::
start
()
{
...
@@ -469,6 +489,16 @@ double AucEvaluator::calcAuc() const {
...
@@ -469,6 +489,16 @@ double AucEvaluator::calcAuc() const {
}
}
}
}
real
AucEvaluator
::
getValueImpl
()
const
{
return
calcAuc
();
}
std
::
string
AucEvaluator
::
getTypeImpl
()
const
{
if
(
colIdx_
==
-
1
)
{
return
"last-column-auc"
;
}
else
{
return
"auc"
;
}
}
// class RankAucEvaluator
// class RankAucEvaluator
REGISTER_EVALUATOR
(
rankauc
,
RankAucEvaluator
);
REGISTER_EVALUATOR
(
rankauc
,
RankAucEvaluator
);
...
@@ -548,12 +578,15 @@ double RankAucEvaluator::calcRankAuc(real* outputData,
...
@@ -548,12 +578,15 @@ double RankAucEvaluator::calcRankAuc(real* outputData,
:
aucTmp
/
(
clickSum
*
noClickSum
);
:
aucTmp
/
(
clickSum
*
noClickSum
);
}
}
std
::
string
RankAucEvaluator
::
getTypeImpl
()
const
{
return
"rankauc"
;
}
// class PrecisionRecallEvaluator
// class PrecisionRecallEvaluator
REGISTER_EVALUATOR
(
precision_recall
,
PrecisionRecallEvaluator
);
REGISTER_EVALUATOR
(
precision_recall
,
PrecisionRecallEvaluator
);
void
PrecisionRecallEvaluator
::
start
()
{
void
PrecisionRecallEvaluator
::
start
()
{
Evaluator
::
start
();
Evaluator
::
start
();
statsInfo_
.
clear
();
statsInfo_
.
clear
();
values_
.
clear
();
}
}
real
PrecisionRecallEvaluator
::
evalImp
(
std
::
vector
<
Argument
>&
arguments
)
{
real
PrecisionRecallEvaluator
::
evalImp
(
std
::
vector
<
Argument
>&
arguments
)
{
...
@@ -614,52 +647,23 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
...
@@ -614,52 +647,23 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
}
}
void
PrecisionRecallEvaluator
::
printStats
(
std
::
ostream
&
os
)
const
{
void
PrecisionRecallEvaluator
::
printStats
(
std
::
ostream
&
os
)
const
{
int
label
=
config_
.
positive_label
();
PrintStatsInfo
info
;
if
(
label
!=
-
1
)
{
bool
containMacroMicroInfo
=
getStatsInfo
(
&
info
);
CHECK
(
label
>=
0
&&
label
<
(
int
)
statsInfo_
.
size
())
os
<<
"positive_label="
<<
config_
.
positive_label
()
<<
"positive_label ["
<<
label
<<
"] should be in range [0, "
<<
" precision="
<<
info
.
precision
<<
" recall="
<<
info
.
recall
<<
statsInfo_
.
size
()
<<
")"
;
<<
" F1-score="
<<
info
.
f1
;
double
precision
=
if
(
containMacroMicroInfo
)
{
calcPrecision
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FP
);
os
<<
"macro-average-precision="
<<
info
.
macroAvgPrecision
double
recall
=
calcRecall
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FN
);
<<
" macro-average-recall="
<<
info
.
macroAvgRecall
os
<<
"positive_label="
<<
label
<<
" precision="
<<
precision
<<
" macro-average-F1-score="
<<
info
.
macroAvgF1Score
;
<<
" recall="
<<
recall
if
(
!
isMultiBinaryLabel_
)
{
<<
" F1-score="
<<
calcF1Score
(
precision
,
recall
);
// precision and recall are equal in this case
return
;
os
<<
" micro-average-precision="
<<
info
.
microAvgPrecision
;
}
}
else
{
os
<<
" micro-average-precision="
<<
info
.
microAvgPrecision
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
<<
" micro-average-recall="
<<
info
.
microAvgRecall
// macro average method: precision = (precision1+precision2)/2
<<
" micro-average-F1-score="
<<
info
.
microAvgF1Score
;
double
microTotalTP
=
0
;
}
double
microTotalFP
=
0
;
double
microTotalFN
=
0
;
double
macroAvgPrecision
=
0
;
double
macroAvgRecall
=
0
;
size_t
numLabels
=
statsInfo_
.
size
();
for
(
size_t
i
=
0
;
i
<
numLabels
;
++
i
)
{
microTotalTP
+=
statsInfo_
[
i
].
TP
;
microTotalFP
+=
statsInfo_
[
i
].
FP
;
microTotalFN
+=
statsInfo_
[
i
].
FN
;
macroAvgPrecision
+=
calcPrecision
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FP
);
macroAvgRecall
+=
calcRecall
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FN
);
}
macroAvgPrecision
/=
numLabels
;
macroAvgRecall
/=
numLabels
;
double
macroAvgF1Score
=
calcF1Score
(
macroAvgPrecision
,
macroAvgRecall
);
os
<<
"macro-average-precision="
<<
macroAvgPrecision
<<
" macro-average-recall="
<<
macroAvgRecall
<<
" macro-average-F1-score="
<<
macroAvgF1Score
;
double
microAvgPrecision
=
calcPrecision
(
microTotalTP
,
microTotalFP
);
double
microAvgRecall
=
calcPrecision
(
microTotalTP
,
microTotalFN
);
double
microAvgF1Score
=
calcF1Score
(
microAvgPrecision
,
microAvgRecall
);
if
(
!
isMultiBinaryLabel_
)
{
// precision and recall are equal in this case
os
<<
" micro-average-precision="
<<
microAvgPrecision
;
}
else
{
os
<<
" micro-average-precision="
<<
microAvgPrecision
<<
" micro-average-recall="
<<
microAvgRecall
<<
" micro-average-F1-score="
<<
microAvgF1Score
;
}
}
}
}
...
@@ -741,6 +745,60 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
...
@@ -741,6 +745,60 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
}
}
}
}
void
PrecisionRecallEvaluator
::
storeLocalValues
()
const
{
if
(
this
->
values_
.
size
()
==
0
)
{
PrintStatsInfo
info
;
bool
containMacroMicroInfo
=
getStatsInfo
(
&
info
);
values_
[
"precision"
]
=
info
.
precision
;
values_
[
"recal"
]
=
info
.
recall
;
values_
[
"F1-score"
]
=
info
.
f1
;
if
(
containMacroMicroInfo
)
{
values_
[
"macro-average-precision"
]
=
info
.
macroAvgPrecision
;
values_
[
"macro-average-recall"
]
=
info
.
macroAvgRecall
;
values_
[
"macro-average-F1-score"
]
=
info
.
macroAvgF1Score
;
if
(
!
isMultiBinaryLabel_
)
{
// precision and recall are equal in this case
values_
[
"micro-average-precision"
]
=
info
.
microAvgPrecision
;
}
else
{
values_
[
"micro-average-precision"
]
=
info
.
microAvgPrecision
;
values_
[
"micro-average-recall"
]
=
info
.
microAvgRecall
;
values_
[
"micro-average-F1-score"
]
=
info
.
microAvgF1Score
;
}
}
}
}
void
PrecisionRecallEvaluator
::
getNames
(
std
::
vector
<
std
::
string
>*
names
)
{
this
->
storeLocalValues
();
names
->
reserve
(
this
->
values_
.
size
());
for
(
auto
it
=
this
->
values_
.
begin
();
it
!=
this
->
values_
.
end
();
++
it
)
{
names
->
push_back
(
this
->
config_
.
name
()
+
"."
+
it
->
first
);
}
}
real
PrecisionRecallEvaluator
::
getValue
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
this
->
storeLocalValues
();
std
::
vector
<
std
::
string
>
buffers
;
paddle
::
str
::
split
(
name
,
'.'
,
&
buffers
);
auto
it
=
this
->
values_
.
find
(
buffers
[
buffers
.
size
()
-
1
]);
if
(
it
==
this
->
values_
.
end
())
{
// not found
*
err
=
Error
(
"No such key %s"
,
name
.
c_str
());
return
.0
f
;
}
return
it
->
second
;
}
std
::
string
PrecisionRecallEvaluator
::
getType
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
this
->
getValue
(
name
,
err
);
if
(
!
err
->
isOK
())
{
return
""
;
}
return
"precision_recall"
;
}
void
PrecisionRecallEvaluator
::
distributeEval
(
ParameterClient2
*
client
)
{
void
PrecisionRecallEvaluator
::
distributeEval
(
ParameterClient2
*
client
)
{
size_t
size
=
4
*
statsInfo_
.
size
();
size_t
size
=
4
*
statsInfo_
.
size
();
double
*
buf
=
new
double
[
size
];
double
*
buf
=
new
double
[
size
];
...
@@ -760,6 +818,47 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
...
@@ -760,6 +818,47 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
delete
[]
buf
;
delete
[]
buf
;
}
}
bool
PrecisionRecallEvaluator
::
getStatsInfo
(
PrecisionRecallEvaluator
::
PrintStatsInfo
*
info
)
const
{
int
label
=
config_
.
positive_label
();
if
(
label
!=
-
1
)
{
CHECK
(
label
>=
0
&&
label
<
(
int
)
statsInfo_
.
size
())
<<
"positive_label ["
<<
label
<<
"] should be in range [0, "
<<
statsInfo_
.
size
()
<<
")"
;
info
->
precision
=
calcPrecision
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FP
);
info
->
recall
=
calcRecall
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FN
);
info
->
f1
=
calcF1Score
(
info
->
precision
,
info
->
recall
);
return
false
;
}
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
// macro average method: precision = (precision1+precision2)/2
double
microTotalTP
=
0
;
double
microTotalFP
=
0
;
double
microTotalFN
=
0
;
info
->
macroAvgPrecision
=
0
;
info
->
macroAvgRecall
=
0
;
size_t
numLabels
=
statsInfo_
.
size
();
for
(
size_t
i
=
0
;
i
<
numLabels
;
++
i
)
{
microTotalTP
+=
statsInfo_
[
i
].
TP
;
microTotalFP
+=
statsInfo_
[
i
].
FP
;
microTotalFN
+=
statsInfo_
[
i
].
FN
;
info
->
macroAvgPrecision
+=
calcPrecision
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FP
);
info
->
macroAvgRecall
+=
calcRecall
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FN
);
}
info
->
macroAvgPrecision
/=
numLabels
;
info
->
macroAvgRecall
/=
numLabels
;
info
->
macroAvgF1Score
=
calcF1Score
(
info
->
macroAvgPrecision
,
info
->
macroAvgRecall
);
info
->
microAvgPrecision
=
calcPrecision
(
microTotalTP
,
microTotalFP
);
info
->
microAvgRecall
=
calcPrecision
(
microTotalTP
,
microTotalFN
);
info
->
microAvgF1Score
=
calcF1Score
(
info
->
microAvgPrecision
,
info
->
microAvgRecall
);
return
true
;
}
REGISTER_EVALUATOR
(
pnpair
,
PnpairEvaluator
);
REGISTER_EVALUATOR
(
pnpair
,
PnpairEvaluator
);
void
PnpairEvaluator
::
start
()
{
void
PnpairEvaluator
::
start
()
{
Evaluator
::
start
();
Evaluator
::
start
();
...
@@ -884,6 +983,8 @@ void PnpairEvaluator::calc(std::vector<PredictionResult>& predictArray) {
...
@@ -884,6 +983,8 @@ void PnpairEvaluator::calc(std::vector<PredictionResult>& predictArray) {
<<
" calc total special pair: "
<<
special
;
<<
" calc total special pair: "
<<
special
;
}
}
std
::
string
PnpairEvaluator
::
getTypeImpl
()
const
{
return
"pnpair"
;
}
ClassRegistrar
<
Evaluator
>
Evaluator
::
registrar_
;
ClassRegistrar
<
Evaluator
>
Evaluator
::
registrar_
;
Evaluator
*
Evaluator
::
create
(
const
EvaluatorConfig
&
config
)
{
Evaluator
*
Evaluator
::
create
(
const
EvaluatorConfig
&
config
)
{
Evaluator
*
evaluator
=
registrar_
.
createByType
(
config
.
type
());
Evaluator
*
evaluator
=
registrar_
.
createByType
(
config
.
type
());
...
@@ -905,7 +1006,7 @@ static InitFunction __reg_type_auc_sum__([]() {
...
@@ -905,7 +1006,7 @@ static InitFunction __reg_type_auc_sum__([]() {
*
*
* The config file api is value_printer_evaluator.
* The config file api is value_printer_evaluator.
*/
*/
class
ValuePrinter
:
public
Evaluator
{
class
ValuePrinter
:
public
NotGetable
Evaluator
{
public:
public:
virtual
void
eval
(
const
NeuralNetwork
&
nn
)
{
virtual
void
eval
(
const
NeuralNetwork
&
nn
)
{
for
(
const
std
::
string
&
name
:
config_
.
input_layers
())
{
for
(
const
std
::
string
&
name
:
config_
.
input_layers
())
{
...
@@ -919,12 +1020,13 @@ public:
...
@@ -919,12 +1020,13 @@ public:
virtual
real
evalImp
(
std
::
vector
<
Argument
>&
arguments
)
{
return
0
;
}
virtual
real
evalImp
(
std
::
vector
<
Argument
>&
arguments
)
{
return
0
;
}
};
};
REGISTER_EVALUATOR
(
value_printer
,
ValuePrinter
);
REGISTER_EVALUATOR
(
value_printer
,
ValuePrinter
);
/**
/**
* @brief print gradient of each layer.
* @brief print gradient of each layer.
*
*
* The config file api is gradient_printer_evaluator.
* The config file api is gradient_printer_evaluator.
*/
*/
class
GradientPrinter
:
public
Evaluator
{
class
GradientPrinter
:
public
NotGetable
Evaluator
{
public:
public:
virtual
void
eval
(
const
NeuralNetwork
&
nn
)
{
virtual
void
eval
(
const
NeuralNetwork
&
nn
)
{
for
(
const
std
::
string
&
name
:
config_
.
input_layers
())
{
for
(
const
std
::
string
&
name
:
config_
.
input_layers
())
{
...
@@ -947,7 +1049,7 @@ REGISTER_EVALUATOR(gradient_printer, GradientPrinter);
...
@@ -947,7 +1049,7 @@ REGISTER_EVALUATOR(gradient_printer, GradientPrinter);
*
*
* The config file api is maxid_printer_evaluator.
* The config file api is maxid_printer_evaluator.
*/
*/
class
MaxIdPrinter
:
public
Evaluator
{
class
MaxIdPrinter
:
public
NotGetable
Evaluator
{
private:
private:
IVectorPtr
maxIds_
;
IVectorPtr
maxIds_
;
MatrixPtr
maxValues_
;
MatrixPtr
maxValues_
;
...
@@ -989,7 +1091,7 @@ REGISTER_EVALUATOR(max_id_printer, MaxIdPrinter);
...
@@ -989,7 +1091,7 @@ REGISTER_EVALUATOR(max_id_printer, MaxIdPrinter);
*
*
* The config file api is maxframe_printer_evaluator.
* The config file api is maxframe_printer_evaluator.
*/
*/
class
MaxFramePrinter
:
public
Evaluator
{
class
MaxFramePrinter
:
public
NotGetable
Evaluator
{
private:
private:
IVectorPtr
maxIds_
;
IVectorPtr
maxIds_
;
MatrixPtr
maxValues_
;
MatrixPtr
maxValues_
;
...
@@ -1076,7 +1178,7 @@ REGISTER_EVALUATOR(max_frame_printer, MaxFramePrinter);
...
@@ -1076,7 +1178,7 @@ REGISTER_EVALUATOR(max_frame_printer, MaxFramePrinter);
* The config file api is seqtext_printer_evaluator.
* The config file api is seqtext_printer_evaluator.
*
*
*/
*/
class
SequenceTextPrinter
:
public
Evaluator
{
class
SequenceTextPrinter
:
public
NotGetable
Evaluator
{
private:
private:
/// dict_file, which contains a list of tokens
/// dict_file, which contains a list of tokens
std
::
vector
<
std
::
string
>
dict_
;
std
::
vector
<
std
::
string
>
dict_
;
...
@@ -1243,4 +1345,6 @@ public:
...
@@ -1243,4 +1345,6 @@ public:
};
};
REGISTER_EVALUATOR
(
classification_error_printer
,
ClassificationErrorPrinter
);
REGISTER_EVALUATOR
(
classification_error_printer
,
ClassificationErrorPrinter
);
std
::
string
DummyEvaluator
::
getTypeImpl
()
const
{
return
"dummy"
;
}
}
// namespace paddle
}
// namespace paddle
paddle/gserver/evaluators/Evaluator.h
浏览文件 @
c4519574
...
@@ -19,6 +19,7 @@ limitations under the License. */
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/parameter/Argument.h"
#include "paddle/parameter/Argument.h"
#include "paddle/pserver/ParameterClient2.h"
#include "paddle/pserver/ParameterClient2.h"
#include "paddle/utils/ClassRegistrar.h"
#include "paddle/utils/ClassRegistrar.h"
#include "paddle/utils/Error.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -117,12 +118,105 @@ public:
...
@@ -117,12 +118,105 @@ public:
static
ClassRegistrar
<
Evaluator
>
registrar_
;
static
ClassRegistrar
<
Evaluator
>
registrar_
;
/**
* @brief getNames will return all field names of current evaluator.
*
* The format of name is `evaluator_name.evaluator_fields`. If the evaluator
* has multiple field, the name could be `evaluator_name.field1`. For example
* the PrecisionRecallEvaluator contains `precision`, `recall` fields. The get
* names will return `precision_recall_evaluator.precision`,
* `precision_recall_evaluator.recal`, etc.
*
* Also, if current Evaluator is a combined evaluator. getNames will return
* all names of all evaluators inside the combined evaluator.
*
* @param names [out]: the field names of current evaluator.
* @note Never clear the names parameter inside getNames.
*/
virtual
void
getNames
(
std
::
vector
<
std
::
string
>*
names
)
{
names
->
push_back
(
config_
.
name
());
}
/**
* @brief getValue will return the current evaluate value of one field.
*
* @param name: The field name of current evaluator.
* @param err [out]: The error state.
*
* @return The evaluate value(metric).
*/
virtual
real
getValue
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
if
(
name
!=
config_
.
name
())
{
*
err
=
Error
(
"no such name of evaluator %s"
,
name
.
c_str
());
return
.0
f
;
}
return
this
->
getValueImpl
();
}
/**
* @brief getType will return the evaluator type by field name.
*
* Evaluate Type is the current type of evaluator in string. Such as 'auc',
* 'precision_recall'. In combined evaluator, different name may get different
* evaluate type because it could be evaluated by different evaluator inside.
*
* @param name: The field name of current Evaluator.
* @param err: The error state. nullptr means don't care.
* @return the evaluator type string.
*/
virtual
std
::
string
getType
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
if
(
name
!=
config_
.
name
())
{
*
err
=
Error
(
"no such name of evaluator %s"
,
name
.
c_str
());
return
std
::
string
();
}
return
this
->
getTypeImpl
();
}
protected:
/**
* @brief getValueImpl The simplest way to define getValue result. If this
* evaluator doesn't contain multiple fields, and do not throw any error, just
* implemented this method to get the evaluate result(metric).
* @return Evaluate result(metric).
*/
virtual
real
getValueImpl
()
const
{
return
numSamples_
!=
.0
?
totalScore_
/
numSamples_
:
.0
;
}
/**
* @brief getTypeImpl The simplest way to define getType result. If this
* evaluator doesn't combine many evaluators, the get type should only return
* itself type.
* @return Evaluator type.
*/
virtual
std
::
string
getTypeImpl
()
const
{
return
"base"
;
}
protected:
protected:
EvaluatorConfig
config_
;
EvaluatorConfig
config_
;
double
numSamples_
;
double
numSamples_
;
double
totalScore_
;
double
totalScore_
;
};
};
/**
* @brief The NotGetableEvaluator class is the base class of evaluator that
* cannot get value in runtime. The most NotGetableEvaluator is Printer
* Evaluator, which is only used to debug network configuration.
*/
class
NotGetableEvaluator
:
public
Evaluator
{
// Evaluator interface
public:
void
getNames
(
std
::
vector
<
std
::
string
>*
names
)
{}
real
getValue
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
*
err
=
Error
(
"Not implemented"
);
return
.0
f
;
}
std
::
string
getType
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
*
err
=
Error
(
"Not implemented"
);
return
""
;
}
};
class
DummyEvaluator
:
public
Evaluator
{
class
DummyEvaluator
:
public
Evaluator
{
public:
public:
DummyEvaluator
()
{}
DummyEvaluator
()
{}
...
@@ -135,6 +229,10 @@ public:
...
@@ -135,6 +229,10 @@ public:
}
}
virtual
void
finish
()
{}
virtual
void
finish
()
{}
virtual
void
printStats
(
std
::
ostream
&
)
const
{}
virtual
void
printStats
(
std
::
ostream
&
)
const
{}
// Evaluator interface
protected:
std
::
string
getTypeImpl
()
const
;
};
};
/**
/**
* @brief evaluate AUC using colIdx-th column as prediction.
* @brief evaluate AUC using colIdx-th column as prediction.
...
@@ -191,6 +289,11 @@ private:
...
@@ -191,6 +289,11 @@ private:
}
}
double
calcAuc
()
const
;
double
calcAuc
()
const
;
// Evaluator interface
protected:
real
getValueImpl
()
const
;
std
::
string
getTypeImpl
()
const
;
};
};
/**
/**
...
@@ -223,6 +326,10 @@ private:
...
@@ -223,6 +326,10 @@ private:
real
*
clickData
,
real
*
clickData
,
real
*
pvData
,
real
*
pvData
,
size_t
size
);
size_t
size
);
// Evaluator interface
protected:
std
::
string
getTypeImpl
()
const
;
};
};
/**
/**
* @brief precision, recall and f1 score Evaluator
* @brief precision, recall and f1 score Evaluator
...
@@ -272,6 +379,20 @@ private:
...
@@ -272,6 +379,20 @@ private:
IVectorPtr
cpuLabel_
;
IVectorPtr
cpuLabel_
;
MatrixPtr
cpuWeight_
;
MatrixPtr
cpuWeight_
;
struct
PrintStatsInfo
{
double
precision
;
double
recall
;
double
f1
;
double
macroAvgPrecision
;
double
macroAvgRecall
;
double
macroAvgF1Score
;
double
microAvgPrecision
;
double
microAvgRecall
;
double
microAvgF1Score
;
};
bool
getStatsInfo
(
PrintStatsInfo
*
info
)
const
;
void
calcStatsInfo
(
const
MatrixPtr
&
output
,
void
calcStatsInfo
(
const
MatrixPtr
&
output
,
const
IVectorPtr
&
label
,
const
IVectorPtr
&
label
,
const
MatrixPtr
&
weight
);
const
MatrixPtr
&
weight
);
...
@@ -303,6 +424,15 @@ private:
...
@@ -303,6 +424,15 @@ private:
return
0
;
return
0
;
}
}
}
}
mutable
std
::
unordered_map
<
std
::
string
,
real
>
values_
;
void
storeLocalValues
()
const
;
// Evaluator interface
public:
void
getNames
(
std
::
vector
<
std
::
string
>*
names
);
real
getValue
(
const
std
::
string
&
name
,
Error
*
err
)
const
;
std
::
string
getType
(
const
std
::
string
&
name
,
Error
*
err
)
const
;
};
};
/*
/*
...
@@ -349,8 +479,7 @@ public:
...
@@ -349,8 +479,7 @@ public:
virtual
void
finish
()
{
calc
(
predictArray_
);
}
virtual
void
finish
()
{
calc
(
predictArray_
);
}
virtual
void
printStats
(
std
::
ostream
&
os
)
const
{
virtual
void
printStats
(
std
::
ostream
&
os
)
const
{
os
<<
" pos/neg"
os
<<
" pos/neg="
<<
this
->
getValueImpl
();
<<
"="
<<
pairArray_
[
0
]
/
((
pairArray_
[
1
]
<=
0
)
?
1.0
:
pairArray_
[
1
]);
}
}
virtual
void
distributeEval
(
ParameterClient2
*
client
)
{
virtual
void
distributeEval
(
ParameterClient2
*
client
)
{
...
@@ -366,6 +495,13 @@ private:
...
@@ -366,6 +495,13 @@ private:
IVectorPtr
cpuLabel_
;
IVectorPtr
cpuLabel_
;
IVectorPtr
cpuInfo_
;
IVectorPtr
cpuInfo_
;
MatrixPtr
cpuWeight_
;
MatrixPtr
cpuWeight_
;
// Evaluator interface
protected:
real
getValueImpl
()
const
{
return
pairArray_
[
0
]
/
((
pairArray_
[
1
]
<=
0
)
?
1.0
:
pairArray_
[
1
]);
}
std
::
string
getTypeImpl
()
const
;
};
};
}
// namespace paddle
}
// namespace paddle
paddle/gserver/gradientmachines/NeuralNetwork.cpp
浏览文件 @
c4519574
...
@@ -306,7 +306,6 @@ void NeuralNetwork::onPassEnd() {
...
@@ -306,7 +306,6 @@ void NeuralNetwork::onPassEnd() {
class
CombinedEvaluator
:
public
Evaluator
{
class
CombinedEvaluator
:
public
Evaluator
{
public:
public:
CombinedEvaluator
()
{}
void
addEvaluator
(
std
::
unique_ptr
<
Evaluator
>&&
evaluator
)
{
void
addEvaluator
(
std
::
unique_ptr
<
Evaluator
>&&
evaluator
)
{
evaluators_
.
emplace_back
(
std
::
move
(
evaluator
));
evaluators_
.
emplace_back
(
std
::
move
(
evaluator
));
}
}
...
@@ -346,6 +345,55 @@ public:
...
@@ -346,6 +345,55 @@ public:
protected:
protected:
std
::
vector
<
std
::
unique_ptr
<
Evaluator
>>
evaluators_
;
std
::
vector
<
std
::
unique_ptr
<
Evaluator
>>
evaluators_
;
// Evaluator interface
public:
/**
* @brief getNames will return all inside evaluators' names.
* @param names [out]: return names.
*/
void
getNames
(
std
::
vector
<
std
::
string
>*
names
)
{
for
(
auto
&
eval
:
evaluators_
)
{
eval
->
getNames
(
names
);
}
}
/**
* @brief getValue could get all inside evaluators' value.
*/
real
getValue
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
return
this
->
getMethodHelper
<
real
>
(
name
,
err
,
[
&
name
,
err
](
const
std
::
unique_ptr
<
Evaluator
>&
eval
)
{
return
eval
->
getValue
(
name
,
err
);
});
}
/**
* @brief getType could get all inside evaluators' type.
*/
std
::
string
getType
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
return
this
->
getMethodHelper
<
std
::
string
>
(
name
,
err
,
[
&
name
,
err
](
const
std
::
unique_ptr
<
Evaluator
>&
eval
)
{
return
eval
->
getType
(
name
,
err
);
});
}
private:
template
<
typename
T
>
T
getMethodHelper
(
const
std
::
string
&
name
,
Error
*
err
,
const
std
::
function
<
T
(
const
std
::
unique_ptr
<
Evaluator
>&
)
>&
callback
)
const
{
for
(
auto
&
eval
:
evaluators_
)
{
std
::
vector
<
std
::
string
>
names
;
eval
->
getNames
(
&
names
);
if
(
std
::
find
(
names
.
begin
(),
names
.
end
(),
name
)
!=
names
.
end
())
{
return
callback
(
eval
);
}
}
*
err
=
Error
(
"No such key %s"
,
name
.
c_str
());
return
T
();
}
};
};
Evaluator
*
NeuralNetwork
::
makeEvaluator
()
const
{
Evaluator
*
NeuralNetwork
::
makeEvaluator
()
const
{
...
...
paddle/gserver/tests/test_Evaluator.cpp
浏览文件 @
c4519574
...
@@ -110,6 +110,18 @@ void testEvaluator(TestConfig testConf,
...
@@ -110,6 +110,18 @@ void testEvaluator(TestConfig testConf,
testEvaluator
->
finish
();
testEvaluator
->
finish
();
LOG
(
INFO
)
<<
*
testEvaluator
;
LOG
(
INFO
)
<<
*
testEvaluator
;
std
::
vector
<
std
::
string
>
names
;
testEvaluator
->
getNames
(
&
names
);
paddle
::
Error
err
;
for
(
auto
&
name
:
names
)
{
auto
value
=
testEvaluator
->
getValue
(
name
,
&
err
);
ASSERT_TRUE
(
err
.
isOK
());
LOG
(
INFO
)
<<
name
<<
" "
<<
value
;
auto
tp
=
testEvaluator
->
getType
(
name
,
&
err
);
ASSERT_TRUE
(
err
.
isOK
());
ASSERT_EQ
(
testConf
.
evaluatorConfig
.
type
(),
tp
);
}
double
totalScore2
=
0.0
;
double
totalScore2
=
0.0
;
if
(
testConf
.
testAccumulate
)
{
if
(
testConf
.
testAccumulate
)
{
testEvaluator
->
start
();
testEvaluator
->
start
();
...
...
paddle/utils/Error.h
浏览文件 @
c4519574
...
@@ -37,10 +37,10 @@ namespace paddle {
...
@@ -37,10 +37,10 @@ namespace paddle {
*
*
* Error __must_check bar() {
* Error __must_check bar() {
* // do something.
* // do something.
*
Status s
= foo(); // invoke other method return status.
*
Error err
= foo(); // invoke other method return status.
* if (
!s) return s
;
* if (
err) return err
;
* // do something else.
* // do something else.
* return
Status
();
* return
Error
();
* }
* }
* @endcode{cpp}
* @endcode{cpp}
*
*
...
@@ -53,8 +53,8 @@ namespace paddle {
...
@@ -53,8 +53,8 @@ namespace paddle {
*
*
* int foo(Error* error) {
* int foo(Error* error) {
* // Do something.
* // Do something.
* Error
s
= bar();
* Error
err
= bar();
* if (
!s
) {
* if (
err
) {
* *error = s;
* *error = s;
* return 0;
* return 0;
* }
* }
...
@@ -68,10 +68,10 @@ namespace paddle {
...
@@ -68,10 +68,10 @@ namespace paddle {
* }
* }
*
*
* Error foobar() {
* Error foobar() {
* Error
s
;
* Error
err
;
* // do something.
* // do something.
* foo(&
s
);
* foo(&
err
);
* if (
!s) return s
;
* if (
err) return err
;
* }
* }
* @endcode{cpp}
* @endcode{cpp}
*
*
...
@@ -112,16 +112,22 @@ public:
...
@@ -112,16 +112,22 @@ public:
}
}
/**
/**
* @brief operator bool, return True if there is
no
error.
* @brief operator bool, return True if there is
something
error.
*/
*/
operator
bool
()
const
{
return
msg_
==
nullptr
;
}
operator
bool
()
const
{
return
!
this
->
isOK
();
}
/**
* @brief isOK return True if there is no error.
* @return True if no error.
*/
bool
isOK
()
const
{
return
msg_
==
nullptr
;
}
/**
/**
* @brief check this status by glog.
* @brief check this status by glog.
* @note It is a temp method used during cleaning Paddle code. It will be
* @note It is a temp method used during cleaning Paddle code. It will be
* removed later.
* removed later.
*/
*/
void
check
()
const
{
CHECK
(
*
this
)
<<
msg
();
}
void
check
()
const
{
CHECK
(
this
->
isOK
()
)
<<
msg
();
}
private:
private:
std
::
shared_ptr
<
std
::
string
>
msg_
;
std
::
shared_ptr
<
std
::
string
>
msg_
;
...
...
paddle/utils/tests/test_Error.cpp
浏览文件 @
c4519574
...
@@ -18,17 +18,17 @@ limitations under the License. */
...
@@ -18,17 +18,17 @@ limitations under the License. */
TEST
(
Error
,
testAll
)
{
TEST
(
Error
,
testAll
)
{
paddle
::
Error
error
;
paddle
::
Error
error
;
ASSERT_TRUE
(
error
);
error
=
paddle
::
Error
(
"I'm the error"
);
ASSERT_FALSE
(
error
);
ASSERT_FALSE
(
error
);
error
=
paddle
::
Error
(
"I'm the error"
);
ASSERT_TRUE
(
error
);
ASSERT_STREQ
(
"I'm the error"
,
error
.
msg
());
ASSERT_STREQ
(
"I'm the error"
,
error
.
msg
());
error
=
paddle
::
Error
(
"error2"
);
error
=
paddle
::
Error
(
"error2"
);
ASSERT_
FALS
E
(
error
);
ASSERT_
TRU
E
(
error
);
ASSERT_STREQ
(
"error2"
,
error
.
msg
());
ASSERT_STREQ
(
"error2"
,
error
.
msg
());
int
i
=
3
;
int
i
=
3
;
auto
error3
=
paddle
::
Error
(
"error%d"
,
i
);
auto
error3
=
paddle
::
Error
(
"error%d"
,
i
);
ASSERT_
FALS
E
(
error3
);
ASSERT_
TRU
E
(
error3
);
ASSERT_STREQ
(
"error3"
,
error3
.
msg
());
ASSERT_STREQ
(
"error3"
,
error3
.
msg
());
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录