DetectionOutputLayer.cpp 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "DetectionOutputLayer.h"

namespace paddle {

REGISTER_LAYER(detection_output, DetectionOutputLayer);

bool DetectionOutputLayer::init(const LayerMap& layerMap,
                                const ParameterMap& parameterMap) {
  Layer::init(layerMap, parameterMap);
  auto& layerConf = config_.inputs(0).detection_output_conf();
  numClasses_ = layerConf.num_classes();
  inputNum_ = layerConf.input_num();
  nmsThreshold_ = layerConf.nms_threshold();
  confidenceThreshold_ = layerConf.confidence_threshold();
  nmsTopK_ = layerConf.nms_top_k();
  keepTopK_ = layerConf.keep_top_k();
  backgroundId_ = layerConf.background_id();
  return true;
}

void DetectionOutputLayer::forward(PassType passType) {
  Layer::forward(passType);
  size_t batchSize = getInputValue(*getLocInputLayer(0))->getHeight();

  locSizeSum_ = 0;
  confSizeSum_ = 0;
  for (size_t n = 0; n < inputNum_; ++n) {
    const MatrixPtr inLoc = getInputValue(*getLocInputLayer(n));
    const MatrixPtr inConf = getInputValue(*getConfInputLayer(n));
    locSizeSum_ += inLoc->getElementCnt();
    confSizeSum_ += inConf->getElementCnt();
  }

  Matrix::resizeOrCreate(locTmpBuffer_, 1, locSizeSum_, false, useGpu_);
  Matrix::resizeOrCreate(
      confTmpBuffer_, confSizeSum_ / numClasses_, numClasses_, false, useGpu_);

  size_t locOffset = 0;
  size_t confOffset = 0;
  auto& layerConf = config_.inputs(0).detection_output_conf();
  for (size_t n = 0; n < inputNum_; ++n) {
    const MatrixPtr inLoc = getInputValue(*getLocInputLayer(n));
    const MatrixPtr inConf = getInputValue(*getConfInputLayer(n));

    size_t height = getInput(*getLocInputLayer(n)).getFrameHeight();
    if (!height) height = layerConf.height();
    size_t width = getInput(*getLocInputLayer(n)).getFrameWidth();
    if (!width) width = layerConf.width();
    locOffset += appendWithPermute(*inLoc,
                                   height,
                                   width,
                                   locSizeSum_,
                                   locOffset,
                                   batchSize,
69
                                   *locTmpBuffer_,
70 71 72 73 74 75 76
                                   kNCHWToNHWC);
    confOffset += appendWithPermute(*inConf,
                                    height,
                                    width,
                                    confSizeSum_,
                                    confOffset,
                                    batchSize,
77
                                    *confTmpBuffer_,
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
                                    kNCHWToNHWC);
  }
  CHECK_EQ(locOffset, locSizeSum_ / batchSize);
  CHECK_EQ(confOffset, confSizeSum_ / batchSize);

  MatrixPtr priorValue;
  if (useGpu_) {
    Matrix::resizeOrCreate(locCpuBuffer_, 1, locSizeSum_, false, false);
    Matrix::resizeOrCreate(
        confCpuBuffer_, confSizeSum_ / numClasses_, numClasses_, false, false);
    MatrixPtr priorTmpValue = getInputValue(*getPriorBoxLayer());
    Matrix::resizeOrCreate(
        priorCpuValue_, 1, priorTmpValue->getElementCnt(), false, false);

    locCpuBuffer_->copyFrom(*locTmpBuffer_);
    confCpuBuffer_->copyFrom(*confTmpBuffer_);
    priorCpuValue_->copyFrom(*priorTmpValue);

    locBuffer_ = locCpuBuffer_;
    confBuffer_ = confCpuBuffer_;
    priorValue = priorCpuValue_;
  } else {
    priorValue = getInputValue(*getPriorBoxLayer());
101 102
    locBuffer_ = locTmpBuffer_;
    confBuffer_ = confTmpBuffer_;
103 104 105 106
  }
  confBuffer_->softmax(*confBuffer_);

  size_t numPriors = priorValue->getElementCnt() / 8;
107
  std::vector<std::vector<NormalizedBBox>> allDecodedBBoxes;
108
  for (size_t n = 0; n < batchSize; ++n) {
109
    std::vector<NormalizedBBox> decodedBBoxes;
110 111 112
    for (size_t i = 0; i < numPriors; ++i) {
      size_t priorOffset = i * 8;
      size_t locPredOffset = n * numPriors * 4 + i * 4;
113
      std::vector<NormalizedBBox> priorBBoxVec;
114 115
      getBBoxFromPriorData(
          priorValue->getData() + priorOffset, 1, priorBBoxVec);
116
      std::vector<std::vector<real>> priorBBoxVar;
117 118
      getBBoxVarFromPriorData(
          priorValue->getData() + priorOffset, 1, priorBBoxVar);
119
      std::vector<real> locPredData;
120 121 122 123 124 125 126 127 128
      for (size_t j = 0; j < 4; ++j)
        locPredData.push_back(*(locBuffer_->getData() + locPredOffset + j));
      NormalizedBBox bbox =
          decodeBBoxWithVar(priorBBoxVec[0], priorBBoxVar[0], locPredData);
      decodedBBoxes.push_back(bbox);
    }
    allDecodedBBoxes.push_back(decodedBBoxes);
  }

129
  std::vector<std::map<size_t, std::vector<size_t>>> allIndices;
130 131 132 133 134 135 136 137 138 139 140 141
  size_t numKept = getDetectionIndices(confBuffer_->getData(),
                                       numPriors,
                                       numClasses_,
                                       backgroundId_,
                                       batchSize,
                                       confidenceThreshold_,
                                       nmsTopK_,
                                       nmsThreshold_,
                                       keepTopK_,
                                       allDecodedBBoxes,
                                       &allIndices);

G
gaoyuan 已提交
142 143 144 145 146 147 148
  if (numKept > 0) {
    resetOutput(numKept, 7);
  } else {
    MatrixPtr outV = getOutputValue();
    outV = NULL;
    return;
  }
149 150 151 152 153 154 155 156 157 158 159 160
  MatrixPtr outV = getOutputValue();
  getDetectionOutput(confBuffer_->getData(),
                     numKept,
                     numPriors,
                     numClasses_,
                     batchSize,
                     allIndices,
                     allDecodedBBoxes,
                     *outV);
}

}  // namespace paddle