diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 35c41e0dc93da0a367b8e98b4e4a4882bcea0822..233509fd9f94e07f9dab35b01356a207898902a4 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -46,7 +46,7 @@ if(WITH_XPU) elseif(WITH_MLU) detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_mlu.cc) - detection_library(prior_box_op SRCS prior_box_op.cc) + detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op_mlu.cc) detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op_mlu.cc) elseif(WITH_ASCEND_CL) detection_library(iou_similarity_op SRCS iou_similarity_op.cc diff --git a/paddle/fluid/operators/detection/prior_box_op_mlu.cc b/paddle/fluid/operators/detection/prior_box_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..04402f6ae200a18988737c3a172a9270751b44e0 --- /dev/null +++ b/paddle/fluid/operators/detection/prior_box_op_mlu.cc @@ -0,0 +1,104 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/detection/prior_box_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +template +class PriorBoxMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* image = ctx.Input("Image"); + auto* boxes = ctx.Output("Boxes"); + auto* variances = ctx.Output("Variances"); + float step_w = ctx.Attr("step_w"); + float step_h = ctx.Attr("step_h"); + float offset = ctx.Attr("offset"); + bool clip = ctx.Attr("clip"); + bool min_max_aspect_ratios_order = + ctx.Attr("min_max_aspect_ratios_order"); + + int im_width = image->dims()[3]; + int im_height = image->dims()[2]; + int width = input->dims()[3]; + int height = input->dims()[2]; + + auto aspect_ratios = ctx.Attr>("aspect_ratios"); + bool flip = ctx.Attr("flip"); + std::vector new_aspect_ratios; + ExpandAspectRatios(aspect_ratios, flip, &new_aspect_ratios); + auto& dev_ctx = ctx.template device_context(); + phi::DenseTensor ratios; + paddle::framework::TensorFromVector(new_aspect_ratios, dev_ctx, &ratios); + MLUOpTensorDesc new_aspect_ratios_desc(ratios); + + auto min_sizes = ctx.Attr>("min_sizes"); + phi::DenseTensor min; + paddle::framework::TensorFromVector(min_sizes, dev_ctx, &min); + MLUOpTensorDesc min_sizes_desc(min); + + auto max_sizes = ctx.Attr>("max_sizes"); + phi::DenseTensor max; + paddle::framework::TensorFromVector(max_sizes, dev_ctx, &max); + MLUOpTensorDesc max_sizes_desc(max); + + auto variances_attr = ctx.Attr>("variances"); + phi::DenseTensor var_tensor; + paddle::framework::TensorFromVector(variances_attr, dev_ctx, &var_tensor); + MLUOpTensorDesc variances_attr_desc(var_tensor); + + auto place = ctx.GetPlace(); + + boxes->mutable_data(place); + variances->mutable_data(place); + + MLUOpTensorDesc var_desc(*variances); + MLUOpTensorDesc output_desc(*boxes); + MLUOP::OpPriorBox(ctx, + min_sizes_desc.get(), + GetBasePtr(&min), + new_aspect_ratios_desc.get(), + GetBasePtr(&ratios), + variances_attr_desc.get(), + GetBasePtr(&var_tensor), + max_sizes_desc.get(), + GetBasePtr(&max), + height, + width, + im_height, + im_width, + step_h, + step_w, + offset, + clip, + min_max_aspect_ratios_order, + output_desc.get(), + GetBasePtr(boxes), + var_desc.get(), + GetBasePtr(variances)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(prior_box, ops::PriorBoxMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 03e5a49e28ac3d12c912a118a082524110f67e1d..04e3063dd70878b2146bbe809c9504efb7d013c7 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -5458,5 +5458,54 @@ MLURNNDesc::~MLURNNDesc() { scores)); } +/* static */ void MLUOP::OpPriorBox( + const ExecutionContext& ctx, + const mluOpTensorDescriptor_t min_sizes_desc, + const void* min_sizes, + const mluOpTensorDescriptor_t aspect_ratios_desc, + const void* aspect_ratios, + const mluOpTensorDescriptor_t variances_desc, + const void* variances, + const mluOpTensorDescriptor_t max_sizes_desc, + const void* max_sizes, + const int height, + const int width, + const int im_height, + const int im_width, + const float step_h, + const float step_w, + const float offset, + const bool clip, + const bool min_max_aspect_ratios_order, + const mluOpTensorDescriptor_t output_desc, + void* output, + const mluOpTensorDescriptor_t var_desc, + void* var) { + mluOpHandle_t handle = GetMLUOpHandleFromCTX(ctx); + + PADDLE_ENFORCE_MLU_SUCCESS(mluOpPriorBox(handle, + min_sizes_desc, + min_sizes, + aspect_ratios_desc, + aspect_ratios, + variances_desc, + variances, + max_sizes_desc, + max_sizes, + height, + width, + im_height, + im_width, + step_h, + step_w, + offset, + clip, + min_max_aspect_ratios_order, + output_desc, + output, + var_desc, + var)); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 354c5fe3f3d8e5871722d1aa7456f8e40f6fc08f..5ceee76a0270bf161544b07de968da7e87e61585 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -2312,6 +2312,29 @@ class MLUOP { void* boxes, const mluOpTensorDescriptor_t scores_desc, void* scores); + + static void OpPriorBox(const ExecutionContext& ctx, + const mluOpTensorDescriptor_t min_sizes_desc, + const void* min_sizes, + const mluOpTensorDescriptor_t aspect_ratios_desc, + const void* aspect_ratios, + const mluOpTensorDescriptor_t variances_desc, + const void* variances, + const mluOpTensorDescriptor_t max_sizes_desc, + const void* max_sizes, + const int height, + const int width, + const int im_height, + const int im_width, + const float step_h, + const float step_w, + const float offset, + const bool clip, + const bool min_max_aspect_ratios_order, + const mluOpTensorDescriptor_t output_desc, + void* output, + const mluOpTensorDescriptor_t var_desc, + void* var); }; const std::map, std::vector>> TransPermMap = { diff --git a/python/paddle/fluid/tests/unittests/mlu/test_prior_box_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_prior_box_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..2e2b96c9e77346ff5a860fc5a9f081953420f973 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_prior_box_op_mlu.py @@ -0,0 +1,209 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import sys + +sys.path.append('..') +import numpy as np +from op_test import OpTest +import paddle.fluid as fluid +import paddle +import math + +paddle.enable_static() + + +class TestMLUPriorBox(OpTest): + + def setUp(self): + self.op_type = "prior_box" + self.set_mlu() + self.init_dtype() + self.set_data() + + def test_check_output(self): + self.check_output_with_place(self.place) + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def set_data(self): + self.init_test_params() + self.init_test_input() + self.init_test_output() + self.inputs = {'Input': self.input, 'Image': self.image} + + self.attrs = { + 'min_sizes': self.min_sizes, + 'aspect_ratios': self.aspect_ratios, + 'variances': self.variances, + 'flip': self.flip, + 'clip': self.clip, + 'min_max_aspect_ratios_order': self.min_max_aspect_ratios_order, + 'step_w': self.step_w, + 'step_h': self.step_h, + 'offset': self.offset + } + if len(self.max_sizes) > 0: + self.attrs['max_sizes'] = self.max_sizes + + self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var} + + def set_max_sizes(self): + max_sizes = [5, 10] + self.max_sizes = np.array(max_sizes).astype('float32').tolist() + + def set_min_max_aspect_ratios_order(self): + self.min_max_aspect_ratios_order = True + + def init_test_params(self): + self.layer_w = 32 + self.layer_h = 32 + + self.image_w = 40 + self.image_h = 40 + + self.step_w = float(self.image_w) / float(self.layer_w) + self.step_h = float(self.image_h) / float(self.layer_h) + + self.input_channels = 2 + self.image_channels = 3 + self.batch_size = 10 + + self.min_sizes = [2, 4] + self.min_sizes = np.array(self.min_sizes).astype('float32').tolist() + self.set_max_sizes() + self.aspect_ratios = [2.0, 3.0] + self.flip = True + self.set_min_max_aspect_ratios_order() + self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] + self.aspect_ratios = np.array(self.aspect_ratios, + dtype=np.float64).flatten() + self.variances = [0.1, 0.1, 0.2, 0.2] + self.variances = np.array(self.variances, dtype=np.float64).flatten() + + self.clip = True + self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes) + if len(self.max_sizes) > 0: + self.num_priors += len(self.max_sizes) + self.offset = 0.5 + + def init_test_input(self): + self.image = np.random.random( + (self.batch_size, self.image_channels, self.image_w, + self.image_h)).astype('float32') + + self.input = np.random.random( + (self.batch_size, self.input_channels, self.layer_w, + self.layer_h)).astype('float32') + + def init_test_output(self): + out_dim = (self.layer_h, self.layer_w, self.num_priors, 4) + out_boxes = np.zeros(out_dim).astype('float32') + out_var = np.zeros(out_dim).astype('float32') + + idx = 0 + for h in range(self.layer_h): + for w in range(self.layer_w): + c_x = (w + self.offset) * self.step_w + c_y = (h + self.offset) * self.step_h + idx = 0 + for s in range(len(self.min_sizes)): + min_size = self.min_sizes[s] + if not self.min_max_aspect_ratios_order: + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + c_w = min_size * math.sqrt(ar) / 2 + c_h = (min_size / math.sqrt(ar)) / 2 + out_boxes[h, w, + idx, :] = [(c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h] + idx += 1 + + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + c_w = c_h = math.sqrt(min_size * max_size) / 2 + out_boxes[h, w, + idx, :] = [(c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h] + idx += 1 + else: + c_w = c_h = min_size / 2. + out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h] + idx += 1 + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + c_w = c_h = math.sqrt(min_size * max_size) / 2 + out_boxes[h, w, + idx, :] = [(c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h] + idx += 1 + + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + if abs(ar - 1.) < 1e-6: + continue + c_w = min_size * math.sqrt(ar) / 2 + c_h = (min_size / math.sqrt(ar)) / 2 + out_boxes[h, w, + idx, :] = [(c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h] + idx += 1 + + # clip the prior's coordidate such that it is within[0, 1] + if self.clip: + out_boxes = np.clip(out_boxes, 0.0, 1.0) + # set the variance. + out_var = np.tile(self.variances, + (self.layer_h, self.layer_w, self.num_priors, 1)) + self.out_boxes = out_boxes.astype('float32') + self.out_var = out_var.astype('float32') + + +class TestMLUPriorBoxWithoutMaxSize(TestMLUPriorBox): + + def set_max_sizes(self): + self.max_sizes = [] + + +class TestMLUPriorBoxWithoutSpecifiedOutOrder(TestMLUPriorBox): + + def set_min_max_aspect_ratios_order(self): + self.min_max_aspect_ratios_order = False + + +if __name__ == '__main__': + unittest.main()