From 98f8fa4cb3d7b9c14c1897f2716249a54c44f8be Mon Sep 17 00:00:00 2001 From: lyq <30404405+affectionlu@users.noreply.github.com> Date: Tue, 26 Jul 2022 15:45:07 +0800 Subject: [PATCH] [Phi] Migrate box coder to phi. (#44550) --- .../fluid/operators/detection/CMakeLists.txt | 5 +- .../fluid/operators/detection/box_coder_op.cc | 137 +------- .../fluid/operators/detection/box_coder_op.cu | 248 --------------- .../fluid/operators/detection/box_coder_op.h | 295 ------------------ .../operators/detection/box_coder_op_npu.cc | 10 +- paddle/phi/api/yaml/legacy_api.yaml | 10 + paddle/phi/infermeta/ternary.cc | 110 +++++++ paddle/phi/infermeta/ternary.h | 10 + paddle/phi/kernels/box_coder_kernel.h | 34 ++ paddle/phi/kernels/cpu/box_coder.cc | 281 +++++++++++++++++ paddle/phi/kernels/gpu/box_coder.cu | 246 +++++++++++++++ paddle/phi/kernels/impl/box_coder.h | 43 +++ paddle/phi/ops/compat/box_coder_sig.cc | 28 ++ python/paddle/fluid/layers/detection.py | 17 + .../tests/unittests/test_box_coder_op.py | 53 +++- .../unittests/test_squared_l2_norm_op.py | 3 +- 16 files changed, 851 insertions(+), 679 deletions(-) delete mode 100644 paddle/fluid/operators/detection/box_coder_op.cu delete mode 100644 paddle/fluid/operators/detection/box_coder_op.h create mode 100644 paddle/phi/kernels/box_coder_kernel.h create mode 100644 paddle/phi/kernels/cpu/box_coder.cc create mode 100644 paddle/phi/kernels/gpu/box_coder.cu create mode 100644 paddle/phi/kernels/impl/box_coder.h create mode 100644 paddle/phi/ops/compat/box_coder_sig.cc diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index c05c39e88d..000a1a4a52 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -29,12 +29,11 @@ function(detection_library TARGET_NAME) endfunction() if(WITH_ASCEND_CL) - detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu - box_coder_op_npu.cc) + detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op_npu.cc) detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu density_prior_box_op_npu.cc) else() - detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu) + detection_library(box_coder_op SRCS box_coder_op.cc) detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu) endif() diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index 64aa863156..53a9d04fb5 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -9,135 +9,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/detection/box_coder_op.h" - #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/ternary.h" + namespace paddle { namespace operators { class BoxCoderOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("PriorBox"), - true, - platform::errors::NotFound( - "Input(PriorBox) of BoxCoder operator is not found.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("TargetBox"), - true, - platform::errors::NotFound( - "Input(TargetBox) of BoxCoder operator is not found.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("OutputBox"), - true, - platform::errors::NotFound( - "Output(OutputBox) of BoxCoder operator is not found.")); - - auto prior_box_dims = ctx->GetInputDim("PriorBox"); - auto target_box_dims = ctx->GetInputDim("TargetBox"); - - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(prior_box_dims.size(), - 2, - platform::errors::InvalidArgument( - "The rank of Input PriorBox in BoxCoder operator " - "must be 2. But received rank = %d", - prior_box_dims.size())); - PADDLE_ENFORCE_EQ(prior_box_dims[1], - 4, - platform::errors::InvalidArgument( - "The second dimension of PriorBox in BoxCoder " - "operator must be 4. But received dimension = %d", - prior_box_dims[1])); - if (ctx->HasInput("PriorBoxVar")) { - auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar"); - PADDLE_ENFORCE_EQ( - prior_box_var_dims.size(), - 2, - platform::errors::InvalidArgument( - "The rank of Input(PriorBoxVar) in BoxCoder operator" - " should be 2. But received rank = %d", - prior_box_var_dims.size())); - PADDLE_ENFORCE_EQ( - prior_box_dims, - prior_box_var_dims, - platform::errors::InvalidArgument( - "The dimension of Input(PriorBoxVar) should be equal to" - "the dimension of Input(PriorBox) in BoxCoder operator " - "when the rank is 2.")); - } - } - - auto code_type = GetBoxCodeType(ctx->Attrs().Get("code_type")); - int axis = ctx->Attrs().Get("axis"); - if (code_type == BoxCodeType::kEncodeCenterSize) { - PADDLE_ENFORCE_EQ(target_box_dims.size(), - 2, - platform::errors::InvalidArgument( - "The rank of Input TargetBox in BoxCoder operator " - "must be 2. But received rank is %d", - target_box_dims.size())); - PADDLE_ENFORCE_EQ(target_box_dims[1], - 4, - platform::errors::InvalidArgument( - "The second dimension of TargetBox in BoxCoder " - "operator is 4. But received dimension is %d", - target_box_dims[1])); - ctx->SetOutputDim( - "OutputBox", - phi::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); - } else if (code_type == BoxCodeType::kDecodeCenterSize) { - PADDLE_ENFORCE_EQ(target_box_dims.size(), - 3, - platform::errors::InvalidArgument( - "The rank of Input TargetBox in BoxCoder " - "operator must be 3. But received rank is %d", - target_box_dims.size())); - PADDLE_ENFORCE_EQ(axis == 0 || axis == 1, - true, - platform::errors::InvalidArgument( - "axis in BoxCoder operator must be 0 or 1." - "But received axis = %d", - axis)); - if (ctx->IsRuntime()) { - if (axis == 0) { - PADDLE_ENFORCE_EQ( - target_box_dims[1], - prior_box_dims[0], - platform::errors::InvalidArgument( - "When axis is 0, The second " - "dimension of TargetBox in BoxCoder " - "should be equal to the first dimension of PriorBox.")); - } else if (axis == 1) { - PADDLE_ENFORCE_EQ( - target_box_dims[0], - prior_box_dims[0], - platform::errors::InvalidArgument( - "When axis is 1, The first " - "dimension of TargetBox in BoxCoder " - "should be equal to the first dimension of PriorBox.")); - } - PADDLE_ENFORCE_EQ(target_box_dims[2], - prior_box_dims[1], - platform::errors::InvalidArgument( - "The third dimension of TargetBox" - " in BoxCoder should be equal to the " - "second dimension of PriorBox.")); - } - ctx->ShareDim("TargetBox", /*->*/ "OutputBox"); - } - - if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) { - ctx->ShareLoD("PriorBox", /*->*/ "OutputBox"); - } else { - ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); - } - } }; class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { @@ -245,12 +129,15 @@ box will broadcast to target box along the assigned axis. } // namespace paddle namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(box_coder, + BoxCoderInferShapeFunctor, + PD_INFER_META(phi::BoxCoderInferMeta)); + REGISTER_OPERATOR( box_coder, ops::BoxCoderOp, ops::BoxCoderOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(box_coder, - ops::BoxCoderKernel, - ops::BoxCoderKernel); + paddle::framework::EmptyGradOpMaker, + BoxCoderInferShapeFunctor); diff --git a/paddle/fluid/operators/detection/box_coder_op.cu b/paddle/fluid/operators/detection/box_coder_op.cu deleted file mode 100644 index 9c305f8515..0000000000 --- a/paddle/fluid/operators/detection/box_coder_op.cu +++ /dev/null @@ -1,248 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -#include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/operators/detection/box_coder_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -template -__global__ void EncodeCenterSizeKernel(const T* prior_box_data, - const T* prior_box_var_data, - const T* target_box_data, - const int row, - const int col, - const int len, - const bool normalized, - const T prior_box_var_size, - const float* variance, - const int var_size, - T* output) { - const int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < row * col) { - const int row_idx = idx / col; - const int col_idx = idx % col; - T prior_box_width = prior_box_data[col_idx * len + 2] - - prior_box_data[col_idx * len] + (normalized == false); - T prior_box_height = prior_box_data[col_idx * len + 3] - - prior_box_data[col_idx * len + 1] + - (normalized == false); - T prior_box_center_x = prior_box_data[col_idx * len] + prior_box_width / 2; - T prior_box_center_y = - prior_box_data[col_idx * len + 1] + prior_box_height / 2; - - T target_box_center_x = - (target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) / - 2; - T target_box_center_y = (target_box_data[row_idx * len + 3] + - target_box_data[row_idx * len + 1]) / - 2; - T target_box_width = target_box_data[row_idx * len + 2] - - target_box_data[row_idx * len] + (normalized == false); - T target_box_height = target_box_data[row_idx * len + 3] - - target_box_data[row_idx * len + 1] + - (normalized == false); - - output[idx * len] = - (target_box_center_x - prior_box_center_x) / prior_box_width; - output[idx * len + 1] = - (target_box_center_y - prior_box_center_y) / prior_box_height; - output[idx * len + 2] = log(fabs(target_box_width / prior_box_width)); - output[idx * len + 3] = log(fabs(target_box_height / prior_box_height)); - if (prior_box_var_data) { - int prior_var_offset = col_idx * len; - output[idx * len] /= prior_box_var_data[prior_var_offset]; - output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1]; - output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2]; - output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3]; - } else if (var_size == 4) { - for (int k = 0; k < 4; ++k) { - output[idx * len + k] /= static_cast(variance[k]); - } - } - } -} - -template -__global__ void DecodeCenterSizeKernel(const T* prior_box_data, - const T* prior_box_var_data, - const T* target_box_data, - const int row, - const int col, - const int len, - const bool normalized, - const T prior_box_var_size, - const float* variance, - const int var_size, - const int axis, - T* output) { - const int idx = threadIdx.x + blockIdx.x * blockDim.x; - int prior_box_offset = 0; - if (idx < row * col) { - const int col_idx = idx % col; - const int row_idx = idx / col; - prior_box_offset = axis == 0 ? col_idx * len : row_idx * len; - T prior_box_width = prior_box_data[prior_box_offset + 2] - - prior_box_data[prior_box_offset] + - (normalized == false); - T prior_box_height = prior_box_data[prior_box_offset + 3] - - prior_box_data[prior_box_offset + 1] + - (normalized == false); - T prior_box_center_x = - prior_box_data[prior_box_offset] + prior_box_width / 2; - T prior_box_center_y = - prior_box_data[prior_box_offset + 1] + prior_box_height / 2; - T target_box_width, target_box_height; - T target_box_center_x, target_box_center_y; - T box_var_x = T(1), box_var_y = T(1); - T box_var_w = T(1), box_var_h = T(1); - if (prior_box_var_data) { - int prior_var_offset = axis == 0 ? col_idx * len : row_idx * len; - box_var_x = prior_box_var_data[prior_var_offset]; - box_var_y = prior_box_var_data[prior_var_offset + 1]; - box_var_w = prior_box_var_data[prior_var_offset + 2]; - box_var_h = prior_box_var_data[prior_var_offset + 3]; - } else if (var_size == 4) { - box_var_x = static_cast(variance[0]); - box_var_y = static_cast(variance[1]); - box_var_w = static_cast(variance[2]); - box_var_h = static_cast(variance[3]); - } - target_box_width = - exp(box_var_w * target_box_data[idx * len + 2]) * prior_box_width; - target_box_height = - exp(box_var_h * target_box_data[idx * len + 3]) * prior_box_height; - target_box_center_x = - box_var_x * target_box_data[idx * len] * prior_box_width + - prior_box_center_x; - target_box_center_y = - box_var_y * target_box_data[idx * len + 1] * prior_box_height + - prior_box_center_y; - - output[idx * len] = target_box_center_x - target_box_width / 2; - output[idx * len + 1] = target_box_center_y - target_box_height / 2; - output[idx * len + 2] = - target_box_center_x + target_box_width / 2 - (normalized == false); - output[idx * len + 3] = - target_box_center_y + target_box_height / 2 - (normalized == false); - } -} - -template -class BoxCoderCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* prior_box = context.Input("PriorBox"); - auto* prior_box_var = context.Input("PriorBoxVar"); - auto* target_box = context.Input("TargetBox"); - auto* output_box = context.Output("OutputBox"); - std::vector variance = context.Attr>("variance"); - const T* prior_box_data = prior_box->data(); - const T* target_box_data = target_box->data(); - const T* prior_box_var_data = nullptr; - auto prior_box_var_size = 0; - if (prior_box_var) { - PADDLE_ENFORCE_EQ(variance.empty(), - true, - platform::errors::InvalidArgument( - "Input 'PriorBoxVar' and attribute 'variance'" - " of BoxCoder operator should not be used at the " - "same time.")); - prior_box_var_data = prior_box_var->data(); - prior_box_var_size = prior_box_var->dims().size(); - } - if (!(variance.empty())) { - PADDLE_ENFORCE_EQ(static_cast(variance.size()), - 4, - platform::errors::InvalidArgument( - "Size of attribute 'variance' in BoxCoder operator" - " should be 4. But received size is %d", - variance.size())); - } - - if (target_box->lod().size()) { - PADDLE_ENFORCE_EQ(target_box->lod().size(), - 1, - platform::errors::InvalidArgument( - "Input 'TargetBox' of BoxCoder operator only" - " supports LoD with one level.")); - } - const int var_size = static_cast(variance.size()); - - auto code_type = GetBoxCodeType(context.Attr("code_type")); - bool normalized = context.Attr("box_normalized"); - int axis = context.Attr("axis"); - - auto row = target_box->dims()[0]; - auto col = prior_box->dims()[0]; - if (code_type == BoxCodeType::kDecodeCenterSize) { - col = target_box->dims()[1]; - } - auto len = prior_box->dims()[1]; - int block = 512; - int grid = (row * col + block - 1) / block; - auto& device_ctx = context.cuda_device_context(); - - int bytes = var_size * sizeof(float); - auto dev_var = memory::Alloc(device_ctx, bytes); - float* dev_var_data = reinterpret_cast(dev_var->ptr()); - auto cplace = platform::CPUPlace(); - const auto gplace = context.GetPlace(); - memory::Copy( - gplace, dev_var_data, cplace, &variance[0], bytes, device_ctx.stream()); - - output_box->mutable_data({row, col, len}, context.GetPlace()); - T* output = output_box->data(); - - if (code_type == BoxCodeType::kEncodeCenterSize) { - EncodeCenterSizeKernel - <<>>(prior_box_data, - prior_box_var_data, - target_box_data, - row, - col, - len, - normalized, - prior_box_var_size, - dev_var_data, - var_size, - output); - } else if (code_type == BoxCodeType::kDecodeCenterSize) { - DecodeCenterSizeKernel - <<>>(prior_box_data, - prior_box_var_data, - target_box_data, - row, - col, - len, - normalized, - prior_box_var_size, - dev_var_data, - var_size, - axis, - output); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - box_coder, - ops::BoxCoderCUDAKernel, - ops::BoxCoderCUDAKernel); diff --git a/paddle/fluid/operators/detection/box_coder_op.h b/paddle/fluid/operators/detection/box_coder_op.h deleted file mode 100644 index 8e0267cf8b..0000000000 --- a/paddle/fluid/operators/detection/box_coder_op.h +++ /dev/null @@ -1,295 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -enum class BoxCodeType { kEncodeCenterSize = 0, kDecodeCenterSize = 1 }; - -inline BoxCodeType GetBoxCodeType(const std::string &type) { - PADDLE_ENFORCE_EQ( - (type == "encode_center_size") || (type == "decode_center_size"), - true, - platform::errors::InvalidArgument( - "The 'code_type' attribute in BoxCoder" - " must be 'encode_center_size' or 'decode_center_size'. " - "But received 'code_type' is %s", - type)); - if (type == "encode_center_size") { - return BoxCodeType::kEncodeCenterSize; - } else { - return BoxCodeType::kDecodeCenterSize; - } -} - -template -class BoxCoderKernel : public framework::OpKernel { - public: - void EncodeCenterSize(const framework::Tensor *target_box, - const framework::Tensor *prior_box, - const framework::Tensor *prior_box_var, - const bool normalized, - const std::vector variance, - T *output) const { - int64_t row = target_box->dims()[0]; - int64_t col = prior_box->dims()[0]; - int64_t len = prior_box->dims()[1]; - -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for collapse(2) -#endif - for (int64_t i = 0; i < row; ++i) { - for (int64_t j = 0; j < col; ++j) { - auto *target_box_data = target_box->data(); - auto *prior_box_data = prior_box->data(); - size_t offset = i * col * len + j * len; - T prior_box_width = prior_box_data[j * len + 2] - - prior_box_data[j * len] + (normalized == false); - T prior_box_height = prior_box_data[j * len + 3] - - prior_box_data[j * len + 1] + - (normalized == false); - T prior_box_center_x = prior_box_data[j * len] + prior_box_width / 2; - T prior_box_center_y = - prior_box_data[j * len + 1] + prior_box_height / 2; - - T target_box_center_x = - (target_box_data[i * len + 2] + target_box_data[i * len]) / 2; - T target_box_center_y = - (target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2; - T target_box_width = target_box_data[i * len + 2] - - target_box_data[i * len] + (normalized == false); - T target_box_height = target_box_data[i * len + 3] - - target_box_data[i * len + 1] + - (normalized == false); - - output[offset] = - (target_box_center_x - prior_box_center_x) / prior_box_width; - output[offset + 1] = - (target_box_center_y - prior_box_center_y) / prior_box_height; - output[offset + 2] = - std::log(std::fabs(target_box_width / prior_box_width)); - output[offset + 3] = - std::log(std::fabs(target_box_height / prior_box_height)); - } - } - - if (prior_box_var) { - const T *prior_box_var_data = prior_box_var->data(); -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for collapse(3) -#endif - for (int64_t i = 0; i < row; ++i) { - for (int64_t j = 0; j < col; ++j) { - for (int k = 0; k < 4; ++k) { - size_t offset = i * col * len + j * len; - int prior_var_offset = j * len; - output[offset + k] /= prior_box_var_data[prior_var_offset + k]; - } - } - } - } else if (!(variance.empty())) { -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for collapse(3) -#endif - for (int64_t i = 0; i < row; ++i) { - for (int64_t j = 0; j < col; ++j) { - for (int k = 0; k < 4; ++k) { - size_t offset = i * col * len + j * len; - output[offset + k] /= static_cast(variance[k]); - } - } - } - } - } - - template - void DecodeCenterSize(const framework::Tensor *target_box, - const framework::Tensor *prior_box, - const framework::Tensor *prior_box_var, - const bool normalized, - std::vector variance, - T *output) const { - int64_t row = target_box->dims()[0]; - int64_t col = target_box->dims()[1]; - int64_t len = target_box->dims()[2]; - -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for collapse(2) -#endif - for (int64_t i = 0; i < row; ++i) { - for (int64_t j = 0; j < col; ++j) { - auto *target_box_data = target_box->data(); - auto *prior_box_data = prior_box->data(); - - T var_data[4] = {1., 1., 1., 1.}; - T *var_ptr = var_data; - size_t offset = i * col * len + j * len; - int prior_box_offset = axis == 0 ? j * len : i * len; - - T prior_box_width = prior_box_data[prior_box_offset + 2] - - prior_box_data[prior_box_offset] + - (normalized == false); - T prior_box_height = prior_box_data[prior_box_offset + 3] - - prior_box_data[prior_box_offset + 1] + - (normalized == false); - T prior_box_center_x = - prior_box_data[prior_box_offset] + prior_box_width / 2; - T prior_box_center_y = - prior_box_data[prior_box_offset + 1] + prior_box_height / 2; - - T target_box_center_x = 0, target_box_center_y = 0; - T target_box_width = 0, target_box_height = 0; - int prior_var_offset = axis == 0 ? j * len : i * len; - if (var_size == 2) { - std::memcpy(var_ptr, - prior_box_var->data() + prior_var_offset, - 4 * sizeof(T)); - } else if (var_size == 1) { - var_ptr = reinterpret_cast(variance.data()); - } - T box_var_x = *var_ptr; - T box_var_y = *(var_ptr + 1); - T box_var_w = *(var_ptr + 2); - T box_var_h = *(var_ptr + 3); - - target_box_center_x = - box_var_x * target_box_data[offset] * prior_box_width + - prior_box_center_x; - target_box_center_y = - box_var_y * target_box_data[offset + 1] * prior_box_height + - prior_box_center_y; - target_box_width = - std::exp(box_var_w * target_box_data[offset + 2]) * prior_box_width; - target_box_height = std::exp(box_var_h * target_box_data[offset + 3]) * - prior_box_height; - - output[offset] = target_box_center_x - target_box_width / 2; - output[offset + 1] = target_box_center_y - target_box_height / 2; - output[offset + 2] = - target_box_center_x + target_box_width / 2 - (normalized == false); - output[offset + 3] = - target_box_center_y + target_box_height / 2 - (normalized == false); - } - } - } - - void Compute(const framework::ExecutionContext &context) const override { - auto *prior_box = context.Input("PriorBox"); - auto *prior_box_var = context.Input("PriorBoxVar"); - auto *target_box = context.Input("TargetBox"); - auto *output_box = context.Output("OutputBox"); - std::vector variance = context.Attr>("variance"); - const int axis = context.Attr("axis"); - if (target_box->lod().size()) { - PADDLE_ENFORCE_EQ(target_box->lod().size(), - 1UL, - platform::errors::InvalidArgument( - "Input(TargetBox) of BoxCoder operator " - "supports LoD with only one level. But received " - "level = %d", - target_box->lod().size())); - } - if (prior_box_var) { - PADDLE_ENFORCE_EQ(variance.empty(), - true, - platform::errors::InvalidArgument( - "Input 'PriorBoxVar' and attribute 'variance' " - "of BoxCoder operator should not be used at the " - "same time.")); - } - if (!(variance.empty())) { - PADDLE_ENFORCE_EQ(static_cast(variance.size()), - 4, - platform::errors::InvalidArgument( - "Size of attribute 'variance' of BoxCoder " - "operator should be 4. But received " - "size = %d", - variance.size())); - } - auto code_type = GetBoxCodeType(context.Attr("code_type")); - bool normalized = context.Attr("box_normalized"); - - auto row = target_box->dims()[0]; - auto col = prior_box->dims()[0]; - if (code_type == BoxCodeType::kDecodeCenterSize) { - col = target_box->dims()[1]; - } - auto len = prior_box->dims()[1]; - - output_box->mutable_data({row, col, len}, context.GetPlace()); - - T *output = output_box->data(); - if (code_type == BoxCodeType::kEncodeCenterSize) { - EncodeCenterSize( - target_box, prior_box, prior_box_var, normalized, variance, output); - } else if (code_type == BoxCodeType::kDecodeCenterSize) { - if (prior_box_var) { - if (axis == 0) { - DecodeCenterSize<0, 2>(target_box, - prior_box, - prior_box_var, - normalized, - variance, - output); - } else { - DecodeCenterSize<1, 2>(target_box, - prior_box, - prior_box_var, - normalized, - variance, - output); - } - } else if (!(variance.empty())) { - if (axis == 0) { - DecodeCenterSize<0, 1>(target_box, - prior_box, - prior_box_var, - normalized, - variance, - output); - } else { - DecodeCenterSize<1, 1>(target_box, - prior_box, - prior_box_var, - normalized, - variance, - output); - } - } else { - if (axis == 0) { - DecodeCenterSize<0, 0>(target_box, - prior_box, - prior_box_var, - normalized, - variance, - output); - } else { - DecodeCenterSize<1, 0>(target_box, - prior_box, - prior_box_var, - normalized, - variance, - output); - } - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/detection/box_coder_op_npu.cc b/paddle/fluid/operators/detection/box_coder_op_npu.cc index e70bf3a510..8181f10f2b 100644 --- a/paddle/fluid/operators/detection/box_coder_op_npu.cc +++ b/paddle/fluid/operators/detection/box_coder_op_npu.cc @@ -9,8 +9,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/detection/box_coder_op.h" +#include +#include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" +#include "paddle/phi/kernels/impl/box_coder.h" namespace paddle { namespace operators { @@ -407,10 +410,11 @@ class BoxCoderNPUKernel : public framework::OpKernel { " supports LoD with one level.")); } - auto code_type = GetBoxCodeType(ctx.Attr("code_type")); + auto code_type = + phi::funcs::GetBoxCodeType(ctx.Attr("code_type")); bool normalized = ctx.Attr("box_normalized"); - if (code_type == BoxCodeType::kEncodeCenterSize) { + if (code_type == phi::funcs::BoxCodeType::kEncodeCenterSize) { BoxCoderEnc(ctx, target_box, prior_box, diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index ea52ffd960..00a68bb0c4 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -326,6 +326,16 @@ kernel : func : bitwise_xor +# box_coder +- api : box_coder + args : (Tensor prior_box, Tensor prior_box_var, Tensor target_box, str code_type, bool box_normalized, int axis, float[] variance) + output : Tensor(output_box) + infer_meta : + func : BoxCoderInferMeta + kernel : + func : box_coder + optional : prior_box_var + # brelu - api : brelu args : (Tensor x, float t_min, float t_max) diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 9f65de0f0a..b83febb24b 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/phi/common/layout.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/impl/box_coder.h" namespace phi { @@ -145,6 +146,115 @@ void AddmmInferMeta(const MetaTensor& input, out->set_dtype(input.dtype()); } +void BoxCoderInferMeta(const MetaTensor& prior_box, + const MetaTensor& prior_box_var, + const MetaTensor& target_box, + const std::string& code_type, + bool box_normalized, + int axis, + const std::vector& variance, + MetaTensor* output_box, + MetaConfig config) { + auto prior_box_dims = prior_box.dims(); + auto target_box_dims = target_box.dims(); + + if (config.is_runtime) { + PADDLE_ENFORCE_EQ(prior_box_dims.size(), + 2, + phi::errors::InvalidArgument( + "The rank of Input PriorBox in BoxCoder operator " + "must be 2. But received rank = %d", + prior_box_dims.size())); + PADDLE_ENFORCE_EQ(prior_box_dims[1], + 4, + phi::errors::InvalidArgument( + "The second dimension of PriorBox in BoxCoder " + "operator must be 4. But received dimension = %d", + prior_box_dims[1])); + if (prior_box_var) { + auto prior_box_var_dims = prior_box_var.dims(); + PADDLE_ENFORCE_EQ( + prior_box_var_dims.size(), + 2, + phi::errors::InvalidArgument( + "The rank of Input(PriorBoxVar) in BoxCoder operator" + " should be 2. But received rank = %d", + prior_box_var_dims.size())); + PADDLE_ENFORCE_EQ( + prior_box_dims, + prior_box_var_dims, + phi::errors::InvalidArgument( + "The dimension of Input(PriorBoxVar) should be equal to" + "the dimension of Input(PriorBox) in BoxCoder operator " + "when the rank is 2.")); + } + } + + auto box_code_type = phi::funcs::GetBoxCodeType(code_type); + if (box_code_type == phi::funcs::BoxCodeType::kEncodeCenterSize) { + PADDLE_ENFORCE_EQ(target_box_dims.size(), + 2, + phi::errors::InvalidArgument( + "The rank of Input TargetBox in BoxCoder operator " + "must be 2. But received rank is %d", + target_box_dims.size())); + PADDLE_ENFORCE_EQ(target_box_dims[1], + 4, + phi::errors::InvalidArgument( + "The second dimension of TargetBox in BoxCoder " + "operator is 4. But received dimension is %d", + target_box_dims[1])); + output_box->set_dims({target_box_dims[0], prior_box_dims[0], 4}); + } else if (box_code_type == phi::funcs::BoxCodeType::kDecodeCenterSize) { + PADDLE_ENFORCE_EQ(target_box_dims.size(), + 3, + phi::errors::InvalidArgument( + "The rank of Input TargetBox in BoxCoder " + "operator must be 3. But received rank is %d", + target_box_dims.size())); + PADDLE_ENFORCE_EQ( + axis == 0 || axis == 1, + true, + phi::errors::InvalidArgument("axis in BoxCoder operator must be 0 or 1." + "But received axis = %d", + axis)); + if (config.is_runtime) { + if (axis == 0) { + PADDLE_ENFORCE_EQ( + target_box_dims[1], + prior_box_dims[0], + phi::errors::InvalidArgument( + "When axis is 0, The second " + "dimension of TargetBox in BoxCoder " + "should be equal to the first dimension of PriorBox.")); + } else if (axis == 1) { + PADDLE_ENFORCE_EQ( + target_box_dims[0], + prior_box_dims[0], + phi::errors::InvalidArgument( + "When axis is 1, The first " + "dimension of TargetBox in BoxCoder " + "should be equal to the first dimension of PriorBox.")); + } + PADDLE_ENFORCE_EQ( + target_box_dims[2], + prior_box_dims[1], + phi::errors::InvalidArgument("The third dimension of TargetBox" + " in BoxCoder should be equal to the " + "second dimension of PriorBox.")); + } + output_box->share_dims(target_box); + } + + if (box_code_type == phi::funcs::BoxCodeType::kDecodeCenterSize && + axis == 1) { + output_box->share_lod(prior_box); + } else { + output_box->share_lod(target_box); + } + output_box->set_dtype(target_box.dtype()); +} + void ArangeInferMeta(const MetaTensor& start, const MetaTensor& end, const MetaTensor& step, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 40461d299f..329c6e13a5 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -52,6 +52,16 @@ void ArangeInferMeta(const MetaTensor& start, const MetaTensor& step, MetaTensor* out); +void BoxCoderInferMeta(const MetaTensor& prior_box, + const MetaTensor& prior_box_var, + const MetaTensor& target_box, + const std::string& code_type, + bool box_normalized, + int axis, + const std::vector& variance, + MetaTensor* output_box, + MetaConfig config = MetaConfig()); + void InstanceNormInferMeta(const MetaTensor& x, const MetaTensor& scale, const MetaTensor& bias, diff --git a/paddle/phi/kernels/box_coder_kernel.h b/paddle/phi/kernels/box_coder_kernel.h new file mode 100644 index 0000000000..9e5c52bd5b --- /dev/null +++ b/paddle/phi/kernels/box_coder_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void BoxCoderKernel(const Context& dev_ctx, + const DenseTensor& prior_box, + const paddle::optional& prior_box_var, + const DenseTensor& target_box, + const std::string& code_type, + bool box_normalized, + int axis, + const std::vector& variance, + DenseTensor* output_box); +} // namespace phi diff --git a/paddle/phi/kernels/cpu/box_coder.cc b/paddle/phi/kernels/cpu/box_coder.cc new file mode 100644 index 0000000000..81380e97d9 --- /dev/null +++ b/paddle/phi/kernels/cpu/box_coder.cc @@ -0,0 +1,281 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/box_coder_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/impl/box_coder.h" + +namespace phi { + +template +void EncodeCenterSize(const DenseTensor *target_box, + const DenseTensor *prior_box, + const DenseTensor *prior_box_var, + const bool normalized, + const std::vector variance, + T *output) { + int64_t row = target_box->dims()[0]; + int64_t col = prior_box->dims()[0]; + int64_t len = prior_box->dims()[1]; + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(2) +#endif + for (int64_t i = 0; i < row; ++i) { + for (int64_t j = 0; j < col; ++j) { + auto *target_box_data = target_box->data(); + auto *prior_box_data = prior_box->data(); + size_t offset = i * col * len + j * len; + T prior_box_width = prior_box_data[j * len + 2] - + prior_box_data[j * len] + (normalized == false); + T prior_box_height = prior_box_data[j * len + 3] - + prior_box_data[j * len + 1] + (normalized == false); + T prior_box_center_x = prior_box_data[j * len] + prior_box_width / 2; + T prior_box_center_y = prior_box_data[j * len + 1] + prior_box_height / 2; + + T target_box_center_x = + (target_box_data[i * len + 2] + target_box_data[i * len]) / 2; + T target_box_center_y = + (target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2; + T target_box_width = target_box_data[i * len + 2] - + target_box_data[i * len] + (normalized == false); + T target_box_height = target_box_data[i * len + 3] - + target_box_data[i * len + 1] + + (normalized == false); + + output[offset] = + (target_box_center_x - prior_box_center_x) / prior_box_width; + output[offset + 1] = + (target_box_center_y - prior_box_center_y) / prior_box_height; + output[offset + 2] = + std::log(std::fabs(target_box_width / prior_box_width)); + output[offset + 3] = + std::log(std::fabs(target_box_height / prior_box_height)); + } + } + + if (prior_box_var) { + const T *prior_box_var_data = prior_box_var->data(); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(3) +#endif + for (int64_t i = 0; i < row; ++i) { + for (int64_t j = 0; j < col; ++j) { + for (int k = 0; k < 4; ++k) { + size_t offset = i * col * len + j * len; + int prior_var_offset = j * len; + output[offset + k] /= prior_box_var_data[prior_var_offset + k]; + } + } + } + } else if (!(variance.empty())) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(3) +#endif + for (int64_t i = 0; i < row; ++i) { + for (int64_t j = 0; j < col; ++j) { + for (int k = 0; k < 4; ++k) { + size_t offset = i * col * len + j * len; + output[offset + k] /= static_cast(variance[k]); + } + } + } + } +} + +template +void DecodeCenterSize(const DenseTensor *target_box, + const DenseTensor *prior_box, + const DenseTensor *prior_box_var, + const bool normalized, + std::vector variance, + T *output) { + int64_t row = target_box->dims()[0]; + int64_t col = target_box->dims()[1]; + int64_t len = target_box->dims()[2]; + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(2) +#endif + for (int64_t i = 0; i < row; ++i) { + for (int64_t j = 0; j < col; ++j) { + auto *target_box_data = target_box->data(); + auto *prior_box_data = prior_box->data(); + + T var_data[4] = {1., 1., 1., 1.}; + T *var_ptr = var_data; + size_t offset = i * col * len + j * len; + int prior_box_offset = axis == 0 ? j * len : i * len; + + T prior_box_width = prior_box_data[prior_box_offset + 2] - + prior_box_data[prior_box_offset] + + (normalized == false); + T prior_box_height = prior_box_data[prior_box_offset + 3] - + prior_box_data[prior_box_offset + 1] + + (normalized == false); + T prior_box_center_x = + prior_box_data[prior_box_offset] + prior_box_width / 2; + T prior_box_center_y = + prior_box_data[prior_box_offset + 1] + prior_box_height / 2; + + T target_box_center_x = 0, target_box_center_y = 0; + T target_box_width = 0, target_box_height = 0; + int prior_var_offset = axis == 0 ? j * len : i * len; + if (var_size == 2) { + std::memcpy(var_ptr, + prior_box_var->data() + prior_var_offset, + 4 * sizeof(T)); + } else if (var_size == 1) { + var_ptr = reinterpret_cast(variance.data()); + } + T box_var_x = *var_ptr; + T box_var_y = *(var_ptr + 1); + T box_var_w = *(var_ptr + 2); + T box_var_h = *(var_ptr + 3); + + target_box_center_x = + box_var_x * target_box_data[offset] * prior_box_width + + prior_box_center_x; + target_box_center_y = + box_var_y * target_box_data[offset + 1] * prior_box_height + + prior_box_center_y; + target_box_width = + std::exp(box_var_w * target_box_data[offset + 2]) * prior_box_width; + target_box_height = + std::exp(box_var_h * target_box_data[offset + 3]) * prior_box_height; + + output[offset] = target_box_center_x - target_box_width / 2; + output[offset + 1] = target_box_center_y - target_box_height / 2; + output[offset + 2] = + target_box_center_x + target_box_width / 2 - (normalized == false); + output[offset + 3] = + target_box_center_y + target_box_height / 2 - (normalized == false); + } + } +} + +template +void BoxCoderKernel(const Context &dev_ctx, + const DenseTensor &prior_box, + const paddle::optional &prior_box_var, + const DenseTensor &target_box, + const std::string &code_type_str, + bool normalized, + int axis, + const std::vector &variance, + DenseTensor *output_box) { + if (target_box.lod().size()) { + PADDLE_ENFORCE_EQ(target_box.lod().size(), + 1UL, + phi::errors::InvalidArgument( + "Input(TargetBox) of BoxCoder operator " + "supports LoD with only one level. But received " + "level = %d", + target_box.lod().size())); + } + if (prior_box_var) { + PADDLE_ENFORCE_EQ(variance.empty(), + true, + phi::errors::InvalidArgument( + "Input 'PriorBoxVar' and attribute 'variance' " + "of BoxCoder operator should not be used at the " + "same time.")); + } + if (!(variance.empty())) { + PADDLE_ENFORCE_EQ( + static_cast(variance.size()), + 4, + phi::errors::InvalidArgument("Size of attribute 'variance' of BoxCoder " + "operator should be 4. But received " + "size = %d", + variance.size())); + } + + auto code_type = phi::funcs::GetBoxCodeType(code_type_str); + auto row = target_box.dims()[0]; + auto col = prior_box.dims()[0]; + if (code_type == phi::funcs::BoxCodeType::kDecodeCenterSize) { + col = target_box.dims()[1]; + } + auto len = prior_box.dims()[1]; + output_box->Resize({row, col, len}); + dev_ctx.template Alloc(output_box); + T *output = output_box->data(); + if (code_type == phi::funcs::BoxCodeType::kEncodeCenterSize) { + EncodeCenterSize(&target_box, + &prior_box, + prior_box_var.get_ptr(), + normalized, + variance, + output); + } else if (code_type == phi::funcs::BoxCodeType::kDecodeCenterSize) { + if (prior_box_var) { + if (axis == 0) { + DecodeCenterSize(&target_box, + &prior_box, + prior_box_var.get_ptr(), + normalized, + variance, + output); + } else { + DecodeCenterSize(&target_box, + &prior_box, + prior_box_var.get_ptr(), + normalized, + variance, + output); + } + } else if (!(variance.empty())) { + if (axis == 0) { + DecodeCenterSize(&target_box, + &prior_box, + prior_box_var.get_ptr(), + normalized, + variance, + output); + } else { + DecodeCenterSize(&target_box, + &prior_box, + prior_box_var.get_ptr(), + normalized, + variance, + output); + } + } else { + if (axis == 0) { + DecodeCenterSize(&target_box, + &prior_box, + prior_box_var.get_ptr(), + normalized, + variance, + output); + } else { + DecodeCenterSize(&target_box, + &prior_box, + prior_box_var.get_ptr(), + normalized, + variance, + output); + } + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + box_coder, CPU, ALL_LAYOUT, phi::BoxCoderKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/box_coder.cu b/paddle/phi/kernels/gpu/box_coder.cu new file mode 100644 index 0000000000..e72c5f9cee --- /dev/null +++ b/paddle/phi/kernels/gpu/box_coder.cu @@ -0,0 +1,246 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/box_coder_kernel.h" + +#include +#include + +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/box_coder.h" + +namespace phi { + +template +__global__ void EncodeCenterSizeKernel(const T *prior_box_data, + const T *prior_box_var_data, + const T *target_box_data, + const int row, + const int col, + const int len, + const bool normalized, + const T prior_box_var_size, + const float *variance, + const int var_size, + T *output) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < row * col) { + const int row_idx = idx / col; + const int col_idx = idx % col; + T prior_box_width = prior_box_data[col_idx * len + 2] - + prior_box_data[col_idx * len] + (normalized == false); + T prior_box_height = prior_box_data[col_idx * len + 3] - + prior_box_data[col_idx * len + 1] + + (normalized == false); + T prior_box_center_x = prior_box_data[col_idx * len] + prior_box_width / 2; + T prior_box_center_y = + prior_box_data[col_idx * len + 1] + prior_box_height / 2; + + T target_box_center_x = + (target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) / + 2; + T target_box_center_y = (target_box_data[row_idx * len + 3] + + target_box_data[row_idx * len + 1]) / + 2; + T target_box_width = target_box_data[row_idx * len + 2] - + target_box_data[row_idx * len] + (normalized == false); + T target_box_height = target_box_data[row_idx * len + 3] - + target_box_data[row_idx * len + 1] + + (normalized == false); + + output[idx * len] = + (target_box_center_x - prior_box_center_x) / prior_box_width; + output[idx * len + 1] = + (target_box_center_y - prior_box_center_y) / prior_box_height; + output[idx * len + 2] = log(fabs(target_box_width / prior_box_width)); + output[idx * len + 3] = log(fabs(target_box_height / prior_box_height)); + if (prior_box_var_data) { + int prior_var_offset = col_idx * len; + output[idx * len] /= prior_box_var_data[prior_var_offset]; + output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1]; + output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2]; + output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3]; + } else if (var_size == 4) { + for (int k = 0; k < 4; ++k) { + output[idx * len + k] /= static_cast(variance[k]); + } + } + } +} + +template +__global__ void DecodeCenterSizeKernel(const T *prior_box_data, + const T *prior_box_var_data, + const T *target_box_data, + const int row, + const int col, + const int len, + const bool normalized, + const T prior_box_var_size, + const float *variance, + const int var_size, + const int axis, + T *output) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + int prior_box_offset = 0; + if (idx < row * col) { + const int col_idx = idx % col; + const int row_idx = idx / col; + prior_box_offset = axis == 0 ? col_idx * len : row_idx * len; + T prior_box_width = prior_box_data[prior_box_offset + 2] - + prior_box_data[prior_box_offset] + + (normalized == false); + T prior_box_height = prior_box_data[prior_box_offset + 3] - + prior_box_data[prior_box_offset + 1] + + (normalized == false); + T prior_box_center_x = + prior_box_data[prior_box_offset] + prior_box_width / 2; + T prior_box_center_y = + prior_box_data[prior_box_offset + 1] + prior_box_height / 2; + T target_box_width, target_box_height; + T target_box_center_x, target_box_center_y; + T box_var_x = T(1), box_var_y = T(1); + T box_var_w = T(1), box_var_h = T(1); + if (prior_box_var_data) { + int prior_var_offset = axis == 0 ? col_idx * len : row_idx * len; + box_var_x = prior_box_var_data[prior_var_offset]; + box_var_y = prior_box_var_data[prior_var_offset + 1]; + box_var_w = prior_box_var_data[prior_var_offset + 2]; + box_var_h = prior_box_var_data[prior_var_offset + 3]; + } else if (var_size == 4) { + box_var_x = static_cast(variance[0]); + box_var_y = static_cast(variance[1]); + box_var_w = static_cast(variance[2]); + box_var_h = static_cast(variance[3]); + } + target_box_width = + exp(box_var_w * target_box_data[idx * len + 2]) * prior_box_width; + target_box_height = + exp(box_var_h * target_box_data[idx * len + 3]) * prior_box_height; + target_box_center_x = + box_var_x * target_box_data[idx * len] * prior_box_width + + prior_box_center_x; + target_box_center_y = + box_var_y * target_box_data[idx * len + 1] * prior_box_height + + prior_box_center_y; + + output[idx * len] = target_box_center_x - target_box_width / 2; + output[idx * len + 1] = target_box_center_y - target_box_height / 2; + output[idx * len + 2] = + target_box_center_x + target_box_width / 2 - (normalized == false); + output[idx * len + 3] = + target_box_center_y + target_box_height / 2 - (normalized == false); + } +} + +template +void BoxCoderKernel(const Context &dev_ctx, + const DenseTensor &prior_box, + const paddle::optional &prior_box_var, + const DenseTensor &target_box, + const std::string &code_type_str, + bool normalized, + int axis, + const std::vector &variance, + DenseTensor *output_box) { + const T *prior_box_data = prior_box.template data(); + const T *target_box_data = target_box.template data(); + const T *prior_box_var_data = nullptr; + auto prior_box_var_size = 0; + if (prior_box_var) { + PADDLE_ENFORCE_EQ(variance.empty(), + true, + phi::errors::InvalidArgument( + "Input 'PriorBoxVar' and attribute 'variance'" + " of BoxCoder operator should not be used at the " + "same time.")); + prior_box_var_data = prior_box_var->data(); + prior_box_var_size = prior_box_var->dims().size(); + } + if (!(variance.empty())) { + PADDLE_ENFORCE_EQ(static_cast(variance.size()), + 4, + phi::errors::InvalidArgument( + "Size of attribute 'variance' in BoxCoder operator" + " should be 4. But received size is %d", + variance.size())); + } + + if (target_box.lod().size()) { + PADDLE_ENFORCE_EQ(target_box.lod().size(), + 1, + phi::errors::InvalidArgument( + "Input 'TargetBox' of BoxCoder operator only" + " supports LoD with one level.")); + } + const int var_size = static_cast(variance.size()); + auto code_type = phi::funcs::GetBoxCodeType(code_type_str); + auto row = target_box.dims()[0]; + auto col = prior_box.dims()[0]; + if (code_type == phi::funcs::BoxCodeType::kDecodeCenterSize) { + col = target_box.dims()[1]; + } + auto len = prior_box.dims()[1]; + int block = 512; + int grid = (row * col + block - 1) / block; + + int bytes = var_size * sizeof(float); + auto dev_var = paddle::memory::Alloc(dev_ctx, bytes); + float *dev_var_data = reinterpret_cast(dev_var->ptr()); + auto cplace = phi::CPUPlace(); + const auto gplace = dev_ctx.GetPlace(); + paddle::memory::Copy( + gplace, dev_var_data, cplace, &variance[0], bytes, dev_ctx.stream()); + + output_box->Resize({row, col, len}); + dev_ctx.template Alloc(output_box); + T *output = output_box->data(); + + if (code_type == phi::funcs::BoxCodeType::kEncodeCenterSize) { + EncodeCenterSizeKernel + <<>>(prior_box_data, + prior_box_var_data, + target_box_data, + row, + col, + len, + normalized, + prior_box_var_size, + dev_var_data, + var_size, + output); + } else if (code_type == phi::funcs::BoxCodeType::kDecodeCenterSize) { + DecodeCenterSizeKernel + <<>>(prior_box_data, + prior_box_var_data, + target_box_data, + row, + col, + len, + normalized, + prior_box_var_size, + dev_var_data, + var_size, + axis, + output); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + box_coder, GPU, ALL_LAYOUT, phi::BoxCoderKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/box_coder.h b/paddle/phi/kernels/impl/box_coder.h new file mode 100644 index 0000000000..739293ef54 --- /dev/null +++ b/paddle/phi/kernels/impl/box_coder.h @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/enforce.h" + +namespace phi { +namespace funcs { + +enum class BoxCodeType { kEncodeCenterSize = 0, kDecodeCenterSize = 1 }; + +inline BoxCodeType GetBoxCodeType(const std::string &type) { + PADDLE_ENFORCE_EQ( + (type == "encode_center_size") || (type == "decode_center_size"), + true, + phi::errors::InvalidArgument( + "The 'code_type' attribute in BoxCoder" + " must be 'encode_center_size' or 'decode_center_size'. " + "But received 'code_type' is %s", + type)); + if (type == "encode_center_size") { + return BoxCodeType::kEncodeCenterSize; + } else { + return BoxCodeType::kDecodeCenterSize; + } +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/ops/compat/box_coder_sig.cc b/paddle/phi/ops/compat/box_coder_sig.cc new file mode 100644 index 0000000000..5b674f3dcd --- /dev/null +++ b/paddle/phi/ops/compat/box_coder_sig.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature BoxCoderOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("box_coder", + {"PriorBox", "PriorBoxVar", "TargetBox"}, + {"code_type", "box_normalized", "axis", "variance"}, + {"OutputBox"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(box_coder, phi::BoxCoderOpArgumentMapping); diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index ddcc1db84b..f8691cc156 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -37,6 +37,7 @@ from functools import reduce from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from paddle.utils import deprecated from paddle import _C_ops +from ..framework import in_dygraph_mode __all__ = [ 'prior_box', @@ -948,6 +949,22 @@ def box_coder(prior_box, 'box_coder') check_variable_and_dtype(target_box, 'target_box', ['float32', 'float64'], 'box_coder') + if in_dygraph_mode(): + if isinstance(prior_box_var, Variable): + box_coder_op = _C_ops.final_state_box_coder(prior_box, + prior_box_var, + target_box, code_type, + box_normalized, axis, + []) + elif isinstance(prior_box_var, list): + box_coder_op = _C_ops.final_state_box_coder(prior_box, None, + target_box, code_type, + box_normalized, axis, + prior_box_var) + else: + raise TypeError( + "Input variance of box_coder must be Variable or lisz") + return box_coder_op helper = LayerHelper("box_coder", **locals()) output_box = helper.create_variable_for_type_inference( diff --git a/python/paddle/fluid/tests/unittests/test_box_coder_op.py b/python/paddle/fluid/tests/unittests/test_box_coder_op.py index ee064963b2..4d18d0a2a1 100644 --- a/python/paddle/fluid/tests/unittests/test_box_coder_op.py +++ b/python/paddle/fluid/tests/unittests/test_box_coder_op.py @@ -19,6 +19,8 @@ import numpy as np import sys import math from op_test import OpTest +import paddle +import paddle.fluid.core as core def box_decoder(t_box, p_box, pb_v, output_box, norm, axis=0): @@ -105,10 +107,11 @@ def batch_box_coder(p_box, pb_v, t_box, lod, code_type, norm, axis=0): class TestBoxCoderOp(OpTest): def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def setUp(self): self.op_type = "box_coder" + self.python_api = paddle.fluid.layers.box_coder lod = [[1, 1, 1, 1, 1]] prior_box = np.random.random((81, 4)).astype('float32') prior_box_var = np.random.random((81, 4)).astype('float32') @@ -132,9 +135,10 @@ class TestBoxCoderOp(OpTest): class TestBoxCoderOpWithoutBoxVar(OpTest): def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def setUp(self): + self.python_api = paddle.fluid.layers.box_coder self.op_type = "box_coder" lod = [[0, 1, 2, 3, 4, 5]] prior_box = np.random.random((81, 4)).astype('float32') @@ -147,6 +151,7 @@ class TestBoxCoderOpWithoutBoxVar(OpTest): self.inputs = { 'PriorBox': prior_box, + 'PriorBoxVar': prior_box_var, 'TargetBox': target_box, } self.attrs = { @@ -159,9 +164,10 @@ class TestBoxCoderOpWithoutBoxVar(OpTest): class TestBoxCoderOpWithLoD(OpTest): def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def setUp(self): + self.python_api = paddle.fluid.layers.box_coder self.op_type = "box_coder" lod = [[10, 20, 20]] prior_box = np.random.random((20, 4)).astype('float32') @@ -184,9 +190,10 @@ class TestBoxCoderOpWithLoD(OpTest): class TestBoxCoderOpWithAxis(OpTest): def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def setUp(self): + self.python_api = paddle.fluid.layers.box_coder self.op_type = "box_coder" lod = [[1, 1, 1, 1, 1]] prior_box = np.random.random((30, 4)).astype('float32') @@ -241,5 +248,43 @@ class TestBoxCoderOpWithVariance(OpTest): self.outputs = {'OutputBox': output_box} +class TestBoxCoderOpWithVarianceDygraphAPI(unittest.TestCase): + + def setUp(self): + self.lod = [[1, 1, 1, 1, 1]] + self.prior_box = np.random.random((30, 4)).astype('float32') + self.prior_box_var = np.random.random((4)).astype('float32') + self.target_box = np.random.random((30, 81, 4)).astype('float32') + self.code_type = "DecodeCenterSize" + self.box_normalized = False + self.axis = 1 + self.output_ref = batch_box_coder(self.prior_box, self.prior_box_var, + self.target_box, self.lod[0], + self.code_type, self.box_normalized, + self.axis) + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def test_dygraph_api(self): + + def run(place): + paddle.disable_static(place) + output_box = paddle.fluid.layers.box_coder( + paddle.to_tensor(self.prior_box), + self.prior_box_var.tolist(), + paddle.to_tensor(self.target_box), + "decode_center_size", + self.box_normalized, + axis=self.axis) + self.assertEqual( + np.allclose(np.sum(self.output_ref), + np.sum(output_box.numpy())), True) + paddle.enable_static() + + for place in self.place: + run(place) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py b/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py index 452b0ac542..1c28393f33 100644 --- a/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py @@ -50,7 +50,8 @@ class TestL2LossOp(OpTest): def test_check_grad(self): self.check_grad(['X'], 'Out', - max_relative_error=self.max_relative_error) + max_relative_error=self.max_relative_error, + check_eager=True) class TestL2LossDeterministic(unittest.TestCase): -- GitLab