Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
108f195d
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
108f195d
编写于
2月 24, 2017
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into topology
上级
12cac800
c4519574
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
379 addition
and
73 deletion
+379
-73
paddle/gserver/evaluators/CTCErrorEvaluator.cpp
paddle/gserver/evaluators/CTCErrorEvaluator.cpp
+1
-1
paddle/gserver/evaluators/Evaluator.cpp
paddle/gserver/evaluators/Evaluator.cpp
+158
-54
paddle/gserver/evaluators/Evaluator.h
paddle/gserver/evaluators/Evaluator.h
+138
-2
paddle/gserver/gradientmachines/NeuralNetwork.cpp
paddle/gserver/gradientmachines/NeuralNetwork.cpp
+49
-1
paddle/gserver/tests/test_Evaluator.cpp
paddle/gserver/tests/test_Evaluator.cpp
+12
-0
paddle/utils/Error.h
paddle/utils/Error.h
+17
-11
paddle/utils/tests/test_Error.cpp
paddle/utils/tests/test_Error.cpp
+4
-4
未找到文件。
paddle/gserver/evaluators/CTCErrorEvaluator.cpp
浏览文件 @
108f195d
...
...
@@ -20,7 +20,7 @@ namespace paddle {
/**
* calculate sequence-to-sequence edit distance
*/
class
CTCErrorEvaluator
:
public
Evaluator
{
class
CTCErrorEvaluator
:
public
NotGetable
Evaluator
{
private:
MatrixPtr
outActivations_
;
int
numTimes_
,
numClasses_
,
numSequences_
,
blank_
;
...
...
paddle/gserver/evaluators/Evaluator.cpp
浏览文件 @
108f195d
...
...
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/gserver/evaluators/Evaluator.h"
#include "paddle/utils/Stat.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/StringUtil.h"
DECLARE_int32
(
trainer_id
);
...
...
@@ -122,6 +122,10 @@ public:
virtual
void
distributeEval
(
ParameterClient2
*
client
)
{
mergeResultsOfAllClients
(
client
);
}
// Evaluator interface
protected:
std
::
string
getTypeImpl
()
const
{
return
"classification_error"
;
}
};
/**
...
...
@@ -160,6 +164,10 @@ public:
virtual
void
distributeEval
(
ParameterClient2
*
client
)
{
mergeResultsOfAllClients
(
client
);
}
// Evaluator interface
protected:
std
::
string
getTypeImpl
()
const
{
return
"seq_classification_error"
;
}
};
REGISTER_EVALUATOR
(
seq_classification_error
,
SequenceClassificationErrorEvaluator
);
...
...
@@ -250,6 +258,10 @@ public:
private:
IVectorPtr
cpuLabel_
;
MatrixPtr
cpuWeight_
;
// Evaluator interface
protected:
std
::
string
getTypeImpl
()
const
{
return
"sum"
;
}
};
/**
* @brief column sum Evaluator
...
...
@@ -357,10 +369,18 @@ public:
}
private:
ColumnSumEvaluator
()
{}
int32_t
colIdx_
;
size_t
colNum_
;
MatrixPtr
sum_
;
/* cpu matrix */
// Evaluator interface
protected:
std
::
string
getTypeImpl
()
const
{
if
(
colIdx_
==
-
1
)
return
"last-column-sum"
;
else
return
"column-sum"
;
}
};
void
AucEvaluator
::
start
()
{
...
...
@@ -469,6 +489,16 @@ double AucEvaluator::calcAuc() const {
}
}
real
AucEvaluator
::
getValueImpl
()
const
{
return
calcAuc
();
}
std
::
string
AucEvaluator
::
getTypeImpl
()
const
{
if
(
colIdx_
==
-
1
)
{
return
"last-column-auc"
;
}
else
{
return
"auc"
;
}
}
// class RankAucEvaluator
REGISTER_EVALUATOR
(
rankauc
,
RankAucEvaluator
);
...
...
@@ -548,12 +578,15 @@ double RankAucEvaluator::calcRankAuc(real* outputData,
:
aucTmp
/
(
clickSum
*
noClickSum
);
}
std
::
string
RankAucEvaluator
::
getTypeImpl
()
const
{
return
"rankauc"
;
}
// class PrecisionRecallEvaluator
REGISTER_EVALUATOR
(
precision_recall
,
PrecisionRecallEvaluator
);
void
PrecisionRecallEvaluator
::
start
()
{
Evaluator
::
start
();
statsInfo_
.
clear
();
values_
.
clear
();
}
real
PrecisionRecallEvaluator
::
evalImp
(
std
::
vector
<
Argument
>&
arguments
)
{
...
...
@@ -614,52 +647,23 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
}
void
PrecisionRecallEvaluator
::
printStats
(
std
::
ostream
&
os
)
const
{
int
label
=
config_
.
positive_label
();
if
(
label
!=
-
1
)
{
CHECK
(
label
>=
0
&&
label
<
(
int
)
statsInfo_
.
size
())
<<
"positive_label ["
<<
label
<<
"] should be in range [0, "
<<
statsInfo_
.
size
()
<<
")"
;
double
precision
=
calcPrecision
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FP
);
double
recall
=
calcRecall
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FN
);
os
<<
"positive_label="
<<
label
<<
" precision="
<<
precision
<<
" recall="
<<
recall
<<
" F1-score="
<<
calcF1Score
(
precision
,
recall
);
return
;
}
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
// macro average method: precision = (precision1+precision2)/2
double
microTotalTP
=
0
;
double
microTotalFP
=
0
;
double
microTotalFN
=
0
;
double
macroAvgPrecision
=
0
;
double
macroAvgRecall
=
0
;
size_t
numLabels
=
statsInfo_
.
size
();
for
(
size_t
i
=
0
;
i
<
numLabels
;
++
i
)
{
microTotalTP
+=
statsInfo_
[
i
].
TP
;
microTotalFP
+=
statsInfo_
[
i
].
FP
;
microTotalFN
+=
statsInfo_
[
i
].
FN
;
macroAvgPrecision
+=
calcPrecision
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FP
);
macroAvgRecall
+=
calcRecall
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FN
);
}
macroAvgPrecision
/=
numLabels
;
macroAvgRecall
/=
numLabels
;
double
macroAvgF1Score
=
calcF1Score
(
macroAvgPrecision
,
macroAvgRecall
);
os
<<
"macro-average-precision="
<<
macroAvgPrecision
<<
" macro-average-recall="
<<
macroAvgRecall
<<
" macro-average-F1-score="
<<
macroAvgF1Score
;
double
microAvgPrecision
=
calcPrecision
(
microTotalTP
,
microTotalFP
);
double
microAvgRecall
=
calcPrecision
(
microTotalTP
,
microTotalFN
);
double
microAvgF1Score
=
calcF1Score
(
microAvgPrecision
,
microAvgRecall
);
PrintStatsInfo
info
;
bool
containMacroMicroInfo
=
getStatsInfo
(
&
info
);
os
<<
"positive_label="
<<
config_
.
positive_label
()
<<
" precision="
<<
info
.
precision
<<
" recall="
<<
info
.
recall
<<
" F1-score="
<<
info
.
f1
;
if
(
containMacroMicroInfo
)
{
os
<<
"macro-average-precision="
<<
info
.
macroAvgPrecision
<<
" macro-average-recall="
<<
info
.
macroAvgRecall
<<
" macro-average-F1-score="
<<
info
.
macroAvgF1Score
;
if
(
!
isMultiBinaryLabel_
)
{
// precision and recall are equal in this case
os
<<
" micro-average-precision="
<<
microAvgPrecision
;
os
<<
" micro-average-precision="
<<
info
.
microAvgPrecision
;
}
else
{
os
<<
" micro-average-precision="
<<
microAvgPrecision
<<
" micro-average-recall="
<<
microAvgRecall
<<
" micro-average-F1-score="
<<
microAvgF1Score
;
os
<<
" micro-average-precision="
<<
info
.
microAvgPrecision
<<
" micro-average-recall="
<<
info
.
microAvgRecall
<<
" micro-average-F1-score="
<<
info
.
microAvgF1Score
;
}
}
}
...
...
@@ -741,6 +745,60 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
}
}
void
PrecisionRecallEvaluator
::
storeLocalValues
()
const
{
if
(
this
->
values_
.
size
()
==
0
)
{
PrintStatsInfo
info
;
bool
containMacroMicroInfo
=
getStatsInfo
(
&
info
);
values_
[
"precision"
]
=
info
.
precision
;
values_
[
"recal"
]
=
info
.
recall
;
values_
[
"F1-score"
]
=
info
.
f1
;
if
(
containMacroMicroInfo
)
{
values_
[
"macro-average-precision"
]
=
info
.
macroAvgPrecision
;
values_
[
"macro-average-recall"
]
=
info
.
macroAvgRecall
;
values_
[
"macro-average-F1-score"
]
=
info
.
macroAvgF1Score
;
if
(
!
isMultiBinaryLabel_
)
{
// precision and recall are equal in this case
values_
[
"micro-average-precision"
]
=
info
.
microAvgPrecision
;
}
else
{
values_
[
"micro-average-precision"
]
=
info
.
microAvgPrecision
;
values_
[
"micro-average-recall"
]
=
info
.
microAvgRecall
;
values_
[
"micro-average-F1-score"
]
=
info
.
microAvgF1Score
;
}
}
}
}
void
PrecisionRecallEvaluator
::
getNames
(
std
::
vector
<
std
::
string
>*
names
)
{
this
->
storeLocalValues
();
names
->
reserve
(
this
->
values_
.
size
());
for
(
auto
it
=
this
->
values_
.
begin
();
it
!=
this
->
values_
.
end
();
++
it
)
{
names
->
push_back
(
this
->
config_
.
name
()
+
"."
+
it
->
first
);
}
}
real
PrecisionRecallEvaluator
::
getValue
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
this
->
storeLocalValues
();
std
::
vector
<
std
::
string
>
buffers
;
paddle
::
str
::
split
(
name
,
'.'
,
&
buffers
);
auto
it
=
this
->
values_
.
find
(
buffers
[
buffers
.
size
()
-
1
]);
if
(
it
==
this
->
values_
.
end
())
{
// not found
*
err
=
Error
(
"No such key %s"
,
name
.
c_str
());
return
.0
f
;
}
return
it
->
second
;
}
std
::
string
PrecisionRecallEvaluator
::
getType
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
this
->
getValue
(
name
,
err
);
if
(
!
err
->
isOK
())
{
return
""
;
}
return
"precision_recall"
;
}
void
PrecisionRecallEvaluator
::
distributeEval
(
ParameterClient2
*
client
)
{
size_t
size
=
4
*
statsInfo_
.
size
();
double
*
buf
=
new
double
[
size
];
...
...
@@ -760,6 +818,47 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
delete
[]
buf
;
}
bool
PrecisionRecallEvaluator
::
getStatsInfo
(
PrecisionRecallEvaluator
::
PrintStatsInfo
*
info
)
const
{
int
label
=
config_
.
positive_label
();
if
(
label
!=
-
1
)
{
CHECK
(
label
>=
0
&&
label
<
(
int
)
statsInfo_
.
size
())
<<
"positive_label ["
<<
label
<<
"] should be in range [0, "
<<
statsInfo_
.
size
()
<<
")"
;
info
->
precision
=
calcPrecision
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FP
);
info
->
recall
=
calcRecall
(
statsInfo_
[
label
].
TP
,
statsInfo_
[
label
].
FN
);
info
->
f1
=
calcF1Score
(
info
->
precision
,
info
->
recall
);
return
false
;
}
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
// macro average method: precision = (precision1+precision2)/2
double
microTotalTP
=
0
;
double
microTotalFP
=
0
;
double
microTotalFN
=
0
;
info
->
macroAvgPrecision
=
0
;
info
->
macroAvgRecall
=
0
;
size_t
numLabels
=
statsInfo_
.
size
();
for
(
size_t
i
=
0
;
i
<
numLabels
;
++
i
)
{
microTotalTP
+=
statsInfo_
[
i
].
TP
;
microTotalFP
+=
statsInfo_
[
i
].
FP
;
microTotalFN
+=
statsInfo_
[
i
].
FN
;
info
->
macroAvgPrecision
+=
calcPrecision
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FP
);
info
->
macroAvgRecall
+=
calcRecall
(
statsInfo_
[
i
].
TP
,
statsInfo_
[
i
].
FN
);
}
info
->
macroAvgPrecision
/=
numLabels
;
info
->
macroAvgRecall
/=
numLabels
;
info
->
macroAvgF1Score
=
calcF1Score
(
info
->
macroAvgPrecision
,
info
->
macroAvgRecall
);
info
->
microAvgPrecision
=
calcPrecision
(
microTotalTP
,
microTotalFP
);
info
->
microAvgRecall
=
calcPrecision
(
microTotalTP
,
microTotalFN
);
info
->
microAvgF1Score
=
calcF1Score
(
info
->
microAvgPrecision
,
info
->
microAvgRecall
);
return
true
;
}
REGISTER_EVALUATOR
(
pnpair
,
PnpairEvaluator
);
void
PnpairEvaluator
::
start
()
{
Evaluator
::
start
();
...
...
@@ -884,6 +983,8 @@ void PnpairEvaluator::calc(std::vector<PredictionResult>& predictArray) {
<<
" calc total special pair: "
<<
special
;
}
std
::
string
PnpairEvaluator
::
getTypeImpl
()
const
{
return
"pnpair"
;
}
ClassRegistrar
<
Evaluator
>
Evaluator
::
registrar_
;
Evaluator
*
Evaluator
::
create
(
const
EvaluatorConfig
&
config
)
{
Evaluator
*
evaluator
=
registrar_
.
createByType
(
config
.
type
());
...
...
@@ -905,7 +1006,7 @@ static InitFunction __reg_type_auc_sum__([]() {
*
* The config file api is value_printer_evaluator.
*/
class
ValuePrinter
:
public
Evaluator
{
class
ValuePrinter
:
public
NotGetable
Evaluator
{
public:
virtual
void
eval
(
const
NeuralNetwork
&
nn
)
{
for
(
const
std
::
string
&
name
:
config_
.
input_layers
())
{
...
...
@@ -919,12 +1020,13 @@ public:
virtual
real
evalImp
(
std
::
vector
<
Argument
>&
arguments
)
{
return
0
;
}
};
REGISTER_EVALUATOR
(
value_printer
,
ValuePrinter
);
/**
* @brief print gradient of each layer.
*
* The config file api is gradient_printer_evaluator.
*/
class
GradientPrinter
:
public
Evaluator
{
class
GradientPrinter
:
public
NotGetable
Evaluator
{
public:
virtual
void
eval
(
const
NeuralNetwork
&
nn
)
{
for
(
const
std
::
string
&
name
:
config_
.
input_layers
())
{
...
...
@@ -947,7 +1049,7 @@ REGISTER_EVALUATOR(gradient_printer, GradientPrinter);
*
* The config file api is maxid_printer_evaluator.
*/
class
MaxIdPrinter
:
public
Evaluator
{
class
MaxIdPrinter
:
public
NotGetable
Evaluator
{
private:
IVectorPtr
maxIds_
;
MatrixPtr
maxValues_
;
...
...
@@ -989,7 +1091,7 @@ REGISTER_EVALUATOR(max_id_printer, MaxIdPrinter);
*
* The config file api is maxframe_printer_evaluator.
*/
class
MaxFramePrinter
:
public
Evaluator
{
class
MaxFramePrinter
:
public
NotGetable
Evaluator
{
private:
IVectorPtr
maxIds_
;
MatrixPtr
maxValues_
;
...
...
@@ -1076,7 +1178,7 @@ REGISTER_EVALUATOR(max_frame_printer, MaxFramePrinter);
* The config file api is seqtext_printer_evaluator.
*
*/
class
SequenceTextPrinter
:
public
Evaluator
{
class
SequenceTextPrinter
:
public
NotGetable
Evaluator
{
private:
/// dict_file, which contains a list of tokens
std
::
vector
<
std
::
string
>
dict_
;
...
...
@@ -1243,4 +1345,6 @@ public:
};
REGISTER_EVALUATOR
(
classification_error_printer
,
ClassificationErrorPrinter
);
std
::
string
DummyEvaluator
::
getTypeImpl
()
const
{
return
"dummy"
;
}
}
// namespace paddle
paddle/gserver/evaluators/Evaluator.h
浏览文件 @
108f195d
...
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/parameter/Argument.h"
#include "paddle/pserver/ParameterClient2.h"
#include "paddle/utils/ClassRegistrar.h"
#include "paddle/utils/Error.h"
namespace
paddle
{
...
...
@@ -117,12 +118,105 @@ public:
static
ClassRegistrar
<
Evaluator
>
registrar_
;
/**
* @brief getNames will return all field names of current evaluator.
*
* The format of name is `evaluator_name.evaluator_fields`. If the evaluator
* has multiple field, the name could be `evaluator_name.field1`. For example
* the PrecisionRecallEvaluator contains `precision`, `recall` fields. The get
* names will return `precision_recall_evaluator.precision`,
* `precision_recall_evaluator.recal`, etc.
*
* Also, if current Evaluator is a combined evaluator. getNames will return
* all names of all evaluators inside the combined evaluator.
*
* @param names [out]: the field names of current evaluator.
* @note Never clear the names parameter inside getNames.
*/
virtual
void
getNames
(
std
::
vector
<
std
::
string
>*
names
)
{
names
->
push_back
(
config_
.
name
());
}
/**
* @brief getValue will return the current evaluate value of one field.
*
* @param name: The field name of current evaluator.
* @param err [out]: The error state.
*
* @return The evaluate value(metric).
*/
virtual
real
getValue
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
if
(
name
!=
config_
.
name
())
{
*
err
=
Error
(
"no such name of evaluator %s"
,
name
.
c_str
());
return
.0
f
;
}
return
this
->
getValueImpl
();
}
/**
* @brief getType will return the evaluator type by field name.
*
* Evaluate Type is the current type of evaluator in string. Such as 'auc',
* 'precision_recall'. In combined evaluator, different name may get different
* evaluate type because it could be evaluated by different evaluator inside.
*
* @param name: The field name of current Evaluator.
* @param err: The error state. nullptr means don't care.
* @return the evaluator type string.
*/
virtual
std
::
string
getType
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
if
(
name
!=
config_
.
name
())
{
*
err
=
Error
(
"no such name of evaluator %s"
,
name
.
c_str
());
return
std
::
string
();
}
return
this
->
getTypeImpl
();
}
protected:
/**
* @brief getValueImpl The simplest way to define getValue result. If this
* evaluator doesn't contain multiple fields, and do not throw any error, just
* implemented this method to get the evaluate result(metric).
* @return Evaluate result(metric).
*/
virtual
real
getValueImpl
()
const
{
return
numSamples_
!=
.0
?
totalScore_
/
numSamples_
:
.0
;
}
/**
* @brief getTypeImpl The simplest way to define getType result. If this
* evaluator doesn't combine many evaluators, the get type should only return
* itself type.
* @return Evaluator type.
*/
virtual
std
::
string
getTypeImpl
()
const
{
return
"base"
;
}
protected:
EvaluatorConfig
config_
;
double
numSamples_
;
double
totalScore_
;
};
/**
* @brief The NotGetableEvaluator class is the base class of evaluator that
* cannot get value in runtime. The most NotGetableEvaluator is Printer
* Evaluator, which is only used to debug network configuration.
*/
class
NotGetableEvaluator
:
public
Evaluator
{
// Evaluator interface
public:
void
getNames
(
std
::
vector
<
std
::
string
>*
names
)
{}
real
getValue
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
*
err
=
Error
(
"Not implemented"
);
return
.0
f
;
}
std
::
string
getType
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
*
err
=
Error
(
"Not implemented"
);
return
""
;
}
};
class
DummyEvaluator
:
public
Evaluator
{
public:
DummyEvaluator
()
{}
...
...
@@ -135,6 +229,10 @@ public:
}
virtual
void
finish
()
{}
virtual
void
printStats
(
std
::
ostream
&
)
const
{}
// Evaluator interface
protected:
std
::
string
getTypeImpl
()
const
;
};
/**
* @brief evaluate AUC using colIdx-th column as prediction.
...
...
@@ -191,6 +289,11 @@ private:
}
double
calcAuc
()
const
;
// Evaluator interface
protected:
real
getValueImpl
()
const
;
std
::
string
getTypeImpl
()
const
;
};
/**
...
...
@@ -223,6 +326,10 @@ private:
real
*
clickData
,
real
*
pvData
,
size_t
size
);
// Evaluator interface
protected:
std
::
string
getTypeImpl
()
const
;
};
/**
* @brief precision, recall and f1 score Evaluator
...
...
@@ -272,6 +379,20 @@ private:
IVectorPtr
cpuLabel_
;
MatrixPtr
cpuWeight_
;
struct
PrintStatsInfo
{
double
precision
;
double
recall
;
double
f1
;
double
macroAvgPrecision
;
double
macroAvgRecall
;
double
macroAvgF1Score
;
double
microAvgPrecision
;
double
microAvgRecall
;
double
microAvgF1Score
;
};
bool
getStatsInfo
(
PrintStatsInfo
*
info
)
const
;
void
calcStatsInfo
(
const
MatrixPtr
&
output
,
const
IVectorPtr
&
label
,
const
MatrixPtr
&
weight
);
...
...
@@ -303,6 +424,15 @@ private:
return
0
;
}
}
mutable
std
::
unordered_map
<
std
::
string
,
real
>
values_
;
void
storeLocalValues
()
const
;
// Evaluator interface
public:
void
getNames
(
std
::
vector
<
std
::
string
>*
names
);
real
getValue
(
const
std
::
string
&
name
,
Error
*
err
)
const
;
std
::
string
getType
(
const
std
::
string
&
name
,
Error
*
err
)
const
;
};
/*
...
...
@@ -349,8 +479,7 @@ public:
virtual
void
finish
()
{
calc
(
predictArray_
);
}
virtual
void
printStats
(
std
::
ostream
&
os
)
const
{
os
<<
" pos/neg"
<<
"="
<<
pairArray_
[
0
]
/
((
pairArray_
[
1
]
<=
0
)
?
1.0
:
pairArray_
[
1
]);
os
<<
" pos/neg="
<<
this
->
getValueImpl
();
}
virtual
void
distributeEval
(
ParameterClient2
*
client
)
{
...
...
@@ -366,6 +495,13 @@ private:
IVectorPtr
cpuLabel_
;
IVectorPtr
cpuInfo_
;
MatrixPtr
cpuWeight_
;
// Evaluator interface
protected:
real
getValueImpl
()
const
{
return
pairArray_
[
0
]
/
((
pairArray_
[
1
]
<=
0
)
?
1.0
:
pairArray_
[
1
]);
}
std
::
string
getTypeImpl
()
const
;
};
}
// namespace paddle
paddle/gserver/gradientmachines/NeuralNetwork.cpp
浏览文件 @
108f195d
...
...
@@ -306,7 +306,6 @@ void NeuralNetwork::onPassEnd() {
class
CombinedEvaluator
:
public
Evaluator
{
public:
CombinedEvaluator
()
{}
void
addEvaluator
(
std
::
unique_ptr
<
Evaluator
>&&
evaluator
)
{
evaluators_
.
emplace_back
(
std
::
move
(
evaluator
));
}
...
...
@@ -346,6 +345,55 @@ public:
protected:
std
::
vector
<
std
::
unique_ptr
<
Evaluator
>>
evaluators_
;
// Evaluator interface
public:
/**
* @brief getNames will return all inside evaluators' names.
* @param names [out]: return names.
*/
void
getNames
(
std
::
vector
<
std
::
string
>*
names
)
{
for
(
auto
&
eval
:
evaluators_
)
{
eval
->
getNames
(
names
);
}
}
/**
* @brief getValue could get all inside evaluators' value.
*/
real
getValue
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
return
this
->
getMethodHelper
<
real
>
(
name
,
err
,
[
&
name
,
err
](
const
std
::
unique_ptr
<
Evaluator
>&
eval
)
{
return
eval
->
getValue
(
name
,
err
);
});
}
/**
* @brief getType could get all inside evaluators' type.
*/
std
::
string
getType
(
const
std
::
string
&
name
,
Error
*
err
)
const
{
return
this
->
getMethodHelper
<
std
::
string
>
(
name
,
err
,
[
&
name
,
err
](
const
std
::
unique_ptr
<
Evaluator
>&
eval
)
{
return
eval
->
getType
(
name
,
err
);
});
}
private:
template
<
typename
T
>
T
getMethodHelper
(
const
std
::
string
&
name
,
Error
*
err
,
const
std
::
function
<
T
(
const
std
::
unique_ptr
<
Evaluator
>&
)
>&
callback
)
const
{
for
(
auto
&
eval
:
evaluators_
)
{
std
::
vector
<
std
::
string
>
names
;
eval
->
getNames
(
&
names
);
if
(
std
::
find
(
names
.
begin
(),
names
.
end
(),
name
)
!=
names
.
end
())
{
return
callback
(
eval
);
}
}
*
err
=
Error
(
"No such key %s"
,
name
.
c_str
());
return
T
();
}
};
Evaluator
*
NeuralNetwork
::
makeEvaluator
()
const
{
...
...
paddle/gserver/tests/test_Evaluator.cpp
浏览文件 @
108f195d
...
...
@@ -110,6 +110,18 @@ void testEvaluator(TestConfig testConf,
testEvaluator
->
finish
();
LOG
(
INFO
)
<<
*
testEvaluator
;
std
::
vector
<
std
::
string
>
names
;
testEvaluator
->
getNames
(
&
names
);
paddle
::
Error
err
;
for
(
auto
&
name
:
names
)
{
auto
value
=
testEvaluator
->
getValue
(
name
,
&
err
);
ASSERT_TRUE
(
err
.
isOK
());
LOG
(
INFO
)
<<
name
<<
" "
<<
value
;
auto
tp
=
testEvaluator
->
getType
(
name
,
&
err
);
ASSERT_TRUE
(
err
.
isOK
());
ASSERT_EQ
(
testConf
.
evaluatorConfig
.
type
(),
tp
);
}
double
totalScore2
=
0.0
;
if
(
testConf
.
testAccumulate
)
{
testEvaluator
->
start
();
...
...
paddle/utils/Error.h
浏览文件 @
108f195d
...
...
@@ -37,10 +37,10 @@ namespace paddle {
*
* Error __must_check bar() {
* // do something.
*
Status s
= foo(); // invoke other method return status.
* if (
!s) return s
;
*
Error err
= foo(); // invoke other method return status.
* if (
err) return err
;
* // do something else.
* return
Status
();
* return
Error
();
* }
* @endcode{cpp}
*
...
...
@@ -53,8 +53,8 @@ namespace paddle {
*
* int foo(Error* error) {
* // Do something.
* Error
s
= bar();
* if (
!s
) {
* Error
err
= bar();
* if (
err
) {
* *error = s;
* return 0;
* }
...
...
@@ -68,10 +68,10 @@ namespace paddle {
* }
*
* Error foobar() {
* Error
s
;
* Error
err
;
* // do something.
* foo(&
s
);
* if (
!s) return s
;
* foo(&
err
);
* if (
err) return err
;
* }
* @endcode{cpp}
*
...
...
@@ -112,16 +112,22 @@ public:
}
/**
* @brief operator bool, return True if there is
no
error.
* @brief operator bool, return True if there is
something
error.
*/
operator
bool
()
const
{
return
msg_
==
nullptr
;
}
operator
bool
()
const
{
return
!
this
->
isOK
();
}
/**
* @brief isOK return True if there is no error.
* @return True if no error.
*/
bool
isOK
()
const
{
return
msg_
==
nullptr
;
}
/**
* @brief check this status by glog.
* @note It is a temp method used during cleaning Paddle code. It will be
* removed later.
*/
void
check
()
const
{
CHECK
(
*
this
)
<<
msg
();
}
void
check
()
const
{
CHECK
(
this
->
isOK
()
)
<<
msg
();
}
private:
std
::
shared_ptr
<
std
::
string
>
msg_
;
...
...
paddle/utils/tests/test_Error.cpp
浏览文件 @
108f195d
...
...
@@ -18,17 +18,17 @@ limitations under the License. */
TEST
(
Error
,
testAll
)
{
paddle
::
Error
error
;
ASSERT_TRUE
(
error
);
error
=
paddle
::
Error
(
"I'm the error"
);
ASSERT_FALSE
(
error
);
error
=
paddle
::
Error
(
"I'm the error"
);
ASSERT_TRUE
(
error
);
ASSERT_STREQ
(
"I'm the error"
,
error
.
msg
());
error
=
paddle
::
Error
(
"error2"
);
ASSERT_
FALS
E
(
error
);
ASSERT_
TRU
E
(
error
);
ASSERT_STREQ
(
"error2"
,
error
.
msg
());
int
i
=
3
;
auto
error3
=
paddle
::
Error
(
"error%d"
,
i
);
ASSERT_
FALS
E
(
error3
);
ASSERT_
TRU
E
(
error3
);
ASSERT_STREQ
(
"error3"
,
error3
.
msg
());
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录