/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/box_coder_op.h" #include "paddle/fluid/platform/cuda_helper.h" namespace paddle { namespace operators { template __global__ void EncodeCenterSizeKernel(const T* prior_box_data, const T* prior_box_var_data, const T* target_box_data, const int row, const int col, const int len, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < row * col) { const int row_idx = idx / col; const int col_idx = idx % col; T prior_box_width = prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len]; T prior_box_height = prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1]; T prior_box_center_x = (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; T prior_box_center_y = (prior_box_data[col_idx * len + 3] + prior_box_data[col_idx * len + 1]) / 2; T target_box_center_x = (target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) / 2; T target_box_center_y = (target_box_data[row_idx * len + 3] + target_box_data[row_idx * len + 1]) / 2; T target_box_width = target_box_data[row_idx * len + 2] - target_box_data[row_idx * len]; T target_box_height = target_box_data[row_idx * len + 3] - target_box_data[row_idx * len + 1]; output[idx * len] = (target_box_center_x - prior_box_center_x) / prior_box_width / prior_box_var_data[col_idx * len]; output[idx * len + 1] = (target_box_center_y - prior_box_center_y) / prior_box_height / prior_box_var_data[col_idx * len + 1]; output[idx * len + 2] = log(fabs(target_box_width / prior_box_width)) / prior_box_var_data[col_idx * len + 2]; output[idx * len + 3] = log(fabs(target_box_height / prior_box_height)) / prior_box_var_data[col_idx * len + 3]; } } template __global__ void DecodeCenterSizeKernel(const T* prior_box_data, const T* prior_box_var_data, const T* target_box_data, const int row, const int col, const int len, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < row * col) { const int col_idx = idx % col; T prior_box_width = prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len]; T prior_box_height = prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1]; T prior_box_center_x = (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; T prior_box_center_y = (prior_box_data[col_idx * len + 3] + prior_box_data[col_idx * len + 1]) / 2; T target_box_width = exp(prior_box_var_data[col_idx * len + 2] * target_box_data[idx * len + 2]) * prior_box_width; T target_box_height = exp(prior_box_var_data[col_idx * len + 3] * target_box_data[idx * len + 3]) * prior_box_height; T target_box_center_x = prior_box_var_data[col_idx * len] * target_box_data[idx * len] * prior_box_width + prior_box_center_x; T target_box_center_y = prior_box_var_data[col_idx * len + 1] * target_box_data[idx * len + 1] * prior_box_height + prior_box_center_y; output[idx * len] = target_box_center_x - target_box_width / 2; output[idx * len + 1] = target_box_center_y - target_box_height / 2; output[idx * len + 2] = target_box_center_x + target_box_width / 2; output[idx * len + 3] = target_box_center_y + target_box_height / 2; } } template class BoxCoderCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), "This kernel only runs on GPU device."); auto* prior_box = context.Input("PriorBox"); auto* prior_box_var = context.Input("PriorBoxVar"); auto* target_box = context.Input("TargetBox"); auto* output_box = context.Output("OutputBox"); if (target_box->lod().size()) { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, "Only support 1 level of LoD."); } auto row = target_box->dims()[0]; auto col = prior_box->dims()[0]; auto len = prior_box->dims()[1]; int block = 512; int grid = (row * col + block - 1) / block; auto& device_ctx = context.cuda_device_context(); const T* prior_box_data = prior_box->data(); const T* prior_box_var_data = prior_box_var->data(); const T* target_box_data = target_box->data(); output_box->mutable_data({row, col, len}, context.GetPlace()); T* output = output_box->data(); auto code_type = GetBoxCodeType(context.Attr("code_type")); if (code_type == BoxCodeType::kEncodeCenterSize) { EncodeCenterSizeKernel<<>>( prior_box_data, prior_box_var_data, target_box_data, row, col, len, output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { DecodeCenterSizeKernel<<>>( prior_box_data, prior_box_var_data, target_box_data, row, col, len, output); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(box_coder, ops::BoxCoderCUDAKernel, ops::BoxCoderCUDAKernel);