diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index a6f2e03f14fd02f14acf21f000db251216f8f380..03733e34ec670b3467c535eb887ea5995a630122 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/detection/prior_box_op.h" - #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/infermeta/binary.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" @@ -28,79 +29,6 @@ class PriorBoxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "PriorBoxOp"); - OP_INOUT_CHECK(ctx->HasInput("Image"), "Input", "Image", "PriorBoxOp"); - - auto image_dims = ctx->GetInputDim("Image"); - auto input_dims = ctx->GetInputDim("Input"); - - PADDLE_ENFORCE_EQ( - image_dims.size(), - 4, - platform::errors::InvalidArgument( - "The Input(Image) of Op(PriorBoxOp) should be a 4-D Tensor " - "and data format is NCHW. But received Image's dimensions = %d, " - "shape = [%s].", - image_dims.size(), - image_dims)); - PADDLE_ENFORCE_EQ( - input_dims.size(), - 4, - platform::errors::InvalidArgument( - "The Input(Input) of Op(PriorBoxOp) should be a 4-D Tensor " - "and data format is NCHW. But received Input's dimensions = %d, " - "shape = [%s].", - input_dims.size(), - input_dims)); - - auto min_sizes = ctx->Attrs().Get>("min_sizes"); - auto max_sizes = ctx->Attrs().Get>("max_sizes"); - auto variances = ctx->Attrs().Get>("variances"); - auto aspect_ratios = ctx->Attrs().Get>("aspect_ratios"); - bool flip = ctx->Attrs().Get("flip"); - - std::vector aspect_ratios_vec; - ExpandAspectRatios(aspect_ratios, flip, &aspect_ratios_vec); - - size_t num_priors = aspect_ratios_vec.size() * min_sizes.size(); - if (max_sizes.size() > 0) { - PADDLE_ENFORCE_EQ( - max_sizes.size(), - min_sizes.size(), - platform::errors::InvalidArgument( - "The length of min_size and " - "max_size must be equal. But received: min_size's length is %d, " - "max_size's length is %d.", - min_sizes.size(), - max_sizes.size())); - num_priors += max_sizes.size(); - for (size_t i = 0; i < max_sizes.size(); ++i) { - PADDLE_ENFORCE_GT( - max_sizes[i], - min_sizes[i], - platform::errors::InvalidArgument( - "max_size[%d] must be greater " - "than min_size[%d]. But received: max_size[%d] is %f, " - "min_size[%d] is %f.", - i, - i, - i, - max_sizes[i], - i, - min_sizes[i])); - } - } - - std::vector dim_vec(4); - dim_vec[0] = input_dims[2]; - dim_vec[1] = input_dims[3]; - dim_vec[2] = num_priors; - dim_vec[3] = 4; - ctx->SetOutputDim("Boxes", phi::make_ddim(dim_vec)); - ctx->SetOutputDim("Variances", phi::make_ddim(dim_vec)); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -274,17 +202,18 @@ https://arxiv.org/abs/1512.02325. } // namespace operators } // namespace paddle +DECLARE_INFER_SHAPE_FUNCTOR(prior_box, + PriorBoxInferShapeFunctor, + PD_INFER_META(phi::PriorBoxInferMeta)); + namespace ops = paddle::operators; REGISTER_OPERATOR( prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); - -REGISTER_OP_CPU_KERNEL(prior_box, - ops::PriorBoxOpKernel, - ops::PriorBoxOpKernel); + paddle::framework::EmptyGradOpMaker, + PriorBoxInferShapeFunctor); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, MKLDNN, diff --git a/paddle/fluid/operators/detection/prior_box_op.cu b/paddle/fluid/operators/detection/prior_box_op.cu index 5d24322f12c30b13e24821c75db7b3abc284553f..1808806714774f7444c687ef08d9ee7f0e15cc2e 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cu +++ b/paddle/fluid/operators/detection/prior_box_op.cu @@ -194,8 +194,3 @@ class PriorBoxOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(prior_box, - ops::PriorBoxOpCUDAKernel, - ops::PriorBoxOpCUDAKernel); diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index 9d5b630bcd498b0ff781570a0adc3c04ac09da62..88903763d805d79f409302176fc1e3a893ce6f81 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -141,7 +141,7 @@ class BaseAPI(object): 'DataLayout': 'DataLayout', 'DataType': 'DataType', 'int64_t[]': 'const std::vector&', - 'int[]': 'const std::vector&' + 'int[]': 'const std::vector&', } optional_types_trans = { 'Tensor': 'const paddle::optional&', diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index ff80469eedaf945e3327e2325ec9614a7b12f8b4..ee1639c888b7e8a44e87d7203c5951abdeebe934 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1791,6 +1791,14 @@ func : prelu backward : prelu_grad +- api : prior_box + args : (Tensor input, Tensor image, float[] min_sizes, float[] aspect_ratios, float[] variances, float[] max_sizes = {}, bool flip=true, bool clip=true, float step_w=0.0, float step_h=0.0, float offset=0.5, bool min_max_aspect_ratios_order=false) + output : Tensor(out), Tensor(var) + infer_meta : + func : PriorBoxInferMeta + kernel : + func : prior_box + - api : psroi_pool args : (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, int output_channels, float spatial_scale) output : Tensor diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 4d72e1b60d6ee10ab57ab8a64140ab411c63340b..5c536cb3546e6c80db52e719dab24d93d55beb57 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1809,6 +1809,110 @@ void PReluInferMeta(const MetaTensor& x, out->share_lod(x); } +inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, + bool flip, + std::vector* output_aspect_ratior) { + constexpr float epsilon = 1e-6; + output_aspect_ratior->clear(); + output_aspect_ratior->push_back(1.0f); + for (size_t i = 0; i < input_aspect_ratior.size(); ++i) { + float ar = input_aspect_ratior[i]; + bool already_exist = false; + for (size_t j = 0; j < output_aspect_ratior->size(); ++j) { + if (fabs(ar - output_aspect_ratior->at(j)) < epsilon) { + already_exist = true; + break; + } + } + if (!already_exist) { + output_aspect_ratior->push_back(ar); + if (flip) { + output_aspect_ratior->push_back(1.0f / ar); + } + } + } +} + +void PriorBoxInferMeta(const MetaTensor& input, + const MetaTensor& image, + const std::vector& min_sizes, + const std::vector& aspect_ratios, + const std::vector& variances, + const std::vector& max_sizes, + bool flip, + bool clip, + float step_w, + float step_h, + float offset, + bool min_max_aspect_ratios_order, + MetaTensor* out, + MetaTensor* var) { + auto image_dims = image.dims(); + auto input_dims = input.dims(); + + PADDLE_ENFORCE_EQ( + image_dims.size(), + 4, + phi::errors::InvalidArgument( + "The Input(Image) of Op(PriorBoxOp) should be a 4-D Tensor " + "and data format is NCHW. But received Image's dimensions = %d, " + "shape = [%s].", + image_dims.size(), + image_dims)); + PADDLE_ENFORCE_EQ( + input_dims.size(), + 4, + phi::errors::InvalidArgument( + "The Input(Input) of Op(PriorBoxOp) should be a 4-D Tensor " + "and data format is NCHW. But received Input's dimensions = %d, " + "shape = [%s].", + input_dims.size(), + input_dims)); + + std::vector aspect_ratios_vec; + ExpandAspectRatios(aspect_ratios, flip, &aspect_ratios_vec); + + size_t num_priors = aspect_ratios_vec.size() * min_sizes.size(); + if (max_sizes.size() > 0) { + PADDLE_ENFORCE_EQ( + max_sizes.size(), + min_sizes.size(), + phi::errors::InvalidArgument( + "The length of min_size and " + "max_size must be equal. But received: min_size's length is %d, " + "max_size's length is %d.", + min_sizes.size(), + max_sizes.size())); + num_priors += max_sizes.size(); + for (size_t i = 0; i < max_sizes.size(); ++i) { + PADDLE_ENFORCE_GT( + max_sizes[i], + min_sizes[i], + phi::errors::InvalidArgument( + "max_size[%d] must be greater " + "than min_size[%d]. But received: max_size[%d] is %f, " + "min_size[%d] is %f.", + i, + i, + i, + max_sizes[i], + i, + min_sizes[i])); + } + } + + std::vector dim_vec(4); + dim_vec[0] = input_dims[2]; + dim_vec[1] = input_dims[3]; + dim_vec[2] = num_priors; + dim_vec[3] = 4; + + out->set_dtype(input.dtype()); + var->set_dtype(input.dtype()); + out->set_dims(phi::make_ddim(dim_vec)); + var->set_dims(phi::make_ddim(dim_vec)); +} + void SearchsortedInferMeta(const MetaTensor& sorted_sequence, const MetaTensor& value, bool out_int32, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 53d6c12e8862fb3a76e0184fba72ec8b545a2e09..aaadce9e1eec9933e3c904aa153688332ea0372b 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -256,6 +256,21 @@ void PReluInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void PriorBoxInferMeta(const MetaTensor& input, + const MetaTensor& image, + const std::vector& min_sizes, + const std::vector& aspect_ratios, + const std::vector& variances, + const std::vector& max_sizes, + bool flip, + bool clip, + float step_w, + float step_h, + float offset, + bool min_max_aspect_ratios_order, + MetaTensor* out, + MetaTensor* var); + void SearchsortedInferMeta(const MetaTensor& sorted_sequence, const MetaTensor& value, bool out_int32, diff --git a/paddle/phi/kernels/cpu/prior_box_kernel.cc b/paddle/phi/kernels/cpu/prior_box_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..018b18006d3a7039066aea3ecb3f413a90f9226b --- /dev/null +++ b/paddle/phi/kernels/cpu/prior_box_kernel.cc @@ -0,0 +1,173 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/prior_box_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void PriorBoxKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& image, + const std::vector& min_sizes, + const std::vector& aspect_ratios, + const std::vector& variances, + const std::vector& max_sizes, + bool flip, + bool clip, + float step_w, + float step_h, + float offset, + bool min_max_aspect_ratios_order, + DenseTensor* out, + DenseTensor* var) { + std::vector new_aspect_ratios; + ExpandAspectRatios(aspect_ratios, flip, &new_aspect_ratios); + + T new_step_w = static_cast(step_w); + T new_step_h = static_cast(step_h); + T new_offset = static_cast(offset); + + auto img_width = image.dims()[3]; + auto img_height = image.dims()[2]; + + auto feature_width = input.dims()[3]; + auto feature_height = input.dims()[2]; + + T step_width, step_height; + if (new_step_w == 0 || new_step_h == 0) { + step_width = static_cast(img_width) / feature_width; + step_height = static_cast(img_height) / feature_height; + } else { + step_width = new_step_w; + step_height = new_step_h; + } + + int num_priors = new_aspect_ratios.size() * min_sizes.size(); + if (max_sizes.size() > 0) { + num_priors += max_sizes.size(); + } + + ctx.template Alloc(out); + ctx.template Alloc(var); + + T* b_t = out->data(); + for (int h = 0; h < feature_height; ++h) { + for (int w = 0; w < feature_width; ++w) { + T center_x = (w + new_offset) * step_width; + T center_y = (h + new_offset) * step_height; + T box_width, box_height; + for (size_t s = 0; s < min_sizes.size(); ++s) { + auto min_size = min_sizes[s]; + if (min_max_aspect_ratios_order) { + box_width = box_height = min_size / 2.; + b_t[0] = (center_x - box_width) / img_width; + b_t[1] = (center_y - box_height) / img_height; + b_t[2] = (center_x + box_width) / img_width; + b_t[3] = (center_y + box_height) / img_height; + b_t += 4; + if (max_sizes.size() > 0) { + auto max_size = max_sizes[s]; + // square prior with size sqrt(minSize * maxSize) + box_width = box_height = sqrt(min_size * max_size) / 2.; + b_t[0] = (center_x - box_width) / img_width; + b_t[1] = (center_y - box_height) / img_height; + b_t[2] = (center_x + box_width) / img_width; + b_t[3] = (center_y + box_height) / img_height; + b_t += 4; + } + // priors with different aspect ratios + for (size_t r = 0; r < new_aspect_ratios.size(); ++r) { + float ar = new_aspect_ratios[r]; + if (fabs(ar - 1.) < 1e-6) { + continue; + } + box_width = min_size * sqrt(ar) / 2.; + box_height = min_size / sqrt(ar) / 2.; + b_t[0] = (center_x - box_width) / img_width; + b_t[1] = (center_y - box_height) / img_height; + b_t[2] = (center_x + box_width) / img_width; + b_t[3] = (center_y + box_height) / img_height; + b_t += 4; + } + } else { + // priors with different aspect ratios + for (size_t r = 0; r < new_aspect_ratios.size(); ++r) { + float ar = new_aspect_ratios[r]; + box_width = min_size * sqrt(ar) / 2.; + box_height = min_size / sqrt(ar) / 2.; + b_t[0] = (center_x - box_width) / img_width; + b_t[1] = (center_y - box_height) / img_height; + b_t[2] = (center_x + box_width) / img_width; + b_t[3] = (center_y + box_height) / img_height; + b_t += 4; + } + if (max_sizes.size() > 0) { + auto max_size = max_sizes[s]; + // square prior with size sqrt(minSize * maxSize) + box_width = box_height = sqrt(min_size * max_size) / 2.; + b_t[0] = (center_x - box_width) / img_width; + b_t[1] = (center_y - box_height) / img_height; + b_t[2] = (center_x + box_width) / img_width; + b_t[3] = (center_y + box_height) / img_height; + b_t += 4; + } + } + } + } + } + + if (clip) { + T* dt = out->data(); + std::transform(dt, dt + out->numel(), dt, [](T v) -> T { + return std::min(std::max(v, 0.), 1.); + }); + } + + DenseTensor var_t; + var_t.Resize(phi::make_ddim({1, static_cast(variances.size())})); + ctx.template Alloc(&var_t); + auto var_et = EigenTensor::From(var_t); + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (size_t i = 0; i < variances.size(); ++i) { + var_et(0, i) = variances[i]; + } + + int box_num = feature_height * feature_width * num_priors; + auto var_dim = var->dims(); + var->Resize({box_num, static_cast(variances.size())}); + + auto e_vars = EigenMatrix::From(*var); + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(2) +#endif + for (int i = 0; i < box_num; ++i) { + for (size_t j = 0; j < variances.size(); ++j) { + e_vars(i, j) = variances[j]; + } + } + var->Resize(var_dim); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + prior_box, CPU, ALL_LAYOUT, phi::PriorBoxKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/prior_box_kernel.cu b/paddle/phi/kernels/gpu/prior_box_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..317f2a3231a642f5c2f1de64519c51ecc6b4c87a --- /dev/null +++ b/paddle/phi/kernels/gpu/prior_box_kernel.cu @@ -0,0 +1,201 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/prior_box_kernel.h" + +#include +#include + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +__device__ inline T clip(T in) { + return min(max(in, 0.), 1.); +} + +template +__global__ void GenPriorBox(T* out, + const T* aspect_ratios, + const int height, + const int width, + const int im_height, + const int im_width, + const int as_num, + const T offset, + const T step_width, + const T step_height, + const T* min_sizes, + const T* max_sizes, + const int min_num, + bool is_clip, + bool min_max_aspect_ratios_order) { + int num_priors = max_sizes ? as_num * min_num + min_num : as_num * min_num; + int box_num = height * width * num_priors; + CUDA_KERNEL_LOOP(i, box_num) { + int h = i / (num_priors * width); + int w = (i / num_priors) % width; + int p = i % num_priors; + int m = max_sizes ? p / (as_num + 1) : p / as_num; + T cx = (w + offset) * step_width; + T cy = (h + offset) * step_height; + T bw, bh; + T min_size = min_sizes[m]; + if (max_sizes) { + int s = p % (as_num + 1); + if (!min_max_aspect_ratios_order) { + if (s < as_num) { + T ar = aspect_ratios[s]; + bw = min_size * sqrt(ar) / 2.; + bh = min_size / sqrt(ar) / 2.; + } else { + T max_size = max_sizes[m]; + bw = sqrt(min_size * max_size) / 2.; + bh = bw; + } + } else { + if (s == 0) { + bw = bh = min_size / 2.; + } else if (s == 1) { + T max_size = max_sizes[m]; + bw = sqrt(min_size * max_size) / 2.; + bh = bw; + } else { + T ar = aspect_ratios[s - 1]; + bw = min_size * sqrt(ar) / 2.; + bh = min_size / sqrt(ar) / 2.; + } + } + } else { + int s = p % as_num; + T ar = aspect_ratios[s]; + bw = min_size * sqrt(ar) / 2.; + bh = min_size / sqrt(ar) / 2.; + } + T xmin = (cx - bw) / im_width; + T ymin = (cy - bh) / im_height; + T xmax = (cx + bw) / im_width; + T ymax = (cy + bh) / im_height; + out[i * 4] = is_clip ? clip(xmin) : xmin; + out[i * 4 + 1] = is_clip ? clip(ymin) : ymin; + out[i * 4 + 2] = is_clip ? clip(xmax) : xmax; + out[i * 4 + 3] = is_clip ? clip(ymax) : ymax; + } +} + +template +__global__ void SetVariance(T* out, + const T* var, + const int vnum, + const int num) { + CUDA_KERNEL_LOOP(i, num) { out[i] = var[i % vnum]; } +} + +template +void PriorBoxKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& image, + const std::vector& min_sizes, + const std::vector& aspect_ratios, + const std::vector& variances, + const std::vector& max_sizes, + bool flip, + bool clip, + float step_w, + float step_h, + float offset, + bool min_max_aspect_ratios_order, + DenseTensor* out, + DenseTensor* var) { + std::vector new_aspect_ratios; + ExpandAspectRatios(aspect_ratios, flip, &new_aspect_ratios); + + T new_step_w = static_cast(step_w); + T new_step_h = static_cast(step_h); + T new_offset = static_cast(offset); + + auto im_width = image.dims()[3]; + auto im_height = image.dims()[2]; + + auto width = input.dims()[3]; + auto height = input.dims()[2]; + + T step_width, step_height; + if (new_step_w == 0 || new_step_h == 0) { + step_width = static_cast(im_width) / width; + step_height = static_cast(im_height) / height; + } else { + step_width = new_step_w; + step_height = new_step_h; + } + + int num_priors = new_aspect_ratios.size() * min_sizes.size(); + if (max_sizes.size() > 0) { + num_priors += max_sizes.size(); + } + int min_num = static_cast(min_sizes.size()); + int box_num = width * height * num_priors; + + int block = 512; + int grid = (box_num + block - 1) / block; + + auto stream = ctx.stream(); + + ctx.template Alloc(out); + ctx.template Alloc(var); + + DenseTensor r; + paddle::framework::TensorFromVector(new_aspect_ratios, ctx, &r); + + DenseTensor min; + paddle::framework::TensorFromVector(min_sizes, ctx, &min); + + T* max_data = nullptr; + DenseTensor max; + if (max_sizes.size() > 0) { + paddle::framework::TensorFromVector(max_sizes, ctx, &max); + max_data = max.data(); + } + + GenPriorBox<<>>(out->data(), + r.data(), + height, + width, + im_height, + im_width, + new_aspect_ratios.size(), + new_offset, + step_width, + step_height, + min.data(), + max_data, + min_num, + clip, + min_max_aspect_ratios_order); + + DenseTensor v; + paddle::framework::TensorFromVector(variances, ctx, &v); + grid = (box_num * 4 + block - 1) / block; + SetVariance<<>>( + var->data(), v.data(), variances.size(), box_num * 4); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + prior_box, GPU, ALL_LAYOUT, phi::PriorBoxKernel, float, double) {} diff --git a/paddle/phi/kernels/prior_box_kernel.h b/paddle/phi/kernels/prior_box_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..7a25b7d8e6d4635efca70145cb5cead4139e424a --- /dev/null +++ b/paddle/phi/kernels/prior_box_kernel.h @@ -0,0 +1,62 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void PriorBoxKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& image, + const std::vector& min_sizes, + const std::vector& aspect_ratios, + const std::vector& variances, + const std::vector& max_sizes, + bool flip, + bool clip, + float step_w, + float step_h, + float offset, + bool min_max_aspect_ratios_order, + DenseTensor* out, + DenseTensor* var); + +inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, + bool flip, + std::vector* output_aspect_ratior) { + constexpr float epsilon = 1e-6; + output_aspect_ratior->clear(); + output_aspect_ratior->push_back(1.0f); + for (size_t i = 0; i < input_aspect_ratior.size(); ++i) { + float ar = input_aspect_ratior[i]; + bool already_exist = false; + for (size_t j = 0; j < output_aspect_ratior->size(); ++j) { + if (fabs(ar - output_aspect_ratior->at(j)) < epsilon) { + already_exist = true; + break; + } + } + if (!already_exist) { + output_aspect_ratior->push_back(ar); + if (flip) { + output_aspect_ratior->push_back(1.0f / ar); + } + } + } +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/prior_box_sig.cc b/paddle/phi/ops/compat/prior_box_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..5d4cd5164305f84867abd3edff7c0351dd180f6a --- /dev/null +++ b/paddle/phi/ops/compat/prior_box_sig.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature PriorBoxOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("prior_box", + {"Input", "Image"}, + {"min_sizes", + "aspect_ratios", + "variances", + "max_sizes", + "flip", + "clip", + "step_w", + "step_h", + "offset", + "min_max_aspect_ratios_order"}, + {"Boxes", "Variances"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(prior_box, phi::PriorBoxOpArgumentMapping); diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index f8691cc156f940e4dd81cf8058f7043896e88a48..3540f69c049739f1fe3c2c48abdac6e8db88a713 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -22,7 +22,7 @@ import paddle from .layer_function_generator import generate_layer_fn from .layer_function_generator import autodoc, templatedoc from ..layer_helper import LayerHelper -from ..framework import Variable, _non_static_mode, static_only +from ..framework import Variable, _non_static_mode, static_only, in_dygraph_mode from .. import core from .loss import softmax_with_cross_entropy from . import tensor @@ -1794,18 +1794,20 @@ def ssd_loss(location, return loss -def prior_box(input, - image, - min_sizes, - max_sizes=None, - aspect_ratios=[1.], - variance=[0.1, 0.1, 0.2, 0.2], - flip=False, - clip=False, - steps=[0.0, 0.0], - offset=0.5, - name=None, - min_max_aspect_ratios_order=False): +def prior_box( + input, + image, + min_sizes, + max_sizes=None, + aspect_ratios=[1.], + variance=[0.1, 0.1, 0.2, 0.2], + flip=False, + clip=False, + steps=[0.0, 0.0], + offset=0.5, + name=None, + min_max_aspect_ratios_order=False, +): """ This op generates prior boxes for SSD(Single Shot MultiBox Detector) algorithm. @@ -1905,6 +1907,15 @@ def prior_box(input, # [6L, 9L, 1L, 4L] """ + + if in_dygraph_mode(): + step_w, step_h = steps + if max_sizes == None: + max_sizes = [] + return _C_ops.final_state_prior_box(input, image, min_sizes, + aspect_ratios, variance, max_sizes, + flip, clip, step_w, step_h, offset, + min_max_aspect_ratios_order) helper = LayerHelper("prior_box", **locals()) dtype = helper.input_dtype() check_variable_and_dtype(input, 'input', diff --git a/python/paddle/fluid/tests/unittests/test_prior_box_op.py b/python/paddle/fluid/tests/unittests/test_prior_box_op.py index b0aaaec246f6763715b796a6c01133fd7feb339c..0b57e8d00f761ea7d207afd9b60dc250b927ea8d 100644 --- a/python/paddle/fluid/tests/unittests/test_prior_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_prior_box_op.py @@ -19,6 +19,35 @@ import numpy as np import sys import math from op_test import OpTest +import paddle + + +def python_prior_box(input, + image, + min_sizes, + aspect_ratios=[1.], + variances=[0.1, 0.1, 0.2, 0.2], + max_sizes=None, + flip=False, + clip=False, + step_w=0, + step_h=0, + offset=0.5, + min_max_aspect_ratios_order=False, + name=None): + return paddle.fluid.layers.detection.prior_box( + input, + image, + min_sizes=min_sizes, + max_sizes=max_sizes, + aspect_ratios=aspect_ratios, + variance=variances, + flip=flip, + clip=clip, + steps=[step_w, step_h], + offset=offset, + name=name, + min_max_aspect_ratios_order=min_max_aspect_ratios_order) class TestPriorBoxOp(OpTest): @@ -35,10 +64,10 @@ class TestPriorBoxOp(OpTest): 'variances': self.variances, 'flip': self.flip, 'clip': self.clip, - 'min_max_aspect_ratios_order': self.min_max_aspect_ratios_order, 'step_w': self.step_w, 'step_h': self.step_h, - 'offset': self.offset + 'offset': self.offset, + 'min_max_aspect_ratios_order': self.min_max_aspect_ratios_order, } if len(self.max_sizes) > 0: self.attrs['max_sizes'] = self.max_sizes @@ -46,10 +75,11 @@ class TestPriorBoxOp(OpTest): self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def setUp(self): self.op_type = "prior_box" + self.python_api = python_prior_box self.set_data() def set_max_sizes(self): @@ -191,4 +221,5 @@ class TestPriorBoxOpWithSpecifiedOutOrder(TestPriorBoxOp): if __name__ == '__main__': + paddle.enable_static() unittest.main()