box_coder_op.cu 9.9 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
G
gaoyuan 已提交
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12 13
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
B
baiyf 已提交
14
#include "paddle/fluid/operators/detection/box_coder_op.h"
D
dzhwinter 已提交
15
#include "paddle/fluid/platform/cuda_primitives.h"
G
gaoyuan 已提交
16 17 18 19 20

namespace paddle {
namespace operators {

template <typename T>
21 22 23 24 25
__global__ void EncodeCenterSizeKernel(
    const T* prior_box_data, const T* prior_box_var_data,
    const T* target_box_data, const int row, const int col, const int len,
    const bool normalized, const T prior_box_var_size, const float* variance,
    const int var_size, T* output) {
G
gaoyuan 已提交
26 27 28 29
  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < row * col) {
    const int row_idx = idx / col;
    const int col_idx = idx % col;
30 31 32 33 34
    T prior_box_width = prior_box_data[col_idx * len + 2] -
                        prior_box_data[col_idx * len] + (normalized == false);
    T prior_box_height = prior_box_data[col_idx * len + 3] -
                         prior_box_data[col_idx * len + 1] +
                         (normalized == false);
J
jerrywgz 已提交
35 36 37
    T prior_box_center_x = prior_box_data[col_idx * len] + prior_box_width / 2;
    T prior_box_center_y =
        prior_box_data[col_idx * len + 1] + prior_box_height / 2;
G
gaoyuan 已提交
38 39

    T target_box_center_x =
G
gaoyuan 已提交
40
        (target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) /
G
gaoyuan 已提交
41
        2;
G
gaoyuan 已提交
42 43 44
    T target_box_center_y = (target_box_data[row_idx * len + 3] +
                             target_box_data[row_idx * len + 1]) /
                            2;
45 46 47 48 49
    T target_box_width = target_box_data[row_idx * len + 2] -
                         target_box_data[row_idx * len] + (normalized == false);
    T target_box_height = target_box_data[row_idx * len + 3] -
                          target_box_data[row_idx * len + 1] +
                          (normalized == false);
G
gaoyuan 已提交
50

51 52 53 54 55 56 57
    output[idx * len] =
        (target_box_center_x - prior_box_center_x) / prior_box_width;
    output[idx * len + 1] =
        (target_box_center_y - prior_box_center_y) / prior_box_height;
    output[idx * len + 2] = log(fabs(target_box_width / prior_box_width));
    output[idx * len + 3] = log(fabs(target_box_height / prior_box_height));
    if (prior_box_var_data) {
J
jerrywgz 已提交
58 59 60 61 62 63 64 65
      int prior_var_offset = 0;
      if (prior_box_var_size == 2) {
        prior_var_offset = col_idx * len;
      }
      output[idx * len] /= prior_box_var_data[prior_var_offset];
      output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1];
      output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2];
      output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3];
66 67 68 69
    } else if (var_size == 4) {
      for (int k = 0; k < 4; ++k) {
        output[idx * len + k] /= static_cast<T>(variance[k]);
      }
70
    }
G
gaoyuan 已提交
71 72 73 74
  }
}

template <typename T>
75 76 77 78 79
__global__ void DecodeCenterSizeKernel(
    const T* prior_box_data, const T* prior_box_var_data,
    const T* target_box_data, const int row, const int col, const int len,
    const bool normalized, const T prior_box_var_size, const float* variance,
    const int var_size, const int axis, T* output) {
G
gaoyuan 已提交
80
  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
J
jerrywgz 已提交
81
  int prior_box_offset = 0;
G
gaoyuan 已提交
82 83
  if (idx < row * col) {
    const int col_idx = idx % col;
J
jerrywgz 已提交
84
    const int row_idx = idx / col;
85
    prior_box_offset = axis == 0 ? col_idx * len : row_idx * len;
J
jerrywgz 已提交
86 87 88 89 90
    T prior_box_width = prior_box_data[prior_box_offset + 2] -
                        prior_box_data[prior_box_offset] +
                        (normalized == false);
    T prior_box_height = prior_box_data[prior_box_offset + 3] -
                         prior_box_data[prior_box_offset + 1] +
91
                         (normalized == false);
G
gaoyuan 已提交
92
    T prior_box_center_x =
J
jerrywgz 已提交
93 94 95
        prior_box_data[prior_box_offset] + prior_box_width / 2;
    T prior_box_center_y =
        prior_box_data[prior_box_offset + 1] + prior_box_height / 2;
96 97 98
    T target_box_width, target_box_height;
    T target_box_center_x, target_box_center_y;
    if (prior_box_var_data) {
J
jerrywgz 已提交
99 100
      int prior_var_offset = 0;
      if (prior_box_var_size == 2) {
101
        prior_var_offset = axis == 0 ? col_idx * len : row_idx * len;
J
jerrywgz 已提交
102 103
      }
      target_box_width = exp(prior_box_var_data[prior_var_offset + 2] *
Y
Yuan Gao 已提交
104
                             target_box_data[idx * len + 2]) *
G
gaoyuan 已提交
105
                         prior_box_width;
J
jerrywgz 已提交
106
      target_box_height = exp(prior_box_var_data[prior_var_offset + 3] *
Y
Yuan Gao 已提交
107
                              target_box_data[idx * len + 3]) *
G
gaoyuan 已提交
108
                          prior_box_height;
J
jerrywgz 已提交
109
      target_box_center_x = prior_box_var_data[prior_var_offset] *
Y
Yuan Gao 已提交
110
                                target_box_data[idx * len] * prior_box_width +
G
gaoyuan 已提交
111
                            prior_box_center_x;
J
jerrywgz 已提交
112
      target_box_center_y = prior_box_var_data[prior_var_offset + 1] *
Y
Yuan Gao 已提交
113
                                target_box_data[idx * len + 1] *
G
gaoyuan 已提交
114 115
                                prior_box_height +
                            prior_box_center_y;
116 117 118 119 120 121 122 123 124 125 126 127 128 129
    } else if (var_size == 4) {
      target_box_width =
          exp(static_cast<T>(variance[2]) * target_box_data[idx * len + 2]) *
          prior_box_width;
      target_box_height =
          exp(static_cast<T>(variance[3]) * target_box_data[idx * len + 3]) *
          prior_box_height;
      target_box_center_x = static_cast<T>(variance[0]) *
                                target_box_data[idx * len] * prior_box_width +
                            prior_box_center_x;
      target_box_center_y = static_cast<T>(variance[1]) *
                                target_box_data[idx * len + 1] *
                                prior_box_height +
                            prior_box_center_y;
130 131 132 133 134 135 136 137 138
    } else {
      target_box_width = exp(target_box_data[idx * len + 2]) * prior_box_width;
      target_box_height =
          exp(target_box_data[idx * len + 3]) * prior_box_height;
      target_box_center_x =
          target_box_data[idx * len] * prior_box_width + prior_box_center_x;
      target_box_center_y = target_box_data[idx * len + 1] * prior_box_height +
                            prior_box_center_y;
    }
G
gaoyuan 已提交
139

G
gaoyuan 已提交
140 141
    output[idx * len] = target_box_center_x - target_box_width / 2;
    output[idx * len + 1] = target_box_center_y - target_box_height / 2;
142 143 144 145
    output[idx * len + 2] =
        target_box_center_x + target_box_width / 2 - (normalized == false);
    output[idx * len + 3] =
        target_box_center_y + target_box_height / 2 - (normalized == false);
G
gaoyuan 已提交
146 147 148
  }
}

149
template <typename DeviceContext, typename T>
G
gaoyuan 已提交
150 151 152 153 154 155 156 157
class BoxCoderCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
    auto* prior_box = context.Input<framework::Tensor>("PriorBox");
    auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
    auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
G
gaoyuan 已提交
158
    auto* output_box = context.Output<framework::Tensor>("OutputBox");
159
    std::vector<float> variance = context.Attr<std::vector<float>>("variance");
160 161 162
    const T* prior_box_data = prior_box->data<T>();
    const T* target_box_data = target_box->data<T>();
    const T* prior_box_var_data = nullptr;
J
jerrywgz 已提交
163 164
    auto prior_box_var_size = 0;
    if (prior_box_var) {
165 166 167
      PADDLE_ENFORCE(variance.empty(),
                     "Input 'PriorBoxVar' and attribute 'variance' should not"
                     "be used at the same time.");
J
jerrywgz 已提交
168 169 170
      prior_box_var_data = prior_box_var->data<T>();
      prior_box_var_size = prior_box_var->dims().size();
    }
171 172 173 174
    if (!(variance.empty())) {
      PADDLE_ENFORCE(static_cast<int>(variance.size()) == 4,
                     "Size of attribute 'variance' should be 4");
    }
175

G
gaoyuan 已提交
176
    if (target_box->lod().size()) {
G
gaoyuan 已提交
177
      PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
G
gaoyuan 已提交
178 179
                        "Only support 1 level of LoD.");
    }
180 181 182
    const int var_size = static_cast<T>(variance.size());
    thrust::device_vector<float> dev_variance(variance.begin(), variance.end());
    const float* dev_var_data = thrust::raw_pointer_cast(dev_variance.data());
J
jerrywgz 已提交
183 184 185 186
    auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
    bool normalized = context.Attr<bool>("box_normalized");
    int axis = context.Attr<int>("axis");

G
gaoyuan 已提交
187 188
    auto row = target_box->dims()[0];
    auto col = prior_box->dims()[0];
J
jerrywgz 已提交
189 190 191
    if (code_type == BoxCodeType::kDecodeCenterSize) {
      col = target_box->dims()[1];
    }
G
gaoyuan 已提交
192
    auto len = prior_box->dims()[1];
G
gaoyuan 已提交
193 194 195 196
    int block = 512;
    int grid = (row * col + block - 1) / block;
    auto& device_ctx = context.cuda_device_context();

G
gaoyuan 已提交
197
    output_box->mutable_data<T>({row, col, len}, context.GetPlace());
G
gaoyuan 已提交
198 199 200 201
    T* output = output_box->data<T>();

    if (code_type == BoxCodeType::kEncodeCenterSize) {
      EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
G
gaoyuan 已提交
202
          prior_box_data, prior_box_var_data, target_box_data, row, col, len,
203
          normalized, prior_box_var_size, dev_var_data, var_size, output);
G
gaoyuan 已提交
204 205
    } else if (code_type == BoxCodeType::kDecodeCenterSize) {
      DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
G
gaoyuan 已提交
206
          prior_box_data, prior_box_var_data, target_box_data, row, col, len,
207
          normalized, prior_box_var_size, dev_var_data, var_size, axis, output);
G
gaoyuan 已提交
208 209 210 211 212 213 214 215
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
216 217 218 219
REGISTER_OP_CUDA_KERNEL(
    box_coder,
    ops::BoxCoderCUDAKernel<paddle::platform::CUDADeviceContext, float>,
    ops::BoxCoderCUDAKernel<paddle::platform::CUDADeviceContext, double>);