box_coder_op.cu 7.2 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
G
gaoyuan 已提交
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

B
baiyf 已提交
12
#include "paddle/fluid/operators/detection/box_coder_op.h"
D
dzhwinter 已提交
13
#include "paddle/fluid/platform/cuda_primitives.h"
G
gaoyuan 已提交
14 15 16 17 18 19 20

namespace paddle {
namespace operators {

template <typename T>
__global__ void EncodeCenterSizeKernel(const T* prior_box_data,
                                       const T* prior_box_var_data,
G
gaoyuan 已提交
21 22
                                       const T* target_box_data, const int row,
                                       const int col, const int len,
23
                                       const bool normalized, T* output) {
G
gaoyuan 已提交
24 25 26 27
  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < row * col) {
    const int row_idx = idx / col;
    const int col_idx = idx % col;
28 29 30 31 32
    T prior_box_width = prior_box_data[col_idx * len + 2] -
                        prior_box_data[col_idx * len] + (normalized == false);
    T prior_box_height = prior_box_data[col_idx * len + 3] -
                         prior_box_data[col_idx * len + 1] +
                         (normalized == false);
G
gaoyuan 已提交
33
    T prior_box_center_x =
G
gaoyuan 已提交
34 35 36 37
        (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2;
    T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
                            prior_box_data[col_idx * len + 1]) /
                           2;
G
gaoyuan 已提交
38 39

    T target_box_center_x =
G
gaoyuan 已提交
40
        (target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) /
G
gaoyuan 已提交
41
        2;
G
gaoyuan 已提交
42 43 44
    T target_box_center_y = (target_box_data[row_idx * len + 3] +
                             target_box_data[row_idx * len + 1]) /
                            2;
45 46 47 48 49
    T target_box_width = target_box_data[row_idx * len + 2] -
                         target_box_data[row_idx * len] + (normalized == false);
    T target_box_height = target_box_data[row_idx * len + 3] -
                          target_box_data[row_idx * len + 1] +
                          (normalized == false);
G
gaoyuan 已提交
50

G
gaoyuan 已提交
51 52 53 54 55 56 57 58 59
    output[idx * len] = (target_box_center_x - prior_box_center_x) /
                        prior_box_width / prior_box_var_data[col_idx * len];
    output[idx * len + 1] = (target_box_center_y - prior_box_center_y) /
                            prior_box_height /
                            prior_box_var_data[col_idx * len + 1];
    output[idx * len + 2] = log(fabs(target_box_width / prior_box_width)) /
                            prior_box_var_data[col_idx * len + 2];
    output[idx * len + 3] = log(fabs(target_box_height / prior_box_height)) /
                            prior_box_var_data[col_idx * len + 3];
G
gaoyuan 已提交
60 61 62 63 64 65
  }
}

template <typename T>
__global__ void DecodeCenterSizeKernel(const T* prior_box_data,
                                       const T* prior_box_var_data,
G
gaoyuan 已提交
66 67
                                       const T* target_box_data, const int row,
                                       const int col, const int len,
68
                                       const bool normalized, T* output) {
G
gaoyuan 已提交
69 70 71
  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < row * col) {
    const int col_idx = idx % col;
72 73 74 75 76
    T prior_box_width = prior_box_data[col_idx * len + 2] -
                        prior_box_data[col_idx * len] + (normalized == false);
    T prior_box_height = prior_box_data[col_idx * len + 3] -
                         prior_box_data[col_idx * len + 1] +
                         (normalized == false);
G
gaoyuan 已提交
77
    T prior_box_center_x =
G
gaoyuan 已提交
78 79 80 81
        (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2;
    T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
                            prior_box_data[col_idx * len + 1]) /
                           2;
G
gaoyuan 已提交
82

G
gaoyuan 已提交
83
    T target_box_width = exp(prior_box_var_data[col_idx * len + 2] *
Y
Yuan Gao 已提交
84
                             target_box_data[idx * len + 2]) *
G
gaoyuan 已提交
85
                         prior_box_width;
G
gaoyuan 已提交
86
    T target_box_height = exp(prior_box_var_data[col_idx * len + 3] *
Y
Yuan Gao 已提交
87
                              target_box_data[idx * len + 3]) *
G
gaoyuan 已提交
88
                          prior_box_height;
G
gaoyuan 已提交
89
    T target_box_center_x = prior_box_var_data[col_idx * len] *
Y
Yuan Gao 已提交
90
                                target_box_data[idx * len] * prior_box_width +
G
gaoyuan 已提交
91
                            prior_box_center_x;
G
gaoyuan 已提交
92
    T target_box_center_y = prior_box_var_data[col_idx * len + 1] *
Y
Yuan Gao 已提交
93
                                target_box_data[idx * len + 1] *
G
gaoyuan 已提交
94 95 96
                                prior_box_height +
                            prior_box_center_y;

G
gaoyuan 已提交
97 98
    output[idx * len] = target_box_center_x - target_box_width / 2;
    output[idx * len + 1] = target_box_center_y - target_box_height / 2;
99 100 101 102
    output[idx * len + 2] =
        target_box_center_x + target_box_width / 2 - (normalized == false);
    output[idx * len + 3] =
        target_box_center_y + target_box_height / 2 - (normalized == false);
G
gaoyuan 已提交
103 104 105 106 107 108 109 110 111 112 113 114
  }
}

template <typename T>
class BoxCoderCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                   "This kernel only runs on GPU device.");
    auto* prior_box = context.Input<framework::Tensor>("PriorBox");
    auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
    auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
G
gaoyuan 已提交
115
    auto* output_box = context.Output<framework::Tensor>("OutputBox");
G
gaoyuan 已提交
116 117

    if (target_box->lod().size()) {
G
gaoyuan 已提交
118
      PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
G
gaoyuan 已提交
119 120 121 122
                        "Only support 1 level of LoD.");
    }
    auto row = target_box->dims()[0];
    auto col = prior_box->dims()[0];
G
gaoyuan 已提交
123
    auto len = prior_box->dims()[1];
G
gaoyuan 已提交
124 125 126 127 128 129 130 131
    int block = 512;
    int grid = (row * col + block - 1) / block;
    auto& device_ctx = context.cuda_device_context();

    const T* prior_box_data = prior_box->data<T>();
    const T* prior_box_var_data = prior_box_var->data<T>();
    const T* target_box_data = target_box->data<T>();

G
gaoyuan 已提交
132
    output_box->mutable_data<T>({row, col, len}, context.GetPlace());
G
gaoyuan 已提交
133 134 135
    T* output = output_box->data<T>();

    auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
136
    bool normalized = context.Attr<bool>("box_normalized");
G
gaoyuan 已提交
137 138
    if (code_type == BoxCodeType::kEncodeCenterSize) {
      EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
G
gaoyuan 已提交
139
          prior_box_data, prior_box_var_data, target_box_data, row, col, len,
140
          normalized, output);
G
gaoyuan 已提交
141 142
    } else if (code_type == BoxCodeType::kDecodeCenterSize) {
      DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
G
gaoyuan 已提交
143
          prior_box_data, prior_box_var_data, target_box_data, row, col, len,
144
          normalized, output);
G
gaoyuan 已提交
145 146 147 148 149 150 151 152 153 154
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(box_coder, ops::BoxCoderCUDAKernel<float>,
                        ops::BoxCoderCUDAKernel<double>);