提交 e768721c 编写于 作者: L Liang Zhao

fix calculating totalScore2_ bug

上级 413cbb84
...@@ -39,6 +39,13 @@ void Evaluator::eval(const NeuralNetwork& nn) { ...@@ -39,6 +39,13 @@ void Evaluator::eval(const NeuralNetwork& nn) {
*/ */
class ClassificationErrorEvaluator : public Evaluator { class ClassificationErrorEvaluator : public Evaluator {
public: public:
ClassificationErrorEvaluator() : totalScore2_(0) {}
virtual void start() {
Evaluator::start();
totalScore2_ = 0;
}
virtual void updateSamplesNum(const std::vector<Argument>& arguments) { virtual void updateSamplesNum(const std::vector<Argument>& arguments) {
if (3 == arguments.size()) { if (3 == arguments.size()) {
numSamples_ += arguments[2].value->getSum(); numSamples_ += arguments[2].value->getSum();
...@@ -85,14 +92,15 @@ public: ...@@ -85,14 +92,15 @@ public:
if (label != nullptr) { if (label != nullptr) {
errorMat->classificationError(*output, *label); // top-1 error errorMat->classificationError(*output, *label); // top-1 error
if (config_.top_k() > 1) {
size_t height = output->getHeight(); size_t height = output->getHeight();
size_t width = 5; size_t width = config_.top_k();
IVector::resizeOrCreate( IVector::resizeOrCreate(
maxIds_, height * width, useGpu(arguments[0].deviceId)); maxIds_, height * width, useGpu(arguments[0].deviceId));
Matrix::resizeOrCreate( Matrix::resizeOrCreate(
maxValues_, height, width, false, useGpu(arguments[0].deviceId)); maxValues_, height, width, false, useGpu(arguments[0].deviceId));
output->rowMax(*maxIds_, *maxValues_); // top-5 values output->rowMax(*maxIds_, *maxValues_); // top-k values
int* ids = nullptr; int* ids = nullptr;
int* lbl = nullptr; int* lbl = nullptr;
...@@ -115,15 +123,16 @@ public: ...@@ -115,15 +123,16 @@ public:
real* result2 = errorMat2->getData(); real* result2 = errorMat2->getData();
for (size_t i = 0; i < height; ++i) { for (size_t i = 0; i < height; ++i) {
result2[i] = (ids[i * width] != lbl[i]); // initialize top-5 error result2[i] = (ids[i * width] != lbl[i]); // initialize top-k error
for (size_t j = 1; j < width; ++j) { for (size_t j = 1; j < width; ++j) {
if (result2[i] == 0.0) { if (result2[i] == 0.0) {
break; break;
} }
result2[i] = (ids[i * width + j] != lbl[i]); // top-5 error result2[i] = (ids[i * width + j] != lbl[i]); // top-k error
} }
} }
totalScore2_ = errorMat2->getSum(); totalScore2_ += errorMat2->getSum();
}
} else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) || } else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) ||
dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) { dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) {
errorMat->classificationErrorMulti( errorMat->classificationErrorMulti(
...@@ -140,8 +149,14 @@ public: ...@@ -140,8 +149,14 @@ public:
} }
void printStats(std::ostream& os) const { void printStats(std::ostream& os) const {
if (config_.top_k() == 1) {
os << config_.name() << "="
<< (numSamples_ ? totalScore_ / numSamples_ : 0);
} else {
os << "top_1_error=" << (numSamples_ ? totalScore_ / numSamples_ : 0) os << "top_1_error=" << (numSamples_ ? totalScore_ / numSamples_ : 0)
<< " top_5_error=" << (numSamples_ ? totalScore2_ / numSamples_ : 0); << " top_" << config_.top_k()
<< "_error=" << (numSamples_ ? totalScore2_ / numSamples_ : 0);
}
} }
virtual real evalImp(std::vector<Argument>& arguments) { virtual real evalImp(std::vector<Argument>& arguments) {
...@@ -150,7 +165,11 @@ public: ...@@ -150,7 +165,11 @@ public:
} }
virtual void distributeEval(ParameterClient2* client) { virtual void distributeEval(ParameterClient2* client) {
mergeResultsOfAllClients(client); double data[3] = {totalScore_, totalScore2_, numSamples_};
client->reduce(data, data, 3, FLAGS_trainer_id, 0);
totalScore_ = data[0];
totalScore2_ = data[1];
numSamples_ = data[2];
} }
private: private:
......
...@@ -129,6 +129,7 @@ void testEvaluatorAll(TestConfig testConf, ...@@ -129,6 +129,7 @@ void testEvaluatorAll(TestConfig testConf,
TEST(Evaluator, classification_error) { TEST(Evaluator, classification_error) {
TestConfig config; TestConfig config;
config.evaluatorConfig.set_type("classification_error"); config.evaluatorConfig.set_type("classification_error");
config.evaluatorConfig.set_top_k(5);
config.inputDefs.push_back({INPUT_DATA, "output", 50}); config.inputDefs.push_back({INPUT_DATA, "output", 50});
config.inputDefs.push_back({INPUT_LABEL, "label", 50}); config.inputDefs.push_back({INPUT_LABEL, "label", 50});
......
...@@ -475,6 +475,10 @@ message EvaluatorConfig { ...@@ -475,6 +475,10 @@ message EvaluatorConfig {
// Used by ChunkEvaluator // Used by ChunkEvaluator
// chunk of these types are not counted // chunk of these types are not counted
repeated int32 excluded_chunk_types = 12; repeated int32 excluded_chunk_types = 12;
// Used by ClassificationErrorEvaluator
// top # classification error
optional int32 top_k = 13 [default = 1];
} }
message LinkConfig { message LinkConfig {
......
...@@ -1253,6 +1253,7 @@ def Evaluator( ...@@ -1253,6 +1253,7 @@ def Evaluator(
dict_file=None, dict_file=None,
result_file=None, result_file=None,
num_results=None, num_results=None,
top_k=None,
delimited=None, delimited=None,
excluded_chunk_types=None, ): excluded_chunk_types=None, ):
evaluator = g_config.model_config.evaluators.add() evaluator = g_config.model_config.evaluators.add()
...@@ -1280,6 +1281,8 @@ def Evaluator( ...@@ -1280,6 +1281,8 @@ def Evaluator(
evaluator.result_file = result_file evaluator.result_file = result_file
if num_results is not None: if num_results is not None:
evaluator.num_results = num_results evaluator.num_results = num_results
if top_k is not None:
evaluator.top_k = top_k
if delimited is not None: if delimited is not None:
evaluator.delimited = delimited evaluator.delimited = delimited
......
...@@ -71,6 +71,7 @@ def evaluator_base( ...@@ -71,6 +71,7 @@ def evaluator_base(
result_file=None, result_file=None,
num_results=None, num_results=None,
delimited=None, delimited=None,
top_k=None,
excluded_chunk_types=None, ): excluded_chunk_types=None, ):
""" """
Evaluator will evaluate the network status while training/testing. Evaluator will evaluate the network status while training/testing.
...@@ -104,12 +105,15 @@ def evaluator_base( ...@@ -104,12 +105,15 @@ def evaluator_base(
:param weight: An input layer which is a weight for each sample. :param weight: An input layer which is a weight for each sample.
Each evaluator may calculate differently to use this weight. Each evaluator may calculate differently to use this weight.
:type weight: LayerOutput. :type weight: LayerOutput.
:param top_k: number k in top-k error rate
:type top_k: int
""" """
# inputs type assertions. # inputs type assertions.
assert classification_threshold is None or isinstance( assert classification_threshold is None or isinstance(
classification_threshold, float) classification_threshold, float)
assert positive_label is None or isinstance(positive_label, int) assert positive_label is None or isinstance(positive_label, int)
assert num_results is None or isinstance(num_results, int) assert num_results is None or isinstance(num_results, int)
assert top_k is None or isinstance(top_k, int)
if not isinstance(input, list): if not isinstance(input, list):
input = [input] input = [input]
...@@ -130,6 +134,8 @@ def evaluator_base( ...@@ -130,6 +134,8 @@ def evaluator_base(
dict_file=dict_file, dict_file=dict_file,
result_file=result_file, result_file=result_file,
delimited=delimited, delimited=delimited,
num_results=num_results,
top_k=top_k,
excluded_chunk_types=excluded_chunk_types, ) excluded_chunk_types=excluded_chunk_types, )
...@@ -139,6 +145,7 @@ def classification_error_evaluator(input, ...@@ -139,6 +145,7 @@ def classification_error_evaluator(input,
label, label,
name=None, name=None,
weight=None, weight=None,
top_k=None,
threshold=None): threshold=None):
""" """
Classification Error Evaluator. It will print error rate for classification. Classification Error Evaluator. It will print error rate for classification.
...@@ -167,6 +174,8 @@ def classification_error_evaluator(input, ...@@ -167,6 +174,8 @@ def classification_error_evaluator(input,
then means not set weight. The larger weight it is, the more then means not set weight. The larger weight it is, the more
important this sample is. important this sample is.
:type weight: LayerOutput :type weight: LayerOutput
:param top_k: number k in top-k error rate
:type top_k: int
:param threshold: The classification threshold. :param threshold: The classification threshold.
:type threshold: float :type threshold: float
:return: None. :return: None.
...@@ -178,6 +187,7 @@ def classification_error_evaluator(input, ...@@ -178,6 +187,7 @@ def classification_error_evaluator(input,
input=input, input=input,
label=label, label=label,
weight=weight, weight=weight,
top_k=top_k,
classification_threshold=threshold, ) classification_threshold=threshold, )
......
...@@ -3536,6 +3536,7 @@ def classification_cost(input, ...@@ -3536,6 +3536,7 @@ def classification_cost(input,
label, label,
weight=None, weight=None,
name=None, name=None,
top_k=None,
evaluator=classification_error_evaluator, evaluator=classification_error_evaluator,
layer_attr=None): layer_attr=None):
""" """
...@@ -3550,6 +3551,8 @@ def classification_cost(input, ...@@ -3550,6 +3551,8 @@ def classification_cost(input,
:param weight: The weight affects the cost, namely the scale of cost. :param weight: The weight affects the cost, namely the scale of cost.
It is an optional argument. It is an optional argument.
:type weight: LayerOutput :type weight: LayerOutput
:param top_k: number k in top-k error rate
:type top_k: int
:param evaluator: Evaluator method. :param evaluator: Evaluator method.
:param layer_attr: layer's extra attribute. :param layer_attr: layer's extra attribute.
:type layer_attr: ExtraLayerAttribute :type layer_attr: ExtraLayerAttribute
...@@ -3577,7 +3580,7 @@ def classification_cost(input, ...@@ -3577,7 +3580,7 @@ def classification_cost(input,
assert isinstance(e.for_classification, bool) assert isinstance(e.for_classification, bool)
assert e.for_classification assert e.for_classification
e(name=e.__name__, input=input, label=label, weight=weight) e(name=e.__name__, input=input, label=label, weight=weight, top_k=top_k)
if not isinstance(evaluator, collections.Sequence): if not isinstance(evaluator, collections.Sequence):
evaluator = [evaluator] evaluator = [evaluator]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册