diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 08c44a2d39ecf026af722251d29c7bf09db25dc7..a85bca364649900a07b5a625ff6f1e2d54ad5162 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -18,11 +18,20 @@ endfunction() if (WITH_ASCEND_CL) detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu box_coder_op_npu.cc) detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu density_prior_box_op_npu.cc) - detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu prior_box_op_npu.cc) else() detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu) detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu) - detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) +endif() + +if(WITH_XPU) + detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_xpu.cc) + detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op_xpu.cc) +elseif(WITH_ASCEND_CL) + detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_npu.cc) + detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu prior_box_op_npu.cc) +else() + detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op.cu) + detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) endif() detection_library(bipartite_match_op SRCS bipartite_match_op.cc) @@ -63,14 +72,6 @@ else() detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc) endif() -if(WITH_XPU) - detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_xpu.cc) -elseif(WITH_ASCEND_CL) - detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_npu.cc) -else() - detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op.cu) -endif() - detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu) #Export local libraries to parent # set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/detection/prior_box_op_xpu.cc b/paddle/fluid/operators/detection/prior_box_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..bab394689546e495a0f7892870c071f0fb7b3f06 --- /dev/null +++ b/paddle/fluid/operators/detection/prior_box_op_xpu.cc @@ -0,0 +1,108 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/detection/prior_box_op.h" + +namespace paddle { +namespace operators { + +template +class PriorBoxOpXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* image = ctx.Input("Image"); + auto* boxes = ctx.Output("Boxes"); + auto* vars = ctx.Output("Variances"); + + auto min_sizes = ctx.Attr>("min_sizes"); + auto max_sizes = ctx.Attr>("max_sizes"); + auto input_aspect_ratio = ctx.Attr>("aspect_ratios"); + auto variances = ctx.Attr>("variances"); + auto flip = ctx.Attr("flip"); + auto clip = ctx.Attr("clip"); + auto min_max_aspect_ratios_order = + ctx.Attr("min_max_aspect_ratios_order"); + + std::vector aspect_ratios; + ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios); + + K step_w = static_cast(ctx.Attr("step_w")); + K step_h = static_cast(ctx.Attr("step_h")); + K offset = static_cast(ctx.Attr("offset")); + + auto img_width = image->dims()[3]; + auto img_height = image->dims()[2]; + + auto feature_width = input->dims()[3]; + auto feature_height = input->dims()[2]; + + K step_width, step_height; + if (step_w == 0 || step_h == 0) { + step_width = static_cast(img_width) / feature_width; + step_height = static_cast(img_height) / feature_height; + } else { + step_width = step_w; + step_height = step_h; + } + + int num_priors = aspect_ratios.size() * min_sizes.size(); + if (max_sizes.size() > 0) { + num_priors += max_sizes.size(); + } + + boxes->mutable_data(ctx.GetPlace()); + vars->mutable_data(ctx.GetPlace()); + + const auto& dev_ctx = + ctx.template device_context(); + auto boxes_data = boxes->data(); + auto vars_data = vars->data(); + xpu::VectorParam aspect_ratios_param{ + aspect_ratios.data(), static_cast(aspect_ratios.size()), nullptr}; + xpu::VectorParam min_sizes_param{ + min_sizes.data(), static_cast(min_sizes.size()), nullptr}; + xpu::VectorParam max_sizes_param{ + max_sizes.data(), static_cast(max_sizes.size()), nullptr}; + + int ret = xpu::gen_prior_box( + dev_ctx.x_context(), boxes_data, aspect_ratios_param, min_sizes_param, + max_sizes_param, feature_height, feature_width, img_height, img_width, + offset, step_height, step_width, clip, min_max_aspect_ratios_order); + PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, + platform::errors::External( + "XPU gen_prior_box kernel return wrong value[%d %s]", + ret, XPUAPIErrorMsg[ret])); + + int box_num = feature_height * feature_width * num_priors; + int vlen = variances.size(); + for (int i = 0; i < box_num; ++i) { + ret = xpu_memcpy(vars_data + i * vlen, variances.data(), vlen * sizeof(K), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, platform::errors::External( + "XPU xpu_memcpy return wrong " + "value[%d %s] in prior_box.", + ret, XPUAPIErrorMsg[ret])); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL(prior_box, ops::PriorBoxOpXPUKernel); + +#endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 78fc53cfc8535e70cbc978884dca2806514b7490..636b27e051122acd566005d03d5012325f70368b 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -289,6 +289,7 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"prior_box", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, // AddMore }; diff --git a/python/paddle/fluid/tests/unittests/xpu/test_prior_box_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_prior_box_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..44137f4718743ccfe5290b0a53d7dd41312653a8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_prior_box_op_xpu.py @@ -0,0 +1,201 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +import math +import paddle +from op_test import OpTest +from op_test_xpu import XPUOpTest + +paddle.enable_static() + + +class TestPriorBoxOp(XPUOpTest): + def set_data(self): + self.init_test_params() + self.init_test_input() + self.init_test_output() + self.inputs = {'Input': self.input, 'Image': self.image} + + self.attrs = { + 'min_sizes': self.min_sizes, + 'aspect_ratios': self.aspect_ratios, + 'variances': self.variances, + 'flip': self.flip, + 'clip': self.clip, + 'min_max_aspect_ratios_order': self.min_max_aspect_ratios_order, + 'step_w': self.step_w, + 'step_h': self.step_h, + 'offset': self.offset + } + if len(self.max_sizes) > 0: + self.attrs['max_sizes'] = self.max_sizes + + self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var} + + def test_check_output(self): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + pass + + def setUp(self): + self.op_type = "prior_box" + self.use_xpu = True + self.set_data() + + def set_max_sizes(self): + max_sizes = [5, 10] + self.max_sizes = np.array(max_sizes).astype('float32').tolist() + + def set_min_max_aspect_ratios_order(self): + self.min_max_aspect_ratios_order = False + + def init_test_params(self): + self.layer_w = 32 + self.layer_h = 32 + + self.image_w = 40 + self.image_h = 40 + + self.step_w = float(self.image_w) / float(self.layer_w) + self.step_h = float(self.image_h) / float(self.layer_h) + + self.input_channels = 2 + self.image_channels = 3 + self.batch_size = 10 + + self.min_sizes = [2, 4] + self.min_sizes = np.array(self.min_sizes).astype('float32').tolist() + self.set_max_sizes() + self.aspect_ratios = [2.0, 3.0] + self.flip = True + self.set_min_max_aspect_ratios_order() + self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] + self.aspect_ratios = np.array( + self.aspect_ratios, dtype=np.float).flatten() + self.variances = [0.1, 0.1, 0.2, 0.2] + self.variances = np.array(self.variances, dtype=np.float).flatten() + + self.clip = True + self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes) + if len(self.max_sizes) > 0: + self.num_priors += len(self.max_sizes) + self.offset = 0.5 + + def init_test_input(self): + self.image = np.random.random( + (self.batch_size, self.image_channels, self.image_w, + self.image_h)).astype('float32') + + self.input = np.random.random( + (self.batch_size, self.input_channels, self.layer_w, + self.layer_h)).astype('float32') + + def init_test_output(self): + out_dim = (self.layer_h, self.layer_w, self.num_priors, 4) + out_boxes = np.zeros(out_dim).astype('float32') + out_var = np.zeros(out_dim).astype('float32') + + idx = 0 + for h in range(self.layer_h): + for w in range(self.layer_w): + c_x = (w + self.offset) * self.step_w + c_y = (h + self.offset) * self.step_h + idx = 0 + for s in range(len(self.min_sizes)): + min_size = self.min_sizes[s] + if not self.min_max_aspect_ratios_order: + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + c_w = min_size * math.sqrt(ar) / 2 + c_h = (min_size / math.sqrt(ar)) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, (c_y - c_h) / + self.image_h, (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h + ] + idx += 1 + + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + c_w = c_h = math.sqrt(min_size * max_size) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, (c_y - c_h) / + self.image_h, (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h + ] + idx += 1 + else: + c_w = c_h = min_size / 2. + out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h] + idx += 1 + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + c_w = c_h = math.sqrt(min_size * max_size) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, (c_y - c_h) / + self.image_h, (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h + ] + idx += 1 + + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + if abs(ar - 1.) < 1e-6: + continue + c_w = min_size * math.sqrt(ar) / 2 + c_h = (min_size / math.sqrt(ar)) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, (c_y - c_h) / + self.image_h, (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h + ] + idx += 1 + + # clip the prior's coordidate such that it is within[0, 1] + if self.clip: + out_boxes = np.clip(out_boxes, 0.0, 1.0) + # set the variance. + out_var = np.tile(self.variances, (self.layer_h, self.layer_w, + self.num_priors, 1)) + self.out_boxes = out_boxes.astype('float32') + self.out_var = out_var.astype('float32') + + +class TestPriorBoxOpWithoutMaxSize(TestPriorBoxOp): + def set_max_sizes(self): + self.max_sizes = [] + + +class TestPriorBoxOpWithSpecifiedOutOrder(TestPriorBoxOp): + def set_min_max_aspect_ratios_order(self): + self.min_max_aspect_ratios_order = True + + +if __name__ == '__main__': + unittest.main()