提交 5f924d5d 编写于 作者: Y yangyaming

Follow comments.

上级 22076592
......@@ -99,3 +99,12 @@ value_printer
.. automodule:: paddle.v2.evaluator
:members: value_printer
:noindex:
Detection
=====
detection_map
-------------
.. automodule:: paddle.v2.evaluator
:members: detection_map
:noindex:
......@@ -80,21 +80,20 @@ public:
allGTBBoxes.push_back(bboxes);
}
size_t imgId = 0;
for (size_t n = 0; n < cpuOutput_->getHeight();) {
size_t n = 0;
const real* cpuOutputData = cpuOutput_->getData();
for (size_t imgId = 0; imgId < batchSize; ++imgId) {
map<size_t, vector<pair<real, NormalizedBBox>>> bboxes;
while (cpuOutput_->getData()[n * 7] == imgId &&
n < cpuOutput_->getHeight()) {
size_t curImgId = static_cast<size_t>((cpuOutputData + n * 7)[0]);
while (curImgId == imgId && n < cpuOutput_->getHeight()) {
vector<real> label;
vector<real> score;
vector<NormalizedBBox> bbox;
getBBoxFromDetectData(
cpuOutput_->getData() + n * 7, 1, label, score, bbox);
getBBoxFromDetectData(cpuOutputData + n * 7, 1, label, score, bbox);
bboxes[label[0]].push_back(make_pair(score[0], bbox[0]));
++n;
curImgId = static_cast<size_t>((cpuOutputData + n * 7)[0]);
}
++imgId;
if (imgId > batchSize) break;
allDetectBBoxes.push_back(bboxes);
}
......@@ -119,15 +118,14 @@ public:
}
// calcTFPos
calcTFPos(
batchSize, allGTBBoxes, allDetectBBoxes, &allTruePos_, &allFalsePos_);
calcTFPos(batchSize, allGTBBoxes, allDetectBBoxes);
return 0;
}
virtual void printStats(std::ostream& os) const {
real mAP = calcMAP();
os << "Detection mAP=" << mAP * 100;
os << "Detection mAP=" << mAP;
}
virtual void distributeEval(ParameterClient2* client) {
......@@ -138,9 +136,7 @@ protected:
void calcTFPos(const size_t batchSize,
const vector<map<size_t, vector<NormalizedBBox>>>& allGTBBoxes,
const vector<map<size_t, vector<pair<real, NormalizedBBox>>>>&
allDetectBBoxes,
map<size_t, vector<pair<real, size_t>>>* allTruePos,
map<size_t, vector<pair<real, size_t>>>* allFalsePos) {
allDetectBBoxes) {
for (size_t n = 0; n < allDetectBBoxes.size(); ++n) {
if (allGTBBoxes[n].size() == 0) {
for (map<size_t, vector<pair<real, NormalizedBBox>>>::const_iterator
......@@ -149,8 +145,8 @@ protected:
++it) {
size_t label = it->first;
for (size_t i = 0; i < it->second.size(); ++i) {
(*allTruePos)[label].push_back(make_pair(it->second[i].first, 0));
(*allFalsePos)[label].push_back(make_pair(it->second[i].first, 1));
allTruePos_[label].push_back(make_pair(it->second[i].first, 0));
allFalsePos_[label].push_back(make_pair(it->second[i].first, 1));
}
}
} else {
......@@ -162,9 +158,8 @@ protected:
vector<pair<real, NormalizedBBox>> predBBoxes = it->second;
if (allGTBBoxes[n].find(label) == allGTBBoxes[n].end()) {
for (size_t i = 0; i < predBBoxes.size(); ++i) {
(*allTruePos)[label].push_back(make_pair(predBBoxes[i].first, 0));
(*allFalsePos)[label].push_back(
make_pair(predBBoxes[i].first, 1));
allTruePos_[label].push_back(make_pair(predBBoxes[i].first, 0));
allFalsePos_[label].push_back(make_pair(predBBoxes[i].first, 1));
}
} else {
vector<NormalizedBBox> gtBBoxes =
......@@ -189,22 +184,21 @@ protected:
if (evaluateDifficult_ ||
(!evaluateDifficult_ && !gtBBoxes[maxIdx].isDifficult)) {
if (!visited[maxIdx]) {
(*allTruePos)[label].push_back(
allTruePos_[label].push_back(
make_pair(predBBoxes[i].first, 1));
(*allFalsePos)[label].push_back(
allFalsePos_[label].push_back(
make_pair(predBBoxes[i].first, 0));
visited[maxIdx] = true;
} else {
(*allTruePos)[label].push_back(
allTruePos_[label].push_back(
make_pair(predBBoxes[i].first, 0));
(*allFalsePos)[label].push_back(
allFalsePos_[label].push_back(
make_pair(predBBoxes[i].first, 1));
}
}
} else {
(*allTruePos)[label].push_back(
make_pair(predBBoxes[i].first, 0));
(*allFalsePos)[label].push_back(
allTruePos_[label].push_back(make_pair(predBBoxes[i].first, 0));
allFalsePos_[label].push_back(
make_pair(predBBoxes[i].first, 1));
}
}
......@@ -274,7 +268,7 @@ protected:
}
}
if (count != 0) mAP /= count;
return mAP;
return mAP * 100;
}
void getAccumulation(vector<pair<real, size_t>> inPairs,
......@@ -291,20 +285,22 @@ protected:
std::string getTypeImpl() const { return "detection_map"; }
real getValueImpl() const { return calcMAP() * 100; }
real getValueImpl() const { return calcMAP(); }
private:
real overlapThreshold_;
bool evaluateDifficult_;
size_t backgroundId_;
std::string apType_;
real overlapThreshold_; // overlap threshold when determining whether matched
bool evaluateDifficult_; // whether evaluate difficult ground truth
size_t backgroundId_; // class index of background
std::string apType_; // how to calculate mAP (Integral or 11point)
MatrixPtr cpuOutput_;
MatrixPtr cpuLabel_;
map<size_t, size_t> numPos_;
map<size_t, vector<pair<real, size_t>>> allTruePos_;
map<size_t, vector<pair<real, size_t>>> allFalsePos_;
map<size_t, size_t> numPos_; // counts of true objects each classification
map<size_t, vector<pair<real, size_t>>>
allTruePos_; // true positive prediction
map<size_t, vector<pair<real, size_t>>>
allFalsePos_; // false positive prediction
};
REGISTER_EVALUATOR(detection_map, DetectionMAPEvaluator);
......
......@@ -166,9 +166,9 @@ def detection_map_evaluator(input,
ap_type="11point",
name=None):
"""
Detection mAP Evaluator. It will print mean Average Precision for detection.
Detection mAP Evaluator. It will print mean Average Precision (mAP) for detection.
The detection mAP Evaluator according to the detection_output's output count
The detection mAP Evaluator based on the output of detection_output layer counts
the true positive and the false positive bbox and integral them to get the
mAP.
......@@ -186,7 +186,7 @@ def detection_map_evaluator(input,
:type overlap_threshold: float
:param background_id: The background class index.
:type background_id: int
:param evaluate_difficult: Wether evaluate a difficult ground truth.
:param evaluate_difficult: Whether evaluate a difficult ground truth.
:type evaluate_difficult: bool
"""
if not isinstance(input, list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册