box_coder_op.cu 6.3 KB
Newer Older
G
gaoyuan 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
G
gaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/box_coder_op.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {

template <typename T>
__global__ void EncodeCenterSizeKernel(const T* prior_box_data,
                                       const T* prior_box_var_data,
                                       const T* target_box_data, int row,
                                       int col, T* output) {
  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < row * col) {
    const int row_idx = idx / col;
    const int col_idx = idx % col;
    T prior_box_width =
        prior_box_data[col_idx * 4 + 2] - prior_box_data[col_idx * 4];
    T prior_box_height =
        prior_box_data[col_idx * 4 + 3] - prior_box_data[col_idx * 4 + 1];
    T prior_box_center_x =
        (prior_box_data[col_idx * 4 + 2] + prior_box_data[col_idx * 4]) / 2;
    T prior_box_center_y =
        (prior_box_data[col_idx * 4 + 3] + prior_box_data[col_idx * 4 + 1]) / 2;

    T target_box_center_x =
        (target_box_data[row_idx * 4 + 2] + target_box_data[row_idx * 4]) / 2;
    T target_box_center_y =
        (target_box_data[row_idx * 4 + 3] + target_box_data[row_idx * 4 + 1]) /
        2;
    T target_box_width =
        target_box_data[row_idx * 4 + 2] - target_box_data[row_idx * 4];
    T target_box_height =
        target_box_data[row_idx * 4 + 3] - target_box_data[row_idx * 4 + 1];

    output[idx * 4] = (target_box_center_x - prior_box_center_x) /
                      prior_box_width / prior_box_var_data[col_idx * 4];
    output[idx * 4 + 1] = (target_box_center_y - prior_box_center_y) /
                          prior_box_height /
                          prior_box_var_data[col_idx * 4 + 1];
    output[idx * 4 + 2] = log(fabs(target_box_width / prior_box_width)) /
                          prior_box_var_data[col_idx * 4 + 2];
    output[idx * 4 + 3] = log(fabs(target_box_height / prior_box_height)) /
                          prior_box_var_data[col_idx * 4 + 3];
  }
}

template <typename T>
__global__ void DecodeCenterSizeKernel(const T* prior_box_data,
                                       const T* prior_box_var_data,
                                       const T* target_box_data, int row,
                                       int col, T* output) {
  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < row * col) {
    const int row_idx = idx / col;
    const int col_idx = idx % col;
    T prior_box_width =
        prior_box_data[col_idx * 4 + 2] - prior_box_data[col_idx * 4];
    T prior_box_height =
        prior_box_data[col_idx * 4 + 3] - prior_box_data[col_idx * 4 + 1];
    T prior_box_center_x =
        (prior_box_data[col_idx * 4 + 2] + prior_box_data[col_idx * 4]) / 2;
    T prior_box_center_y =
        (prior_box_data[col_idx * 4 + 3] + prior_box_data[col_idx * 4 + 1]) / 2;

    T target_box_width = exp(prior_box_var_data[col_idx * 4 + 2] *
                             target_box_data[row_idx * 4 + 2]) *
                         prior_box_width;
    T target_box_height = exp(prior_box_var_data[col_idx * 4 + 3] *
                              target_box_data[row_idx * 4 + 3]) *
                          prior_box_height;
    T target_box_center_x = prior_box_var_data[col_idx * 4] *
                                target_box_data[row_idx * 4] * prior_box_width +
                            prior_box_center_x;
    T target_box_center_y = prior_box_var_data[col_idx * 4 + 1] *
                                target_box_data[row_idx * 4 + 1] *
                                prior_box_height +
                            prior_box_center_y;

    output[idx * 4] = target_box_center_x - target_box_width / 2;
    output[idx * 4 + 1] = target_box_center_y - target_box_height / 2;
    output[idx * 4 + 2] = target_box_center_x + target_box_width / 2;
    output[idx * 4 + 3] = target_box_center_y + target_box_height / 2;
  }
}

template <typename T>
class BoxCoderCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
    auto* prior_box = context.Input<framework::Tensor>("PriorBox");
    auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
    auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
    auto* output_box = context.Output<Tensor>("OutputBox");

    if (target_box->lod().size()) {
G
gaoyuan 已提交
109
      PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
G
gaoyuan 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
                        "Only support 1 level of LoD.");
    }
    auto row = target_box->dims()[0];
    auto col = prior_box->dims()[0];
    int block = 512;
    int grid = (row * col + block - 1) / block;
    auto& device_ctx = context.cuda_device_context();

    const T* prior_box_data = prior_box->data<T>();
    const T* prior_box_var_data = prior_box_var->data<T>();
    const T* target_box_data = target_box->data<T>();

    output_box->mutable_data<T>({row, col, 4}, context.GetPlace());
    T* output = output_box->data<T>();

    auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
    if (code_type == BoxCodeType::kEncodeCenterSize) {
      EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
          prior_box_data, prior_box_var_data, target_box_data, row, col,
          output);
    } else if (code_type == BoxCodeType::kDecodeCenterSize) {
      DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
          prior_box_data, prior_box_var_data, target_box_data, row, col,
          output);
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(box_coder, ops::BoxCoderCUDAKernel<float>,
                        ops::BoxCoderCUDAKernel<double>);