提交 e768721c 编写于 作者: L Liang Zhao

fix calculating totalScore2_ bug

上级 413cbb84
......@@ -39,6 +39,13 @@ void Evaluator::eval(const NeuralNetwork& nn) {
*/
class ClassificationErrorEvaluator : public Evaluator {
public:
ClassificationErrorEvaluator() : totalScore2_(0) {}
virtual void start() {
Evaluator::start();
totalScore2_ = 0;
}
virtual void updateSamplesNum(const std::vector<Argument>& arguments) {
if (3 == arguments.size()) {
numSamples_ += arguments[2].value->getSum();
......@@ -85,14 +92,15 @@ public:
if (label != nullptr) {
errorMat->classificationError(*output, *label); // top-1 error
if (config_.top_k() > 1) {
size_t height = output->getHeight();
size_t width = 5;
size_t width = config_.top_k();
IVector::resizeOrCreate(
maxIds_, height * width, useGpu(arguments[0].deviceId));
Matrix::resizeOrCreate(
maxValues_, height, width, false, useGpu(arguments[0].deviceId));
output->rowMax(*maxIds_, *maxValues_); // top-5 values
output->rowMax(*maxIds_, *maxValues_); // top-k values
int* ids = nullptr;
int* lbl = nullptr;
......@@ -115,15 +123,16 @@ public:
real* result2 = errorMat2->getData();
for (size_t i = 0; i < height; ++i) {
result2[i] = (ids[i * width] != lbl[i]); // initialize top-5 error
result2[i] = (ids[i * width] != lbl[i]); // initialize top-k error
for (size_t j = 1; j < width; ++j) {
if (result2[i] == 0.0) {
break;
}
result2[i] = (ids[i * width + j] != lbl[i]); // top-5 error
result2[i] = (ids[i * width + j] != lbl[i]); // top-k error
}
}
totalScore2_ = errorMat2->getSum();
totalScore2_ += errorMat2->getSum();
}
} else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) ||
dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) {
errorMat->classificationErrorMulti(
......@@ -140,8 +149,14 @@ public:
}
void printStats(std::ostream& os) const {
if (config_.top_k() == 1) {
os << config_.name() << "="
<< (numSamples_ ? totalScore_ / numSamples_ : 0);
} else {
os << "top_1_error=" << (numSamples_ ? totalScore_ / numSamples_ : 0)
<< " top_5_error=" << (numSamples_ ? totalScore2_ / numSamples_ : 0);
<< " top_" << config_.top_k()
<< "_error=" << (numSamples_ ? totalScore2_ / numSamples_ : 0);
}
}
virtual real evalImp(std::vector<Argument>& arguments) {
......@@ -150,7 +165,11 @@ public:
}
virtual void distributeEval(ParameterClient2* client) {
mergeResultsOfAllClients(client);
double data[3] = {totalScore_, totalScore2_, numSamples_};
client->reduce(data, data, 3, FLAGS_trainer_id, 0);
totalScore_ = data[0];
totalScore2_ = data[1];
numSamples_ = data[2];
}
private:
......
......@@ -129,6 +129,7 @@ void testEvaluatorAll(TestConfig testConf,
TEST(Evaluator, classification_error) {
TestConfig config;
config.evaluatorConfig.set_type("classification_error");
config.evaluatorConfig.set_top_k(5);
config.inputDefs.push_back({INPUT_DATA, "output", 50});
config.inputDefs.push_back({INPUT_LABEL, "label", 50});
......
......@@ -475,6 +475,10 @@ message EvaluatorConfig {
// Used by ChunkEvaluator
// chunk of these types are not counted
repeated int32 excluded_chunk_types = 12;
// Used by ClassificationErrorEvaluator
// top # classification error
optional int32 top_k = 13 [default = 1];
}
message LinkConfig {
......
......@@ -1253,6 +1253,7 @@ def Evaluator(
dict_file=None,
result_file=None,
num_results=None,
top_k=None,
delimited=None,
excluded_chunk_types=None, ):
evaluator = g_config.model_config.evaluators.add()
......@@ -1280,6 +1281,8 @@ def Evaluator(
evaluator.result_file = result_file
if num_results is not None:
evaluator.num_results = num_results
if top_k is not None:
evaluator.top_k = top_k
if delimited is not None:
evaluator.delimited = delimited
......
......@@ -71,6 +71,7 @@ def evaluator_base(
result_file=None,
num_results=None,
delimited=None,
top_k=None,
excluded_chunk_types=None, ):
"""
Evaluator will evaluate the network status while training/testing.
......@@ -104,12 +105,15 @@ def evaluator_base(
:param weight: An input layer which is a weight for each sample.
Each evaluator may calculate differently to use this weight.
:type weight: LayerOutput.
:param top_k: number k in top-k error rate
:type top_k: int
"""
# inputs type assertions.
assert classification_threshold is None or isinstance(
classification_threshold, float)
assert positive_label is None or isinstance(positive_label, int)
assert num_results is None or isinstance(num_results, int)
assert top_k is None or isinstance(top_k, int)
if not isinstance(input, list):
input = [input]
......@@ -130,6 +134,8 @@ def evaluator_base(
dict_file=dict_file,
result_file=result_file,
delimited=delimited,
num_results=num_results,
top_k=top_k,
excluded_chunk_types=excluded_chunk_types, )
......@@ -139,6 +145,7 @@ def classification_error_evaluator(input,
label,
name=None,
weight=None,
top_k=None,
threshold=None):
"""
Classification Error Evaluator. It will print error rate for classification.
......@@ -167,6 +174,8 @@ def classification_error_evaluator(input,
then means not set weight. The larger weight it is, the more
important this sample is.
:type weight: LayerOutput
:param top_k: number k in top-k error rate
:type top_k: int
:param threshold: The classification threshold.
:type threshold: float
:return: None.
......@@ -178,6 +187,7 @@ def classification_error_evaluator(input,
input=input,
label=label,
weight=weight,
top_k=top_k,
classification_threshold=threshold, )
......
......@@ -3536,6 +3536,7 @@ def classification_cost(input,
label,
weight=None,
name=None,
top_k=None,
evaluator=classification_error_evaluator,
layer_attr=None):
"""
......@@ -3550,6 +3551,8 @@ def classification_cost(input,
:param weight: The weight affects the cost, namely the scale of cost.
It is an optional argument.
:type weight: LayerOutput
:param top_k: number k in top-k error rate
:type top_k: int
:param evaluator: Evaluator method.
:param layer_attr: layer's extra attribute.
:type layer_attr: ExtraLayerAttribute
......@@ -3577,7 +3580,7 @@ def classification_cost(input,
assert isinstance(e.for_classification, bool)
assert e.for_classification
e(name=e.__name__, input=input, label=label, weight=weight)
e(name=e.__name__, input=input, label=label, weight=weight, top_k=top_k)
if not isinstance(evaluator, collections.Sequence):
evaluator = [evaluator]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册