From a5103770240ceebb5e98f2d99c97f9a6038818b9 Mon Sep 17 00:00:00 2001 From: wenshilei Date: Sun, 10 Sep 2017 02:37:34 +0800 Subject: [PATCH] Fix ssd bugs. --- paddle/gserver/layers/DetectionOutputLayer.cpp | 1 + paddle/gserver/layers/DetectionUtil.cpp | 4 +++- paddle/gserver/layers/DetectionUtil.h | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/paddle/gserver/layers/DetectionOutputLayer.cpp b/paddle/gserver/layers/DetectionOutputLayer.cpp index 8ab838e1913..a1036ea866b 100644 --- a/paddle/gserver/layers/DetectionOutputLayer.cpp +++ b/paddle/gserver/layers/DetectionOutputLayer.cpp @@ -139,6 +139,7 @@ void DetectionOutputLayer::forward(PassType passType) { allDecodedBBoxes, &allIndices); + numKept = numKept > 0 ? numKept : 1; resetOutput(numKept, 7); MatrixPtr outV = getOutputValue(); getDetectionOutput(confBuffer_->getData(), diff --git a/paddle/gserver/layers/DetectionUtil.cpp b/paddle/gserver/layers/DetectionUtil.cpp index 3e61adc66e6..92c61930353 100644 --- a/paddle/gserver/layers/DetectionUtil.cpp +++ b/paddle/gserver/layers/DetectionUtil.cpp @@ -469,7 +469,7 @@ size_t getDetectionIndices( const size_t numClasses, const size_t backgroundId, const size_t batchSize, - const size_t confThreshold, + const real confThreshold, const size_t nmsTopK, const real nmsThreshold, const size_t keepTopK, @@ -536,6 +536,8 @@ void getDetectionOutput(const real* confData, MatrixPtr outBuffer; Matrix::resizeOrCreate(outBuffer, numKept, 7, false, false); real* bufferData = outBuffer->getData(); + for (size_t i = 0; i < 7; i++) + bufferData[i] = -1; size_t count = 0; for (size_t n = 0; n < batchSize; ++n) { for (map>::const_iterator it = allIndices[n].begin(); diff --git a/paddle/gserver/layers/DetectionUtil.h b/paddle/gserver/layers/DetectionUtil.h index fe4f9f075e4..641ed873b4c 100644 --- a/paddle/gserver/layers/DetectionUtil.h +++ b/paddle/gserver/layers/DetectionUtil.h @@ -275,7 +275,7 @@ size_t getDetectionIndices( const size_t numClasses, const size_t backgroundId, const size_t batchSize, - const size_t confThreshold, + const real confThreshold, const size_t nmsTopK, const real nmsThreshold, const size_t keepTopK, -- GitLab