未验证 提交 1127fecb 编写于 作者: F furnace 提交者: GitHub

[NPU] add NPU kernel for prior_box op (#37519)

* [NPU] add NPU kernel for prior_box op

* [NPU] delete debug codes
上级 65056742
......@@ -18,14 +18,15 @@ endfunction()
if (WITH_ASCEND_CL)
detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu box_coder_op_npu.cc)
detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu density_prior_box_op_npu.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu prior_box_op_npu.cc)
else()
detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu)
detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
endif()
detection_library(bipartite_match_op SRCS bipartite_match_op.cc)
detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
detection_library(anchor_generator_op SRCS anchor_generator_op.cc
anchor_generator_op.cu)
detection_library(target_assign_op SRCS target_assign_op.cc
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/prior_box_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class PriorBoxNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto* image = ctx.Input<Tensor>("Image");
auto* boxes = ctx.Output<Tensor>("Boxes");
auto* variances = ctx.Output<Tensor>("Variances");
PADDLE_ENFORCE_EQ(boxes->dims(), variances->dims(),
platform::errors::Unimplemented(
"the shape of boxes and variances must be same in "
"the npu kernel of prior_box, but got boxes->dims() "
"= [%s], variances->dims() = [%s]",
boxes->dims(), variances->dims()));
auto min_sizes = ctx.Attr<std::vector<float>>("min_sizes");
auto max_sizes = ctx.Attr<std::vector<float>>("max_sizes");
auto aspect_ratios = ctx.Attr<std::vector<float>>("aspect_ratios");
auto variances_attr = ctx.Attr<std::vector<float>>("variances");
bool flip = ctx.Attr<bool>("flip");
bool clip = ctx.Attr<bool>("clip");
float step_w = ctx.Attr<float>("step_w");
float step_h = ctx.Attr<float>("step_h");
float offset = ctx.Attr<float>("offset");
auto place = ctx.GetPlace();
Tensor out(input->type());
auto out_dims = framework::vectorize(boxes->dims());
out_dims.insert(out_dims.begin(), 2);
out.Resize(framework::make_ddim(out_dims));
out.mutable_data<T>(place);
framework::NPUAttributeMap attr_input = {{"min_size", min_sizes},
{"max_size", max_sizes},
{"aspect_ratio", aspect_ratios},
{"step_h", step_h},
{"step_w", step_w},
{"flip", flip},
{"clip", clip},
{"offset", offset},
{"variance", variances_attr}};
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
const auto& runner =
NpuOpRunner("PriorBox", {*input, *image}, {out}, attr_input);
runner.Run(stream);
out.Resize(framework::make_ddim({out.numel()}));
Tensor out_boxes = out.Slice(0, boxes->numel());
Tensor out_variances = out.Slice(boxes->numel(), out.numel());
out_boxes.Resize(boxes->dims());
out_variances.Resize(variances->dims());
boxes->mutable_data<T>(place);
variances->mutable_data<T>(place);
framework::TensorCopy(
out_boxes, place,
ctx.template device_context<platform::NPUDeviceContext>(), boxes);
framework::TensorCopy(
out_variances, place,
ctx.template device_context<platform::NPUDeviceContext>(), variances);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
prior_box, ops::PriorBoxNPUKernel<plat::NPUDeviceContext, float>,
ops::PriorBoxNPUKernel<plat::NPUDeviceContext, plat::float16>);
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import sys
import math
from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator
paddle.enable_static()
class TestNPUPriorBox(OpTest):
def setUp(self):
self.op_type = "prior_box"
self.set_npu()
self.init_dtype()
self.set_data()
def test_check_output(self):
self.check_output_with_place(self.place, atol=self.atol)
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def init_dtype(self):
self.dtype = np.float32
def set_data(self):
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {'Input': self.input, 'Image': self.image}
self.attrs = {
'min_sizes': self.min_sizes,
'aspect_ratios': self.aspect_ratios,
'variances': self.variances,
'flip': self.flip,
'clip': self.clip,
'min_max_aspect_ratios_order': self.min_max_aspect_ratios_order,
'step_w': self.step_w,
'step_h': self.step_h,
'offset': self.offset
}
if len(self.max_sizes) > 0:
self.attrs['max_sizes'] = self.max_sizes
self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var}
def set_max_sizes(self):
max_sizes = [5, 10]
self.max_sizes = np.array(max_sizes).astype('float32').tolist()
def set_min_max_aspect_ratios_order(self):
self.min_max_aspect_ratios_order = True
self.atol = 1e-3
def init_test_params(self):
self.layer_w = 32
self.layer_h = 32
self.image_w = 40
self.image_h = 40
self.step_w = float(self.image_w) / float(self.layer_w)
self.step_h = float(self.image_h) / float(self.layer_h)
self.input_channels = 2
self.image_channels = 3
self.batch_size = 10
self.min_sizes = [2, 4]
self.min_sizes = np.array(self.min_sizes).astype('float32').tolist()
self.set_max_sizes()
self.aspect_ratios = [2.0, 3.0]
self.flip = True
self.set_min_max_aspect_ratios_order()
self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0]
self.aspect_ratios = np.array(
self.aspect_ratios, dtype=np.float).flatten()
self.variances = [0.1, 0.1, 0.2, 0.2]
self.variances = np.array(self.variances, dtype=np.float).flatten()
self.clip = True
self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes)
if len(self.max_sizes) > 0:
self.num_priors += len(self.max_sizes)
self.offset = 0.5
def init_test_input(self):
self.image = np.random.random(
(self.batch_size, self.image_channels, self.image_w,
self.image_h)).astype('float32')
self.input = np.random.random(
(self.batch_size, self.input_channels, self.layer_w,
self.layer_h)).astype('float32')
def init_test_output(self):
out_dim = (self.layer_h, self.layer_w, self.num_priors, 4)
out_boxes = np.zeros(out_dim).astype('float32')
out_var = np.zeros(out_dim).astype('float32')
idx = 0
for h in range(self.layer_h):
for w in range(self.layer_w):
c_x = (w + self.offset) * self.step_w
c_y = (h + self.offset) * self.step_h
idx = 0
for s in range(len(self.min_sizes)):
min_size = self.min_sizes[s]
if not self.min_max_aspect_ratios_order:
# rest of priors
for r in range(len(self.real_aspect_ratios)):
ar = self.real_aspect_ratios[r]
c_w = min_size * math.sqrt(ar) / 2
c_h = (min_size / math.sqrt(ar)) / 2
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w, (c_y - c_h) /
self.image_h, (c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h
]
idx += 1
if len(self.max_sizes) > 0:
max_size = self.max_sizes[s]
# second prior: aspect_ratio = 1,
c_w = c_h = math.sqrt(min_size * max_size) / 2
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w, (c_y - c_h) /
self.image_h, (c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h
]
idx += 1
else:
c_w = c_h = min_size / 2.
out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h]
idx += 1
if len(self.max_sizes) > 0:
max_size = self.max_sizes[s]
# second prior: aspect_ratio = 1,
c_w = c_h = math.sqrt(min_size * max_size) / 2
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w, (c_y - c_h) /
self.image_h, (c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h
]
idx += 1
# rest of priors
for r in range(len(self.real_aspect_ratios)):
ar = self.real_aspect_ratios[r]
if abs(ar - 1.) < 1e-6:
continue
c_w = min_size * math.sqrt(ar) / 2
c_h = (min_size / math.sqrt(ar)) / 2
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w, (c_y - c_h) /
self.image_h, (c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h
]
idx += 1
# clip the prior's coordidate such that it is within[0, 1]
if self.clip:
out_boxes = np.clip(out_boxes, 0.0, 1.0)
# set the variance.
out_var = np.tile(self.variances, (self.layer_h, self.layer_w,
self.num_priors, 1))
self.out_boxes = out_boxes.astype('float32')
self.out_var = out_var.astype('float32')
class TestNPUPriorBoxWithoutMaxSize(TestNPUPriorBox):
def set_max_sizes(self):
self.max_sizes = []
class TestNPUPriorBoxWithoutSpecifiedOutOrder(TestNPUPriorBox):
def set_min_max_aspect_ratios_order(self):
self.min_max_aspect_ratios_order = False
self.atol = 1e-1
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册