collect_fpn_proposals_op.cu 9.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
  Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12
#ifdef __NVCC__
13
#include "cub/cub.cuh"
14 15 16 17 18 19
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
#endif

#include <paddle/fluid/memory/allocation/allocator.h>
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/for_range.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

static constexpr int kNumCUDAThreads = 64;
static constexpr int kNumMaxinumNumBlocks = 4096;

const int kBBoxSize = 4;

static inline int NumBlocks(const int N) {
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
                  kNumMaxinumNumBlocks);
}

static __global__ void GetLengthLoD(const int nthreads, const int* batch_ids,
                                    int* length_lod) {
49
  CUDA_KERNEL_LOOP(i, nthreads) {
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    platform::CudaAtomicAdd(length_lod + batch_ids[i], 1);
  }
}

template <typename DeviceContext, typename T>
class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    const auto roi_ins = ctx.MultiInput<LoDTensor>("MultiLevelRois");
    const auto score_ins = ctx.MultiInput<LoDTensor>("MultiLevelScores");
    auto fpn_rois = ctx.Output<LoDTensor>("FpnRois");
    auto& dev_ctx = ctx.template device_context<DeviceContext>();

    const int post_nms_topN = ctx.Attr<int>("post_nms_topN");

    // concat inputs along axis = 0
    int roi_offset = 0;
    int score_offset = 0;
    int total_roi_num = 0;
    for (size_t i = 0; i < roi_ins.size(); ++i) {
      total_roi_num += roi_ins[i]->dims()[0];
    }

    int real_post_num = min(post_nms_topN, total_roi_num);
    fpn_rois->mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace());
    Tensor concat_rois;
    Tensor concat_scores;
    T* concat_rois_data = concat_rois.mutable_data<T>(
        {total_roi_num, kBBoxSize}, dev_ctx.GetPlace());
    T* concat_scores_data =
        concat_scores.mutable_data<T>({total_roi_num, 1}, dev_ctx.GetPlace());
    Tensor roi_batch_id_list;
    roi_batch_id_list.Resize({total_roi_num});
    int* roi_batch_id_data =
        roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
    int index = 0;
    int lod_size;
87
    auto place = BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace());
88

89
    auto multi_rois_num = ctx.MultiInput<Tensor>("MultiLevelRoIsNum");
90 91 92
    for (size_t i = 0; i < roi_ins.size(); ++i) {
      auto roi_in = roi_ins[i];
      auto score_in = score_ins[i];
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
      if (multi_rois_num.size() > 0) {
        framework::Tensor temp;
        TensorCopySync(*multi_rois_num[i], platform::CPUPlace(), &temp);
        const int* length_in = temp.data<int>();
        lod_size = multi_rois_num[i]->numel();
        for (size_t n = 0; n < lod_size; ++n) {
          for (size_t j = 0; j < length_in[n]; ++j) {
            roi_batch_id_data[index++] = n;
          }
        }
      } else {
        auto length_in = roi_in->lod().back();
        lod_size = length_in.size() - 1;
        for (size_t n = 0; n < lod_size; ++n) {
          for (size_t j = length_in[n]; j < length_in[n + 1]; ++j) {
            roi_batch_id_data[index++] = n;
          }
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
        }
      }

      memory::Copy(place, concat_rois_data + roi_offset, place,
                   roi_in->data<T>(), roi_in->numel() * sizeof(T),
                   dev_ctx.stream());
      memory::Copy(place, concat_scores_data + score_offset, place,
                   score_in->data<T>(), score_in->numel() * sizeof(T),
                   dev_ctx.stream());
      roi_offset += roi_in->numel();
      score_offset += score_in->numel();
    }

    // copy batch id list to GPU
    Tensor roi_batch_id_list_gpu;
    framework::TensorCopy(roi_batch_id_list, dev_ctx.GetPlace(),
                          &roi_batch_id_list_gpu);

    Tensor index_in_t;
    int* idx_in =
        index_in_t.mutable_data<int>({total_roi_num}, dev_ctx.GetPlace());
    platform::ForRange<platform::CUDADeviceContext> for_range_total(
        dev_ctx, total_roi_num);
    for_range_total(RangeInitFunctor{0, 1, idx_in});

    Tensor keys_out_t;
    T* keys_out =
        keys_out_t.mutable_data<T>({total_roi_num}, dev_ctx.GetPlace());
    Tensor index_out_t;
    int* idx_out =
        index_out_t.mutable_data<int>({total_roi_num}, dev_ctx.GetPlace());

    // Determine temporary device storage requirements
    size_t temp_storage_bytes = 0;
144 145 146 147 148
#ifdef PADDLE_WITH_HIP
    hipcub::DeviceRadixSort::SortPairsDescending<T, int>(
        nullptr, temp_storage_bytes, concat_scores.data<T>(), keys_out, idx_in,
        idx_out, total_roi_num);
#else
149 150 151
    cub::DeviceRadixSort::SortPairsDescending<T, int>(
        nullptr, temp_storage_bytes, concat_scores.data<T>(), keys_out, idx_in,
        idx_out, total_roi_num);
152
#endif
153
    // Allocate temporary storage
154
    auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
155

156 157 158 159 160 161 162
// Run sorting operation
// sort score to get corresponding index
#ifdef PADDLE_WITH_HIP
    hipcub::DeviceRadixSort::SortPairsDescending<T, int>(
        d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data<T>(),
        keys_out, idx_in, idx_out, total_roi_num);
#else
163 164 165
    cub::DeviceRadixSort::SortPairsDescending<T, int>(
        d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data<T>(),
        keys_out, idx_in, idx_out, total_roi_num);
166
#endif
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
    index_out_t.Resize({real_post_num});
    Tensor sorted_rois;
    sorted_rois.mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace());
    Tensor sorted_batch_id;
    sorted_batch_id.mutable_data<int>({real_post_num}, dev_ctx.GetPlace());
    GPUGather<T>(dev_ctx, concat_rois, index_out_t, &sorted_rois);
    GPUGather<int>(dev_ctx, roi_batch_id_list_gpu, index_out_t,
                   &sorted_batch_id);

    Tensor batch_index_t;
    int* batch_idx_in =
        batch_index_t.mutable_data<int>({real_post_num}, dev_ctx.GetPlace());
    platform::ForRange<platform::CUDADeviceContext> for_range_post(
        dev_ctx, real_post_num);
    for_range_post(RangeInitFunctor{0, 1, batch_idx_in});

    Tensor out_id_t;
    int* out_id_data =
        out_id_t.mutable_data<int>({real_post_num}, dev_ctx.GetPlace());
    // Determine temporary device storage requirements
    temp_storage_bytes = 0;
188 189 190 191 192
#ifdef PADDLE_WITH_HIP
    hipcub::DeviceRadixSort::SortPairs<int, int>(
        nullptr, temp_storage_bytes, sorted_batch_id.data<int>(), out_id_data,
        batch_idx_in, index_out_t.data<int>(), real_post_num);
#else
193 194 195
    cub::DeviceRadixSort::SortPairs<int, int>(
        nullptr, temp_storage_bytes, sorted_batch_id.data<int>(), out_id_data,
        batch_idx_in, index_out_t.data<int>(), real_post_num);
196
#endif
197
    // Allocate temporary storage
198
    d_temp_storage = memory::Alloc(place, temp_storage_bytes);
199

200 201 202 203 204 205 206
// Run sorting operation
// sort batch_id to get corresponding index
#ifdef PADDLE_WITH_HIP
    hipcub::DeviceRadixSort::SortPairs<int, int>(
        d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data<int>(),
        out_id_data, batch_idx_in, index_out_t.data<int>(), real_post_num);
#else
207 208 209
    cub::DeviceRadixSort::SortPairs<int, int>(
        d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data<int>(),
        out_id_data, batch_idx_in, index_out_t.data<int>(), real_post_num);
210
#endif
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235

    GPUGather<T>(dev_ctx, sorted_rois, index_out_t, fpn_rois);

    Tensor length_lod;
    int* length_lod_data =
        length_lod.mutable_data<int>({lod_size}, dev_ctx.GetPlace());
    math::SetConstant<platform::CUDADeviceContext, int> set_zero;
    set_zero(dev_ctx, &length_lod, static_cast<int>(0));

    int blocks = NumBlocks(real_post_num);
    int threads = kNumCUDAThreads;

    // get length-based lod by batch ids
    GetLengthLoD<<<blocks, threads>>>(real_post_num, out_id_data,
                                      length_lod_data);
    std::vector<int> length_lod_cpu(lod_size);
    memory::Copy(platform::CPUPlace(), length_lod_cpu.data(), place,
                 length_lod_data, sizeof(int) * lod_size, dev_ctx.stream());
    dev_ctx.Wait();

    std::vector<size_t> offset(1, 0);
    for (int i = 0; i < lod_size; ++i) {
      offset.emplace_back(offset.back() + length_lod_cpu[i]);
    }

236 237 238 239 240 241 242
    if (ctx.HasOutput("RoisNum")) {
      auto* rois_num = ctx.Output<Tensor>("RoisNum");
      int* rois_num_data = rois_num->mutable_data<int>({lod_size}, place);
      memory::Copy(place, rois_num_data, place, length_lod_data,
                   lod_size * sizeof(int), dev_ctx.stream());
    }

243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
    framework::LoD lod;
    lod.emplace_back(offset);
    fpn_rois->set_lod(lod);
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    collect_fpn_proposals,
    ops::GPUCollectFpnProposalsOpKernel<paddle::platform::CUDADeviceContext,
                                        float>,
    ops::GPUCollectFpnProposalsOpKernel<paddle::platform::CUDADeviceContext,
                                        double>);