box_coder_op.cu 8.6 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
G
gaoyuan 已提交
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

B
baiyf 已提交
12
#include "paddle/fluid/operators/detection/box_coder_op.h"
D
dzhwinter 已提交
13
#include "paddle/fluid/platform/cuda_primitives.h"
G
gaoyuan 已提交
14 15 16 17 18 19 20

namespace paddle {
namespace operators {

template <typename T>
__global__ void EncodeCenterSizeKernel(const T* prior_box_data,
                                       const T* prior_box_var_data,
G
gaoyuan 已提交
21 22
                                       const T* target_box_data, const int row,
                                       const int col, const int len,
J
jerrywgz 已提交
23 24
                                       const bool normalized,
                                       const T prior_box_var_size, T* output) {
G
gaoyuan 已提交
25 26 27 28
  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < row * col) {
    const int row_idx = idx / col;
    const int col_idx = idx % col;
29 30 31 32 33
    T prior_box_width = prior_box_data[col_idx * len + 2] -
                        prior_box_data[col_idx * len] + (normalized == false);
    T prior_box_height = prior_box_data[col_idx * len + 3] -
                         prior_box_data[col_idx * len + 1] +
                         (normalized == false);
J
jerrywgz 已提交
34 35 36
    T prior_box_center_x = prior_box_data[col_idx * len] + prior_box_width / 2;
    T prior_box_center_y =
        prior_box_data[col_idx * len + 1] + prior_box_height / 2;
G
gaoyuan 已提交
37 38

    T target_box_center_x =
G
gaoyuan 已提交
39
        (target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) /
G
gaoyuan 已提交
40
        2;
G
gaoyuan 已提交
41 42 43
    T target_box_center_y = (target_box_data[row_idx * len + 3] +
                             target_box_data[row_idx * len + 1]) /
                            2;
44 45 46 47 48
    T target_box_width = target_box_data[row_idx * len + 2] -
                         target_box_data[row_idx * len] + (normalized == false);
    T target_box_height = target_box_data[row_idx * len + 3] -
                          target_box_data[row_idx * len + 1] +
                          (normalized == false);
G
gaoyuan 已提交
49

50 51 52 53 54 55 56
    output[idx * len] =
        (target_box_center_x - prior_box_center_x) / prior_box_width;
    output[idx * len + 1] =
        (target_box_center_y - prior_box_center_y) / prior_box_height;
    output[idx * len + 2] = log(fabs(target_box_width / prior_box_width));
    output[idx * len + 3] = log(fabs(target_box_height / prior_box_height));
    if (prior_box_var_data) {
J
jerrywgz 已提交
57 58 59 60 61 62 63 64
      int prior_var_offset = 0;
      if (prior_box_var_size == 2) {
        prior_var_offset = col_idx * len;
      }
      output[idx * len] /= prior_box_var_data[prior_var_offset];
      output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1];
      output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2];
      output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3];
65
    }
G
gaoyuan 已提交
66 67 68 69 70 71
  }
}

template <typename T>
__global__ void DecodeCenterSizeKernel(const T* prior_box_data,
                                       const T* prior_box_var_data,
G
gaoyuan 已提交
72 73
                                       const T* target_box_data, const int row,
                                       const int col, const int len,
J
jerrywgz 已提交
74 75 76
                                       const bool normalized,
                                       const T prior_box_var_size,
                                       const int axis, T* output) {
G
gaoyuan 已提交
77
  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
J
jerrywgz 已提交
78
  int prior_box_offset = 0;
G
gaoyuan 已提交
79 80
  if (idx < row * col) {
    const int col_idx = idx % col;
J
jerrywgz 已提交
81
    const int row_idx = idx / col;
82
    prior_box_offset = axis == 0 ? col_idx * len : row_idx * len;
J
jerrywgz 已提交
83 84 85 86 87
    T prior_box_width = prior_box_data[prior_box_offset + 2] -
                        prior_box_data[prior_box_offset] +
                        (normalized == false);
    T prior_box_height = prior_box_data[prior_box_offset + 3] -
                         prior_box_data[prior_box_offset + 1] +
88
                         (normalized == false);
G
gaoyuan 已提交
89
    T prior_box_center_x =
J
jerrywgz 已提交
90 91 92
        prior_box_data[prior_box_offset] + prior_box_width / 2;
    T prior_box_center_y =
        prior_box_data[prior_box_offset + 1] + prior_box_height / 2;
93 94 95
    T target_box_width, target_box_height;
    T target_box_center_x, target_box_center_y;
    if (prior_box_var_data) {
J
jerrywgz 已提交
96 97
      int prior_var_offset = 0;
      if (prior_box_var_size == 2) {
98
        prior_var_offset = axis == 0 ? col_idx * len : row_idx * len;
J
jerrywgz 已提交
99 100
      }
      target_box_width = exp(prior_box_var_data[prior_var_offset + 2] *
Y
Yuan Gao 已提交
101
                             target_box_data[idx * len + 2]) *
G
gaoyuan 已提交
102
                         prior_box_width;
J
jerrywgz 已提交
103
      target_box_height = exp(prior_box_var_data[prior_var_offset + 3] *
Y
Yuan Gao 已提交
104
                              target_box_data[idx * len + 3]) *
G
gaoyuan 已提交
105
                          prior_box_height;
J
jerrywgz 已提交
106
      target_box_center_x = prior_box_var_data[prior_var_offset] *
Y
Yuan Gao 已提交
107
                                target_box_data[idx * len] * prior_box_width +
G
gaoyuan 已提交
108
                            prior_box_center_x;
J
jerrywgz 已提交
109
      target_box_center_y = prior_box_var_data[prior_var_offset + 1] *
Y
Yuan Gao 已提交
110
                                target_box_data[idx * len + 1] *
G
gaoyuan 已提交
111 112
                                prior_box_height +
                            prior_box_center_y;
113 114 115 116 117 118 119 120 121
    } else {
      target_box_width = exp(target_box_data[idx * len + 2]) * prior_box_width;
      target_box_height =
          exp(target_box_data[idx * len + 3]) * prior_box_height;
      target_box_center_x =
          target_box_data[idx * len] * prior_box_width + prior_box_center_x;
      target_box_center_y = target_box_data[idx * len + 1] * prior_box_height +
                            prior_box_center_y;
    }
G
gaoyuan 已提交
122

G
gaoyuan 已提交
123 124
    output[idx * len] = target_box_center_x - target_box_width / 2;
    output[idx * len + 1] = target_box_center_y - target_box_height / 2;
125 126 127 128
    output[idx * len + 2] =
        target_box_center_x + target_box_width / 2 - (normalized == false);
    output[idx * len + 3] =
        target_box_center_y + target_box_height / 2 - (normalized == false);
G
gaoyuan 已提交
129 130 131
  }
}

132
template <typename DeviceContext, typename T>
G
gaoyuan 已提交
133 134 135 136 137 138 139 140
class BoxCoderCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
    auto* prior_box = context.Input<framework::Tensor>("PriorBox");
    auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
    auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
G
gaoyuan 已提交
141
    auto* output_box = context.Output<framework::Tensor>("OutputBox");
G
gaoyuan 已提交
142

143 144 145
    const T* prior_box_data = prior_box->data<T>();
    const T* target_box_data = target_box->data<T>();
    const T* prior_box_var_data = nullptr;
J
jerrywgz 已提交
146 147 148 149 150
    auto prior_box_var_size = 0;
    if (prior_box_var) {
      prior_box_var_data = prior_box_var->data<T>();
      prior_box_var_size = prior_box_var->dims().size();
    }
151

G
gaoyuan 已提交
152
    if (target_box->lod().size()) {
G
gaoyuan 已提交
153
      PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
G
gaoyuan 已提交
154 155
                        "Only support 1 level of LoD.");
    }
J
jerrywgz 已提交
156 157 158 159
    auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
    bool normalized = context.Attr<bool>("box_normalized");
    int axis = context.Attr<int>("axis");

G
gaoyuan 已提交
160 161
    auto row = target_box->dims()[0];
    auto col = prior_box->dims()[0];
J
jerrywgz 已提交
162 163 164
    if (code_type == BoxCodeType::kDecodeCenterSize) {
      col = target_box->dims()[1];
    }
G
gaoyuan 已提交
165
    auto len = prior_box->dims()[1];
G
gaoyuan 已提交
166 167 168 169
    int block = 512;
    int grid = (row * col + block - 1) / block;
    auto& device_ctx = context.cuda_device_context();

G
gaoyuan 已提交
170
    output_box->mutable_data<T>({row, col, len}, context.GetPlace());
G
gaoyuan 已提交
171 172 173 174
    T* output = output_box->data<T>();

    if (code_type == BoxCodeType::kEncodeCenterSize) {
      EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
G
gaoyuan 已提交
175
          prior_box_data, prior_box_var_data, target_box_data, row, col, len,
J
jerrywgz 已提交
176
          normalized, prior_box_var_size, output);
G
gaoyuan 已提交
177 178
    } else if (code_type == BoxCodeType::kDecodeCenterSize) {
      DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
G
gaoyuan 已提交
179
          prior_box_data, prior_box_var_data, target_box_data, row, col, len,
J
jerrywgz 已提交
180
          normalized, prior_box_var_size, axis, output);
G
gaoyuan 已提交
181 182 183 184 185 186 187 188
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
189 190 191 192
REGISTER_OP_CUDA_KERNEL(
    box_coder,
    ops::BoxCoderCUDAKernel<paddle::platform::CUDADeviceContext, float>,
    ops::BoxCoderCUDAKernel<paddle::platform::CUDADeviceContext, double>);