提交 c4519574 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1375 from reyoung/feature/EvaluatorValue

Feature/evaluator value
......@@ -20,7 +20,7 @@ namespace paddle {
/**
* calculate sequence-to-sequence edit distance
*/
class CTCErrorEvaluator : public Evaluator {
class CTCErrorEvaluator : public NotGetableEvaluator {
private:
MatrixPtr outActivations_;
int numTimes_, numClasses_, numSequences_, blank_;
......
......@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/gserver/evaluators/Evaluator.h"
#include "paddle/utils/Stat.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/StringUtil.h"
DECLARE_int32(trainer_id);
......@@ -122,6 +122,10 @@ public:
virtual void distributeEval(ParameterClient2* client) {
mergeResultsOfAllClients(client);
}
// Evaluator interface
protected:
std::string getTypeImpl() const { return "classification_error"; }
};
/**
......@@ -160,6 +164,10 @@ public:
virtual void distributeEval(ParameterClient2* client) {
mergeResultsOfAllClients(client);
}
// Evaluator interface
protected:
std::string getTypeImpl() const { return "seq_classification_error"; }
};
REGISTER_EVALUATOR(seq_classification_error,
SequenceClassificationErrorEvaluator);
......@@ -250,6 +258,10 @@ public:
private:
IVectorPtr cpuLabel_;
MatrixPtr cpuWeight_;
// Evaluator interface
protected:
std::string getTypeImpl() const { return "sum"; }
};
/**
* @brief column sum Evaluator
......@@ -357,10 +369,18 @@ public:
}
private:
ColumnSumEvaluator() {}
int32_t colIdx_;
size_t colNum_;
MatrixPtr sum_; /* cpu matrix */
// Evaluator interface
protected:
std::string getTypeImpl() const {
if (colIdx_ == -1)
return "last-column-sum";
else
return "column-sum";
}
};
void AucEvaluator::start() {
......@@ -469,6 +489,16 @@ double AucEvaluator::calcAuc() const {
}
}
real AucEvaluator::getValueImpl() const { return calcAuc(); }
std::string AucEvaluator::getTypeImpl() const {
if (colIdx_ == -1) {
return "last-column-auc";
} else {
return "auc";
}
}
// class RankAucEvaluator
REGISTER_EVALUATOR(rankauc, RankAucEvaluator);
......@@ -548,12 +578,15 @@ double RankAucEvaluator::calcRankAuc(real* outputData,
: aucTmp / (clickSum * noClickSum);
}
std::string RankAucEvaluator::getTypeImpl() const { return "rankauc"; }
// class PrecisionRecallEvaluator
REGISTER_EVALUATOR(precision_recall, PrecisionRecallEvaluator);
void PrecisionRecallEvaluator::start() {
Evaluator::start();
statsInfo_.clear();
values_.clear();
}
real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
......@@ -614,52 +647,23 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
}
void PrecisionRecallEvaluator::printStats(std::ostream& os) const {
int label = config_.positive_label();
if (label != -1) {
CHECK(label >= 0 && label < (int)statsInfo_.size())
<< "positive_label [" << label << "] should be in range [0, "
<< statsInfo_.size() << ")";
double precision =
calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
double recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
os << "positive_label=" << label << " precision=" << precision
<< " recall=" << recall
<< " F1-score=" << calcF1Score(precision, recall);
return;
}
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
// macro average method: precision = (precision1+precision2)/2
double microTotalTP = 0;
double microTotalFP = 0;
double microTotalFN = 0;
double macroAvgPrecision = 0;
double macroAvgRecall = 0;
size_t numLabels = statsInfo_.size();
for (size_t i = 0; i < numLabels; ++i) {
microTotalTP += statsInfo_[i].TP;
microTotalFP += statsInfo_[i].FP;
microTotalFN += statsInfo_[i].FN;
macroAvgPrecision += calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
}
macroAvgPrecision /= numLabels;
macroAvgRecall /= numLabels;
double macroAvgF1Score = calcF1Score(macroAvgPrecision, macroAvgRecall);
os << "macro-average-precision=" << macroAvgPrecision
<< " macro-average-recall=" << macroAvgRecall
<< " macro-average-F1-score=" << macroAvgF1Score;
double microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
double microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
double microAvgF1Score = calcF1Score(microAvgPrecision, microAvgRecall);
PrintStatsInfo info;
bool containMacroMicroInfo = getStatsInfo(&info);
os << "positive_label=" << config_.positive_label()
<< " precision=" << info.precision << " recall=" << info.recall
<< " F1-score=" << info.f1;
if (containMacroMicroInfo) {
os << "macro-average-precision=" << info.macroAvgPrecision
<< " macro-average-recall=" << info.macroAvgRecall
<< " macro-average-F1-score=" << info.macroAvgF1Score;
if (!isMultiBinaryLabel_) {
// precision and recall are equal in this case
os << " micro-average-precision=" << microAvgPrecision;
os << " micro-average-precision=" << info.microAvgPrecision;
} else {
os << " micro-average-precision=" << microAvgPrecision
<< " micro-average-recall=" << microAvgRecall
<< " micro-average-F1-score=" << microAvgF1Score;
os << " micro-average-precision=" << info.microAvgPrecision
<< " micro-average-recall=" << info.microAvgRecall
<< " micro-average-F1-score=" << info.microAvgF1Score;
}
}
}
......@@ -741,6 +745,60 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
}
}
void PrecisionRecallEvaluator::storeLocalValues() const {
if (this->values_.size() == 0) {
PrintStatsInfo info;
bool containMacroMicroInfo = getStatsInfo(&info);
values_["precision"] = info.precision;
values_["recal"] = info.recall;
values_["F1-score"] = info.f1;
if (containMacroMicroInfo) {
values_["macro-average-precision"] = info.macroAvgPrecision;
values_["macro-average-recall"] = info.macroAvgRecall;
values_["macro-average-F1-score"] = info.macroAvgF1Score;
if (!isMultiBinaryLabel_) {
// precision and recall are equal in this case
values_["micro-average-precision"] = info.microAvgPrecision;
} else {
values_["micro-average-precision"] = info.microAvgPrecision;
values_["micro-average-recall"] = info.microAvgRecall;
values_["micro-average-F1-score"] = info.microAvgF1Score;
}
}
}
}
void PrecisionRecallEvaluator::getNames(std::vector<std::string>* names) {
this->storeLocalValues();
names->reserve(this->values_.size());
for (auto it = this->values_.begin(); it != this->values_.end(); ++it) {
names->push_back(this->config_.name() + "." + it->first);
}
}
real PrecisionRecallEvaluator::getValue(const std::string& name,
Error* err) const {
this->storeLocalValues();
std::vector<std::string> buffers;
paddle::str::split(name, '.', &buffers);
auto it = this->values_.find(buffers[buffers.size() - 1]);
if (it == this->values_.end()) { // not found
*err = Error("No such key %s", name.c_str());
return .0f;
}
return it->second;
}
std::string PrecisionRecallEvaluator::getType(const std::string& name,
Error* err) const {
this->getValue(name, err);
if (!err->isOK()) {
return "";
}
return "precision_recall";
}
void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
size_t size = 4 * statsInfo_.size();
double* buf = new double[size];
......@@ -760,6 +818,47 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
delete[] buf;
}
bool PrecisionRecallEvaluator::getStatsInfo(
PrecisionRecallEvaluator::PrintStatsInfo* info) const {
int label = config_.positive_label();
if (label != -1) {
CHECK(label >= 0 && label < (int)statsInfo_.size())
<< "positive_label [" << label << "] should be in range [0, "
<< statsInfo_.size() << ")";
info->precision = calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
info->recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
info->f1 = calcF1Score(info->precision, info->recall);
return false;
}
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
// macro average method: precision = (precision1+precision2)/2
double microTotalTP = 0;
double microTotalFP = 0;
double microTotalFN = 0;
info->macroAvgPrecision = 0;
info->macroAvgRecall = 0;
size_t numLabels = statsInfo_.size();
for (size_t i = 0; i < numLabels; ++i) {
microTotalTP += statsInfo_[i].TP;
microTotalFP += statsInfo_[i].FP;
microTotalFN += statsInfo_[i].FN;
info->macroAvgPrecision +=
calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
info->macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
}
info->macroAvgPrecision /= numLabels;
info->macroAvgRecall /= numLabels;
info->macroAvgF1Score =
calcF1Score(info->macroAvgPrecision, info->macroAvgRecall);
info->microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
info->microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
info->microAvgF1Score =
calcF1Score(info->microAvgPrecision, info->microAvgRecall);
return true;
}
REGISTER_EVALUATOR(pnpair, PnpairEvaluator);
void PnpairEvaluator::start() {
Evaluator::start();
......@@ -884,6 +983,8 @@ void PnpairEvaluator::calc(std::vector<PredictionResult>& predictArray) {
<< " calc total special pair: " << special;
}
std::string PnpairEvaluator::getTypeImpl() const { return "pnpair"; }
ClassRegistrar<Evaluator> Evaluator::registrar_;
Evaluator* Evaluator::create(const EvaluatorConfig& config) {
Evaluator* evaluator = registrar_.createByType(config.type());
......@@ -905,7 +1006,7 @@ static InitFunction __reg_type_auc_sum__([]() {
*
* The config file api is value_printer_evaluator.
*/
class ValuePrinter : public Evaluator {
class ValuePrinter : public NotGetableEvaluator {
public:
virtual void eval(const NeuralNetwork& nn) {
for (const std::string& name : config_.input_layers()) {
......@@ -919,12 +1020,13 @@ public:
virtual real evalImp(std::vector<Argument>& arguments) { return 0; }
};
REGISTER_EVALUATOR(value_printer, ValuePrinter);
/**
* @brief print gradient of each layer.
*
* The config file api is gradient_printer_evaluator.
*/
class GradientPrinter : public Evaluator {
class GradientPrinter : public NotGetableEvaluator {
public:
virtual void eval(const NeuralNetwork& nn) {
for (const std::string& name : config_.input_layers()) {
......@@ -947,7 +1049,7 @@ REGISTER_EVALUATOR(gradient_printer, GradientPrinter);
*
* The config file api is maxid_printer_evaluator.
*/
class MaxIdPrinter : public Evaluator {
class MaxIdPrinter : public NotGetableEvaluator {
private:
IVectorPtr maxIds_;
MatrixPtr maxValues_;
......@@ -989,7 +1091,7 @@ REGISTER_EVALUATOR(max_id_printer, MaxIdPrinter);
*
* The config file api is maxframe_printer_evaluator.
*/
class MaxFramePrinter : public Evaluator {
class MaxFramePrinter : public NotGetableEvaluator {
private:
IVectorPtr maxIds_;
MatrixPtr maxValues_;
......@@ -1076,7 +1178,7 @@ REGISTER_EVALUATOR(max_frame_printer, MaxFramePrinter);
* The config file api is seqtext_printer_evaluator.
*
*/
class SequenceTextPrinter : public Evaluator {
class SequenceTextPrinter : public NotGetableEvaluator {
private:
/// dict_file, which contains a list of tokens
std::vector<std::string> dict_;
......@@ -1243,4 +1345,6 @@ public:
};
REGISTER_EVALUATOR(classification_error_printer, ClassificationErrorPrinter);
std::string DummyEvaluator::getTypeImpl() const { return "dummy"; }
} // namespace paddle
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/parameter/Argument.h"
#include "paddle/pserver/ParameterClient2.h"
#include "paddle/utils/ClassRegistrar.h"
#include "paddle/utils/Error.h"
namespace paddle {
......@@ -117,12 +118,105 @@ public:
static ClassRegistrar<Evaluator> registrar_;
/**
* @brief getNames will return all field names of current evaluator.
*
* The format of name is `evaluator_name.evaluator_fields`. If the evaluator
* has multiple field, the name could be `evaluator_name.field1`. For example
* the PrecisionRecallEvaluator contains `precision`, `recall` fields. The get
* names will return `precision_recall_evaluator.precision`,
* `precision_recall_evaluator.recal`, etc.
*
* Also, if current Evaluator is a combined evaluator. getNames will return
* all names of all evaluators inside the combined evaluator.
*
* @param names [out]: the field names of current evaluator.
* @note Never clear the names parameter inside getNames.
*/
virtual void getNames(std::vector<std::string>* names) {
names->push_back(config_.name());
}
/**
* @brief getValue will return the current evaluate value of one field.
*
* @param name: The field name of current evaluator.
* @param err [out]: The error state.
*
* @return The evaluate value(metric).
*/
virtual real getValue(const std::string& name, Error* err) const {
if (name != config_.name()) {
*err = Error("no such name of evaluator %s", name.c_str());
return .0f;
}
return this->getValueImpl();
}
/**
* @brief getType will return the evaluator type by field name.
*
* Evaluate Type is the current type of evaluator in string. Such as 'auc',
* 'precision_recall'. In combined evaluator, different name may get different
* evaluate type because it could be evaluated by different evaluator inside.
*
* @param name: The field name of current Evaluator.
* @param err: The error state. nullptr means don't care.
* @return the evaluator type string.
*/
virtual std::string getType(const std::string& name, Error* err) const {
if (name != config_.name()) {
*err = Error("no such name of evaluator %s", name.c_str());
return std::string();
}
return this->getTypeImpl();
}
protected:
/**
* @brief getValueImpl The simplest way to define getValue result. If this
* evaluator doesn't contain multiple fields, and do not throw any error, just
* implemented this method to get the evaluate result(metric).
* @return Evaluate result(metric).
*/
virtual real getValueImpl() const {
return numSamples_ != .0 ? totalScore_ / numSamples_ : .0;
}
/**
* @brief getTypeImpl The simplest way to define getType result. If this
* evaluator doesn't combine many evaluators, the get type should only return
* itself type.
* @return Evaluator type.
*/
virtual std::string getTypeImpl() const { return "base"; }
protected:
EvaluatorConfig config_;
double numSamples_;
double totalScore_;
};
/**
* @brief The NotGetableEvaluator class is the base class of evaluator that
* cannot get value in runtime. The most NotGetableEvaluator is Printer
* Evaluator, which is only used to debug network configuration.
*/
class NotGetableEvaluator : public Evaluator {
// Evaluator interface
public:
void getNames(std::vector<std::string>* names) {}
real getValue(const std::string& name, Error* err) const {
*err = Error("Not implemented");
return .0f;
}
std::string getType(const std::string& name, Error* err) const {
*err = Error("Not implemented");
return "";
}
};
class DummyEvaluator : public Evaluator {
public:
DummyEvaluator() {}
......@@ -135,6 +229,10 @@ public:
}
virtual void finish() {}
virtual void printStats(std::ostream&) const {}
// Evaluator interface
protected:
std::string getTypeImpl() const;
};
/**
* @brief evaluate AUC using colIdx-th column as prediction.
......@@ -191,6 +289,11 @@ private:
}
double calcAuc() const;
// Evaluator interface
protected:
real getValueImpl() const;
std::string getTypeImpl() const;
};
/**
......@@ -223,6 +326,10 @@ private:
real* clickData,
real* pvData,
size_t size);
// Evaluator interface
protected:
std::string getTypeImpl() const;
};
/**
* @brief precision, recall and f1 score Evaluator
......@@ -272,6 +379,20 @@ private:
IVectorPtr cpuLabel_;
MatrixPtr cpuWeight_;
struct PrintStatsInfo {
double precision;
double recall;
double f1;
double macroAvgPrecision;
double macroAvgRecall;
double macroAvgF1Score;
double microAvgPrecision;
double microAvgRecall;
double microAvgF1Score;
};
bool getStatsInfo(PrintStatsInfo* info) const;
void calcStatsInfo(const MatrixPtr& output,
const IVectorPtr& label,
const MatrixPtr& weight);
......@@ -303,6 +424,15 @@ private:
return 0;
}
}
mutable std::unordered_map<std::string, real> values_;
void storeLocalValues() const;
// Evaluator interface
public:
void getNames(std::vector<std::string>* names);
real getValue(const std::string& name, Error* err) const;
std::string getType(const std::string& name, Error* err) const;
};
/*
......@@ -349,8 +479,7 @@ public:
virtual void finish() { calc(predictArray_); }
virtual void printStats(std::ostream& os) const {
os << " pos/neg"
<< "=" << pairArray_[0] / ((pairArray_[1] <= 0) ? 1.0 : pairArray_[1]);
os << " pos/neg=" << this->getValueImpl();
}
virtual void distributeEval(ParameterClient2* client) {
......@@ -366,6 +495,13 @@ private:
IVectorPtr cpuLabel_;
IVectorPtr cpuInfo_;
MatrixPtr cpuWeight_;
// Evaluator interface
protected:
real getValueImpl() const {
return pairArray_[0] / ((pairArray_[1] <= 0) ? 1.0 : pairArray_[1]);
}
std::string getTypeImpl() const;
};
} // namespace paddle
......@@ -306,7 +306,6 @@ void NeuralNetwork::onPassEnd() {
class CombinedEvaluator : public Evaluator {
public:
CombinedEvaluator() {}
void addEvaluator(std::unique_ptr<Evaluator>&& evaluator) {
evaluators_.emplace_back(std::move(evaluator));
}
......@@ -346,6 +345,55 @@ public:
protected:
std::vector<std::unique_ptr<Evaluator>> evaluators_;
// Evaluator interface
public:
/**
* @brief getNames will return all inside evaluators' names.
* @param names [out]: return names.
*/
void getNames(std::vector<std::string>* names) {
for (auto& eval : evaluators_) {
eval->getNames(names);
}
}
/**
* @brief getValue could get all inside evaluators' value.
*/
real getValue(const std::string& name, Error* err) const {
return this->getMethodHelper<real>(
name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) {
return eval->getValue(name, err);
});
}
/**
* @brief getType could get all inside evaluators' type.
*/
std::string getType(const std::string& name, Error* err) const {
return this->getMethodHelper<std::string>(
name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) {
return eval->getType(name, err);
});
}
private:
template <typename T>
T getMethodHelper(const std::string& name,
Error* err,
const std::function<T(const std::unique_ptr<Evaluator>&)>&
callback) const {
for (auto& eval : evaluators_) {
std::vector<std::string> names;
eval->getNames(&names);
if (std::find(names.begin(), names.end(), name) != names.end()) {
return callback(eval);
}
}
*err = Error("No such key %s", name.c_str());
return T();
}
};
Evaluator* NeuralNetwork::makeEvaluator() const {
......
......@@ -110,6 +110,18 @@ void testEvaluator(TestConfig testConf,
testEvaluator->finish();
LOG(INFO) << *testEvaluator;
std::vector<std::string> names;
testEvaluator->getNames(&names);
paddle::Error err;
for (auto& name : names) {
auto value = testEvaluator->getValue(name, &err);
ASSERT_TRUE(err.isOK());
LOG(INFO) << name << " " << value;
auto tp = testEvaluator->getType(name, &err);
ASSERT_TRUE(err.isOK());
ASSERT_EQ(testConf.evaluatorConfig.type(), tp);
}
double totalScore2 = 0.0;
if (testConf.testAccumulate) {
testEvaluator->start();
......
......@@ -37,10 +37,10 @@ namespace paddle {
*
* Error __must_check bar() {
* // do something.
* Status s = foo(); // invoke other method return status.
* if (!s) return s;
* Error err = foo(); // invoke other method return status.
* if (err) return err;
* // do something else.
* return Status();
* return Error();
* }
* @endcode{cpp}
*
......@@ -53,8 +53,8 @@ namespace paddle {
*
* int foo(Error* error) {
* // Do something.
* Error s = bar();
* if (!s) {
* Error err = bar();
* if (err) {
* *error = s;
* return 0;
* }
......@@ -68,10 +68,10 @@ namespace paddle {
* }
*
* Error foobar() {
* Error s;
* Error err;
* // do something.
* foo(&s);
* if (!s) return s;
* foo(&err);
* if (err) return err;
* }
* @endcode{cpp}
*
......@@ -112,16 +112,22 @@ public:
}
/**
* @brief operator bool, return True if there is no error.
* @brief operator bool, return True if there is something error.
*/
operator bool() const { return msg_ == nullptr; }
operator bool() const { return !this->isOK(); }
/**
* @brief isOK return True if there is no error.
* @return True if no error.
*/
bool isOK() const { return msg_ == nullptr; }
/**
* @brief check this status by glog.
* @note It is a temp method used during cleaning Paddle code. It will be
* removed later.
*/
void check() const { CHECK(*this) << msg(); }
void check() const { CHECK(this->isOK()) << msg(); }
private:
std::shared_ptr<std::string> msg_;
......
......@@ -18,17 +18,17 @@ limitations under the License. */
TEST(Error, testAll) {
paddle::Error error;
ASSERT_TRUE(error);
error = paddle::Error("I'm the error");
ASSERT_FALSE(error);
error = paddle::Error("I'm the error");
ASSERT_TRUE(error);
ASSERT_STREQ("I'm the error", error.msg());
error = paddle::Error("error2");
ASSERT_FALSE(error);
ASSERT_TRUE(error);
ASSERT_STREQ("error2", error.msg());
int i = 3;
auto error3 = paddle::Error("error%d", i);
ASSERT_FALSE(error3);
ASSERT_TRUE(error3);
ASSERT_STREQ("error3", error3.msg());
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册