提交 5f924d5d 编写于 作者: Y yangyaming

Follow comments.

上级 22076592
...@@ -99,3 +99,12 @@ value_printer ...@@ -99,3 +99,12 @@ value_printer
.. automodule:: paddle.v2.evaluator .. automodule:: paddle.v2.evaluator
:members: value_printer :members: value_printer
:noindex: :noindex:
Detection
=====
detection_map
-------------
.. automodule:: paddle.v2.evaluator
:members: detection_map
:noindex:
...@@ -80,21 +80,20 @@ public: ...@@ -80,21 +80,20 @@ public:
allGTBBoxes.push_back(bboxes); allGTBBoxes.push_back(bboxes);
} }
size_t imgId = 0; size_t n = 0;
for (size_t n = 0; n < cpuOutput_->getHeight();) { const real* cpuOutputData = cpuOutput_->getData();
for (size_t imgId = 0; imgId < batchSize; ++imgId) {
map<size_t, vector<pair<real, NormalizedBBox>>> bboxes; map<size_t, vector<pair<real, NormalizedBBox>>> bboxes;
while (cpuOutput_->getData()[n * 7] == imgId && size_t curImgId = static_cast<size_t>((cpuOutputData + n * 7)[0]);
n < cpuOutput_->getHeight()) { while (curImgId == imgId && n < cpuOutput_->getHeight()) {
vector<real> label; vector<real> label;
vector<real> score; vector<real> score;
vector<NormalizedBBox> bbox; vector<NormalizedBBox> bbox;
getBBoxFromDetectData( getBBoxFromDetectData(cpuOutputData + n * 7, 1, label, score, bbox);
cpuOutput_->getData() + n * 7, 1, label, score, bbox);
bboxes[label[0]].push_back(make_pair(score[0], bbox[0])); bboxes[label[0]].push_back(make_pair(score[0], bbox[0]));
++n; ++n;
curImgId = static_cast<size_t>((cpuOutputData + n * 7)[0]);
} }
++imgId;
if (imgId > batchSize) break;
allDetectBBoxes.push_back(bboxes); allDetectBBoxes.push_back(bboxes);
} }
...@@ -119,15 +118,14 @@ public: ...@@ -119,15 +118,14 @@ public:
} }
// calcTFPos // calcTFPos
calcTFPos( calcTFPos(batchSize, allGTBBoxes, allDetectBBoxes);
batchSize, allGTBBoxes, allDetectBBoxes, &allTruePos_, &allFalsePos_);
return 0; return 0;
} }
virtual void printStats(std::ostream& os) const { virtual void printStats(std::ostream& os) const {
real mAP = calcMAP(); real mAP = calcMAP();
os << "Detection mAP=" << mAP * 100; os << "Detection mAP=" << mAP;
} }
virtual void distributeEval(ParameterClient2* client) { virtual void distributeEval(ParameterClient2* client) {
...@@ -138,9 +136,7 @@ protected: ...@@ -138,9 +136,7 @@ protected:
void calcTFPos(const size_t batchSize, void calcTFPos(const size_t batchSize,
const vector<map<size_t, vector<NormalizedBBox>>>& allGTBBoxes, const vector<map<size_t, vector<NormalizedBBox>>>& allGTBBoxes,
const vector<map<size_t, vector<pair<real, NormalizedBBox>>>>& const vector<map<size_t, vector<pair<real, NormalizedBBox>>>>&
allDetectBBoxes, allDetectBBoxes) {
map<size_t, vector<pair<real, size_t>>>* allTruePos,
map<size_t, vector<pair<real, size_t>>>* allFalsePos) {
for (size_t n = 0; n < allDetectBBoxes.size(); ++n) { for (size_t n = 0; n < allDetectBBoxes.size(); ++n) {
if (allGTBBoxes[n].size() == 0) { if (allGTBBoxes[n].size() == 0) {
for (map<size_t, vector<pair<real, NormalizedBBox>>>::const_iterator for (map<size_t, vector<pair<real, NormalizedBBox>>>::const_iterator
...@@ -149,8 +145,8 @@ protected: ...@@ -149,8 +145,8 @@ protected:
++it) { ++it) {
size_t label = it->first; size_t label = it->first;
for (size_t i = 0; i < it->second.size(); ++i) { for (size_t i = 0; i < it->second.size(); ++i) {
(*allTruePos)[label].push_back(make_pair(it->second[i].first, 0)); allTruePos_[label].push_back(make_pair(it->second[i].first, 0));
(*allFalsePos)[label].push_back(make_pair(it->second[i].first, 1)); allFalsePos_[label].push_back(make_pair(it->second[i].first, 1));
} }
} }
} else { } else {
...@@ -162,9 +158,8 @@ protected: ...@@ -162,9 +158,8 @@ protected:
vector<pair<real, NormalizedBBox>> predBBoxes = it->second; vector<pair<real, NormalizedBBox>> predBBoxes = it->second;
if (allGTBBoxes[n].find(label) == allGTBBoxes[n].end()) { if (allGTBBoxes[n].find(label) == allGTBBoxes[n].end()) {
for (size_t i = 0; i < predBBoxes.size(); ++i) { for (size_t i = 0; i < predBBoxes.size(); ++i) {
(*allTruePos)[label].push_back(make_pair(predBBoxes[i].first, 0)); allTruePos_[label].push_back(make_pair(predBBoxes[i].first, 0));
(*allFalsePos)[label].push_back( allFalsePos_[label].push_back(make_pair(predBBoxes[i].first, 1));
make_pair(predBBoxes[i].first, 1));
} }
} else { } else {
vector<NormalizedBBox> gtBBoxes = vector<NormalizedBBox> gtBBoxes =
...@@ -189,22 +184,21 @@ protected: ...@@ -189,22 +184,21 @@ protected:
if (evaluateDifficult_ || if (evaluateDifficult_ ||
(!evaluateDifficult_ && !gtBBoxes[maxIdx].isDifficult)) { (!evaluateDifficult_ && !gtBBoxes[maxIdx].isDifficult)) {
if (!visited[maxIdx]) { if (!visited[maxIdx]) {
(*allTruePos)[label].push_back( allTruePos_[label].push_back(
make_pair(predBBoxes[i].first, 1)); make_pair(predBBoxes[i].first, 1));
(*allFalsePos)[label].push_back( allFalsePos_[label].push_back(
make_pair(predBBoxes[i].first, 0)); make_pair(predBBoxes[i].first, 0));
visited[maxIdx] = true; visited[maxIdx] = true;
} else { } else {
(*allTruePos)[label].push_back( allTruePos_[label].push_back(
make_pair(predBBoxes[i].first, 0)); make_pair(predBBoxes[i].first, 0));
(*allFalsePos)[label].push_back( allFalsePos_[label].push_back(
make_pair(predBBoxes[i].first, 1)); make_pair(predBBoxes[i].first, 1));
} }
} }
} else { } else {
(*allTruePos)[label].push_back( allTruePos_[label].push_back(make_pair(predBBoxes[i].first, 0));
make_pair(predBBoxes[i].first, 0)); allFalsePos_[label].push_back(
(*allFalsePos)[label].push_back(
make_pair(predBBoxes[i].first, 1)); make_pair(predBBoxes[i].first, 1));
} }
} }
...@@ -274,7 +268,7 @@ protected: ...@@ -274,7 +268,7 @@ protected:
} }
} }
if (count != 0) mAP /= count; if (count != 0) mAP /= count;
return mAP; return mAP * 100;
} }
void getAccumulation(vector<pair<real, size_t>> inPairs, void getAccumulation(vector<pair<real, size_t>> inPairs,
...@@ -291,20 +285,22 @@ protected: ...@@ -291,20 +285,22 @@ protected:
std::string getTypeImpl() const { return "detection_map"; } std::string getTypeImpl() const { return "detection_map"; }
real getValueImpl() const { return calcMAP() * 100; } real getValueImpl() const { return calcMAP(); }
private: private:
real overlapThreshold_; real overlapThreshold_; // overlap threshold when determining whether matched
bool evaluateDifficult_; bool evaluateDifficult_; // whether evaluate difficult ground truth
size_t backgroundId_; size_t backgroundId_; // class index of background
std::string apType_; std::string apType_; // how to calculate mAP (Integral or 11point)
MatrixPtr cpuOutput_; MatrixPtr cpuOutput_;
MatrixPtr cpuLabel_; MatrixPtr cpuLabel_;
map<size_t, size_t> numPos_; map<size_t, size_t> numPos_; // counts of true objects each classification
map<size_t, vector<pair<real, size_t>>> allTruePos_; map<size_t, vector<pair<real, size_t>>>
map<size_t, vector<pair<real, size_t>>> allFalsePos_; allTruePos_; // true positive prediction
map<size_t, vector<pair<real, size_t>>>
allFalsePos_; // false positive prediction
}; };
REGISTER_EVALUATOR(detection_map, DetectionMAPEvaluator); REGISTER_EVALUATOR(detection_map, DetectionMAPEvaluator);
......
...@@ -166,9 +166,9 @@ def detection_map_evaluator(input, ...@@ -166,9 +166,9 @@ def detection_map_evaluator(input,
ap_type="11point", ap_type="11point",
name=None): name=None):
""" """
Detection mAP Evaluator. It will print mean Average Precision for detection. Detection mAP Evaluator. It will print mean Average Precision (mAP) for detection.
The detection mAP Evaluator according to the detection_output's output count The detection mAP Evaluator based on the output of detection_output layer counts
the true positive and the false positive bbox and integral them to get the the true positive and the false positive bbox and integral them to get the
mAP. mAP.
...@@ -186,7 +186,7 @@ def detection_map_evaluator(input, ...@@ -186,7 +186,7 @@ def detection_map_evaluator(input,
:type overlap_threshold: float :type overlap_threshold: float
:param background_id: The background class index. :param background_id: The background class index.
:type background_id: int :type background_id: int
:param evaluate_difficult: Wether evaluate a difficult ground truth. :param evaluate_difficult: Whether evaluate a difficult ground truth.
:type evaluate_difficult: bool :type evaluate_difficult: bool
""" """
if not isinstance(input, list): if not isinstance(input, list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册