// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/box_coder_kernel.h" #include #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/impl/box_coder.h" namespace phi { template void EncodeCenterSize(const DenseTensor *target_box, const DenseTensor *prior_box, const DenseTensor *prior_box_var, const bool normalized, const std::vector variance, T *output) { int64_t row = target_box->dims()[0]; int64_t col = prior_box->dims()[0]; int64_t len = prior_box->dims()[1]; #ifdef PADDLE_WITH_MKLML #pragma omp parallel for collapse(2) #endif for (int64_t i = 0; i < row; ++i) { for (int64_t j = 0; j < col; ++j) { auto *target_box_data = target_box->data(); auto *prior_box_data = prior_box->data(); size_t offset = i * col * len + j * len; T prior_box_width = prior_box_data[j * len + 2] - prior_box_data[j * len] + (normalized == false); T prior_box_height = prior_box_data[j * len + 3] - prior_box_data[j * len + 1] + (normalized == false); T prior_box_center_x = prior_box_data[j * len] + prior_box_width / 2; T prior_box_center_y = prior_box_data[j * len + 1] + prior_box_height / 2; T target_box_center_x = (target_box_data[i * len + 2] + target_box_data[i * len]) / 2; T target_box_center_y = (target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2; T target_box_width = target_box_data[i * len + 2] - target_box_data[i * len] + (normalized == false); T target_box_height = target_box_data[i * len + 3] - target_box_data[i * len + 1] + (normalized == false); output[offset] = (target_box_center_x - prior_box_center_x) / prior_box_width; output[offset + 1] = (target_box_center_y - prior_box_center_y) / prior_box_height; output[offset + 2] = std::log(std::fabs(target_box_width / prior_box_width)); output[offset + 3] = std::log(std::fabs(target_box_height / prior_box_height)); } } if (prior_box_var) { const T *prior_box_var_data = prior_box_var->data(); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for collapse(3) #endif for (int64_t i = 0; i < row; ++i) { for (int64_t j = 0; j < col; ++j) { for (int k = 0; k < 4; ++k) { size_t offset = i * col * len + j * len; int prior_var_offset = j * len; output[offset + k] /= prior_box_var_data[prior_var_offset + k]; } } } } else if (!(variance.empty())) { #ifdef PADDLE_WITH_MKLML #pragma omp parallel for collapse(3) #endif for (int64_t i = 0; i < row; ++i) { for (int64_t j = 0; j < col; ++j) { for (int k = 0; k < 4; ++k) { size_t offset = i * col * len + j * len; output[offset + k] /= static_cast(variance[k]); } } } } } template void DecodeCenterSize(const DenseTensor *target_box, const DenseTensor *prior_box, const DenseTensor *prior_box_var, const bool normalized, std::vector variance, T *output) { int64_t row = target_box->dims()[0]; int64_t col = target_box->dims()[1]; int64_t len = target_box->dims()[2]; #ifdef PADDLE_WITH_MKLML #pragma omp parallel for collapse(2) #endif for (int64_t i = 0; i < row; ++i) { for (int64_t j = 0; j < col; ++j) { auto *target_box_data = target_box->data(); auto *prior_box_data = prior_box->data(); std::array var_data{1., 1., 1., 1.}; T *var_ptr = var_data.data(); size_t offset = i * col * len + j * len; int prior_box_offset = axis == 0 ? j * len : i * len; T prior_box_width = prior_box_data[prior_box_offset + 2] - prior_box_data[prior_box_offset] + (normalized == false); T prior_box_height = prior_box_data[prior_box_offset + 3] - prior_box_data[prior_box_offset + 1] + (normalized == false); T prior_box_center_x = prior_box_data[prior_box_offset] + prior_box_width / 2; T prior_box_center_y = prior_box_data[prior_box_offset + 1] + prior_box_height / 2; T target_box_center_x = 0, target_box_center_y = 0; T target_box_width = 0, target_box_height = 0; int prior_var_offset = axis == 0 ? j * len : i * len; if (var_size == 2) { std::memcpy(var_ptr, prior_box_var->data() + prior_var_offset, 4 * sizeof(T)); } else if (var_size == 1) { var_ptr = reinterpret_cast(variance.data()); } T box_var_x = *var_ptr; T box_var_y = *(var_ptr + 1); T box_var_w = *(var_ptr + 2); T box_var_h = *(var_ptr + 3); target_box_center_x = box_var_x * target_box_data[offset] * prior_box_width + prior_box_center_x; target_box_center_y = box_var_y * target_box_data[offset + 1] * prior_box_height + prior_box_center_y; target_box_width = std::exp(box_var_w * target_box_data[offset + 2]) * prior_box_width; target_box_height = std::exp(box_var_h * target_box_data[offset + 3]) * prior_box_height; output[offset] = target_box_center_x - target_box_width / 2; output[offset + 1] = target_box_center_y - target_box_height / 2; output[offset + 2] = target_box_center_x + target_box_width / 2 - (normalized == false); output[offset + 3] = target_box_center_y + target_box_height / 2 - (normalized == false); } } } template void BoxCoderKernel(const Context &dev_ctx, const DenseTensor &prior_box, const paddle::optional &prior_box_var, const DenseTensor &target_box, const std::string &code_type_str, bool normalized, int axis, const std::vector &variance, DenseTensor *output_box) { if (!target_box.lod().empty()) { PADDLE_ENFORCE_EQ(target_box.lod().size(), 1UL, phi::errors::InvalidArgument( "Input(TargetBox) of BoxCoder operator " "supports LoD with only one level. But received " "level = %d", target_box.lod().size())); } if (prior_box_var) { PADDLE_ENFORCE_EQ(variance.empty(), true, phi::errors::InvalidArgument( "Input 'PriorBoxVar' and attribute 'variance' " "of BoxCoder operator should not be used at the " "same time.")); } if (!(variance.empty())) { PADDLE_ENFORCE_EQ( static_cast(variance.size()), 4, phi::errors::InvalidArgument("Size of attribute 'variance' of BoxCoder " "operator should be 4. But received " "size = %d", variance.size())); } auto code_type = phi::funcs::GetBoxCodeType(code_type_str); auto row = target_box.dims()[0]; auto col = prior_box.dims()[0]; if (code_type == phi::funcs::BoxCodeType::kDecodeCenterSize) { col = target_box.dims()[1]; } auto len = prior_box.dims()[1]; output_box->Resize({row, col, len}); dev_ctx.template Alloc(output_box); T *output = output_box->data(); if (code_type == phi::funcs::BoxCodeType::kEncodeCenterSize) { EncodeCenterSize(&target_box, &prior_box, prior_box_var.get_ptr(), normalized, variance, output); } else if (code_type == phi::funcs::BoxCodeType::kDecodeCenterSize) { if (prior_box_var) { if (axis == 0) { DecodeCenterSize(&target_box, &prior_box, prior_box_var.get_ptr(), normalized, variance, output); } else { DecodeCenterSize(&target_box, &prior_box, prior_box_var.get_ptr(), normalized, variance, output); } } else if (!(variance.empty())) { if (axis == 0) { DecodeCenterSize(&target_box, &prior_box, prior_box_var.get_ptr(), normalized, variance, output); } else { DecodeCenterSize(&target_box, &prior_box, prior_box_var.get_ptr(), normalized, variance, output); } } else { if (axis == 0) { DecodeCenterSize(&target_box, &prior_box, prior_box_var.get_ptr(), normalized, variance, output); } else { DecodeCenterSize(&target_box, &prior_box, prior_box_var.get_ptr(), normalized, variance, output); } } } } } // namespace phi PD_REGISTER_KERNEL( box_coder, CPU, ALL_LAYOUT, phi::BoxCoderKernel, float, double) {}