提交 a5103770 编写于 作者: W wenshilei

Fix ssd bugs.

上级 0be34949
...@@ -139,6 +139,7 @@ void DetectionOutputLayer::forward(PassType passType) { ...@@ -139,6 +139,7 @@ void DetectionOutputLayer::forward(PassType passType) {
allDecodedBBoxes, allDecodedBBoxes,
&allIndices); &allIndices);
numKept = numKept > 0 ? numKept : 1;
resetOutput(numKept, 7); resetOutput(numKept, 7);
MatrixPtr outV = getOutputValue(); MatrixPtr outV = getOutputValue();
getDetectionOutput(confBuffer_->getData(), getDetectionOutput(confBuffer_->getData(),
......
...@@ -469,7 +469,7 @@ size_t getDetectionIndices( ...@@ -469,7 +469,7 @@ size_t getDetectionIndices(
const size_t numClasses, const size_t numClasses,
const size_t backgroundId, const size_t backgroundId,
const size_t batchSize, const size_t batchSize,
const size_t confThreshold, const real confThreshold,
const size_t nmsTopK, const size_t nmsTopK,
const real nmsThreshold, const real nmsThreshold,
const size_t keepTopK, const size_t keepTopK,
...@@ -536,6 +536,8 @@ void getDetectionOutput(const real* confData, ...@@ -536,6 +536,8 @@ void getDetectionOutput(const real* confData,
MatrixPtr outBuffer; MatrixPtr outBuffer;
Matrix::resizeOrCreate(outBuffer, numKept, 7, false, false); Matrix::resizeOrCreate(outBuffer, numKept, 7, false, false);
real* bufferData = outBuffer->getData(); real* bufferData = outBuffer->getData();
for (size_t i = 0; i < 7; i++)
bufferData[i] = -1;
size_t count = 0; size_t count = 0;
for (size_t n = 0; n < batchSize; ++n) { for (size_t n = 0; n < batchSize; ++n) {
for (map<size_t, vector<size_t>>::const_iterator it = allIndices[n].begin(); for (map<size_t, vector<size_t>>::const_iterator it = allIndices[n].begin();
......
...@@ -275,7 +275,7 @@ size_t getDetectionIndices( ...@@ -275,7 +275,7 @@ size_t getDetectionIndices(
const size_t numClasses, const size_t numClasses,
const size_t backgroundId, const size_t backgroundId,
const size_t batchSize, const size_t batchSize,
const size_t confThreshold, const real confThreshold,
const size_t nmsTopK, const size_t nmsTopK,
const real nmsThreshold, const real nmsThreshold,
const size_t keepTopK, const size_t keepTopK,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册