From 1127fecb49a33801f20898f9b896db4254f31eac Mon Sep 17 00:00:00 2001 From: furnace <34057289+windstamp@users.noreply.github.com> Date: Thu, 25 Nov 2021 16:49:30 +0800 Subject: [PATCH] [NPU] add NPU kernel for prior_box op (#37519) * [NPU] add NPU kernel for prior_box op * [NPU] delete debug codes --- .../fluid/operators/detection/CMakeLists.txt | 3 +- .../operators/detection/prior_box_op_npu.cc | 102 +++++++++ .../unittests/npu/test_prior_box_op_npu.py | 206 ++++++++++++++++++ 3 files changed, 310 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/detection/prior_box_op_npu.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_prior_box_op_npu.py diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 506ae56a126..08c44a2d39e 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -18,14 +18,15 @@ endfunction() if (WITH_ASCEND_CL) detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu box_coder_op_npu.cc) detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu density_prior_box_op_npu.cc) + detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu prior_box_op_npu.cc) else() detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu) detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu) + detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) endif() detection_library(bipartite_match_op SRCS bipartite_match_op.cc) detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc) -detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) detection_library(anchor_generator_op SRCS anchor_generator_op.cc anchor_generator_op.cu) detection_library(target_assign_op SRCS target_assign_op.cc diff --git a/paddle/fluid/operators/detection/prior_box_op_npu.cc b/paddle/fluid/operators/detection/prior_box_op_npu.cc new file mode 100644 index 00000000000..84840e77ed2 --- /dev/null +++ b/paddle/fluid/operators/detection/prior_box_op_npu.cc @@ -0,0 +1,102 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection/prior_box_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class PriorBoxNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* image = ctx.Input("Image"); + auto* boxes = ctx.Output("Boxes"); + auto* variances = ctx.Output("Variances"); + + PADDLE_ENFORCE_EQ(boxes->dims(), variances->dims(), + platform::errors::Unimplemented( + "the shape of boxes and variances must be same in " + "the npu kernel of prior_box, but got boxes->dims() " + "= [%s], variances->dims() = [%s]", + boxes->dims(), variances->dims())); + + auto min_sizes = ctx.Attr>("min_sizes"); + auto max_sizes = ctx.Attr>("max_sizes"); + auto aspect_ratios = ctx.Attr>("aspect_ratios"); + auto variances_attr = ctx.Attr>("variances"); + bool flip = ctx.Attr("flip"); + bool clip = ctx.Attr("clip"); + float step_w = ctx.Attr("step_w"); + float step_h = ctx.Attr("step_h"); + float offset = ctx.Attr("offset"); + + auto place = ctx.GetPlace(); + + Tensor out(input->type()); + auto out_dims = framework::vectorize(boxes->dims()); + out_dims.insert(out_dims.begin(), 2); + out.Resize(framework::make_ddim(out_dims)); + out.mutable_data(place); + + framework::NPUAttributeMap attr_input = {{"min_size", min_sizes}, + {"max_size", max_sizes}, + {"aspect_ratio", aspect_ratios}, + {"step_h", step_h}, + {"step_w", step_w}, + {"flip", flip}, + {"clip", clip}, + {"offset", offset}, + {"variance", variances_attr}}; + + auto stream = + ctx.template device_context() + .stream(); + + const auto& runner = + NpuOpRunner("PriorBox", {*input, *image}, {out}, attr_input); + runner.Run(stream); + + out.Resize(framework::make_ddim({out.numel()})); + Tensor out_boxes = out.Slice(0, boxes->numel()); + Tensor out_variances = out.Slice(boxes->numel(), out.numel()); + + out_boxes.Resize(boxes->dims()); + out_variances.Resize(variances->dims()); + + boxes->mutable_data(place); + variances->mutable_data(place); + + framework::TensorCopy( + out_boxes, place, + ctx.template device_context(), boxes); + framework::TensorCopy( + out_variances, place, + ctx.template device_context(), variances); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL( + prior_box, ops::PriorBoxNPUKernel, + ops::PriorBoxNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_prior_box_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_prior_box_op_npu.py new file mode 100644 index 00000000000..47b78d30820 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_prior_box_op_npu.py @@ -0,0 +1,206 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import sys +import math + +from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator + +paddle.enable_static() + + +class TestNPUPriorBox(OpTest): + def setUp(self): + self.op_type = "prior_box" + self.set_npu() + self.init_dtype() + self.set_data() + + def test_check_output(self): + self.check_output_with_place(self.place, atol=self.atol) + + def set_npu(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def set_data(self): + self.init_test_params() + self.init_test_input() + self.init_test_output() + self.inputs = {'Input': self.input, 'Image': self.image} + + self.attrs = { + 'min_sizes': self.min_sizes, + 'aspect_ratios': self.aspect_ratios, + 'variances': self.variances, + 'flip': self.flip, + 'clip': self.clip, + 'min_max_aspect_ratios_order': self.min_max_aspect_ratios_order, + 'step_w': self.step_w, + 'step_h': self.step_h, + 'offset': self.offset + } + if len(self.max_sizes) > 0: + self.attrs['max_sizes'] = self.max_sizes + + self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var} + + def set_max_sizes(self): + max_sizes = [5, 10] + self.max_sizes = np.array(max_sizes).astype('float32').tolist() + + def set_min_max_aspect_ratios_order(self): + self.min_max_aspect_ratios_order = True + self.atol = 1e-3 + + def init_test_params(self): + self.layer_w = 32 + self.layer_h = 32 + + self.image_w = 40 + self.image_h = 40 + + self.step_w = float(self.image_w) / float(self.layer_w) + self.step_h = float(self.image_h) / float(self.layer_h) + + self.input_channels = 2 + self.image_channels = 3 + self.batch_size = 10 + + self.min_sizes = [2, 4] + self.min_sizes = np.array(self.min_sizes).astype('float32').tolist() + self.set_max_sizes() + self.aspect_ratios = [2.0, 3.0] + self.flip = True + self.set_min_max_aspect_ratios_order() + self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] + self.aspect_ratios = np.array( + self.aspect_ratios, dtype=np.float).flatten() + self.variances = [0.1, 0.1, 0.2, 0.2] + self.variances = np.array(self.variances, dtype=np.float).flatten() + + self.clip = True + self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes) + if len(self.max_sizes) > 0: + self.num_priors += len(self.max_sizes) + self.offset = 0.5 + + def init_test_input(self): + self.image = np.random.random( + (self.batch_size, self.image_channels, self.image_w, + self.image_h)).astype('float32') + + self.input = np.random.random( + (self.batch_size, self.input_channels, self.layer_w, + self.layer_h)).astype('float32') + + def init_test_output(self): + out_dim = (self.layer_h, self.layer_w, self.num_priors, 4) + out_boxes = np.zeros(out_dim).astype('float32') + out_var = np.zeros(out_dim).astype('float32') + + idx = 0 + for h in range(self.layer_h): + for w in range(self.layer_w): + c_x = (w + self.offset) * self.step_w + c_y = (h + self.offset) * self.step_h + idx = 0 + for s in range(len(self.min_sizes)): + min_size = self.min_sizes[s] + if not self.min_max_aspect_ratios_order: + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + c_w = min_size * math.sqrt(ar) / 2 + c_h = (min_size / math.sqrt(ar)) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, (c_y - c_h) / + self.image_h, (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h + ] + idx += 1 + + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + c_w = c_h = math.sqrt(min_size * max_size) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, (c_y - c_h) / + self.image_h, (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h + ] + idx += 1 + else: + c_w = c_h = min_size / 2. + out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h] + idx += 1 + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + c_w = c_h = math.sqrt(min_size * max_size) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, (c_y - c_h) / + self.image_h, (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h + ] + idx += 1 + + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + if abs(ar - 1.) < 1e-6: + continue + c_w = min_size * math.sqrt(ar) / 2 + c_h = (min_size / math.sqrt(ar)) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, (c_y - c_h) / + self.image_h, (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h + ] + idx += 1 + + # clip the prior's coordidate such that it is within[0, 1] + if self.clip: + out_boxes = np.clip(out_boxes, 0.0, 1.0) + # set the variance. + out_var = np.tile(self.variances, (self.layer_h, self.layer_w, + self.num_priors, 1)) + self.out_boxes = out_boxes.astype('float32') + self.out_var = out_var.astype('float32') + + +class TestNPUPriorBoxWithoutMaxSize(TestNPUPriorBox): + def set_max_sizes(self): + self.max_sizes = [] + + +class TestNPUPriorBoxWithoutSpecifiedOutOrder(TestNPUPriorBox): + def set_min_max_aspect_ratios_order(self): + self.min_max_aspect_ratios_order = False + self.atol = 1e-1 + + +if __name__ == '__main__': + unittest.main() -- GitLab