未验证 提交 ff37e48e 编写于 作者: C cifar10 提交者: GitHub

[MLU] add fluid MLUOps prior_box (#46585)

上级 146d70ca
......@@ -46,7 +46,7 @@ if(WITH_XPU)
elseif(WITH_MLU)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op_mlu.cc)
detection_library(prior_box_op SRCS prior_box_op.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op_mlu.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op_mlu.cc)
elseif(WITH_ASCEND_CL)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/detection/prior_box_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
template <typename T>
class PriorBoxMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<phi::DenseTensor>("Input");
auto* image = ctx.Input<phi::DenseTensor>("Image");
auto* boxes = ctx.Output<phi::DenseTensor>("Boxes");
auto* variances = ctx.Output<phi::DenseTensor>("Variances");
float step_w = ctx.Attr<float>("step_w");
float step_h = ctx.Attr<float>("step_h");
float offset = ctx.Attr<float>("offset");
bool clip = ctx.Attr<bool>("clip");
bool min_max_aspect_ratios_order =
ctx.Attr<bool>("min_max_aspect_ratios_order");
int im_width = image->dims()[3];
int im_height = image->dims()[2];
int width = input->dims()[3];
int height = input->dims()[2];
auto aspect_ratios = ctx.Attr<std::vector<float>>("aspect_ratios");
bool flip = ctx.Attr<bool>("flip");
std::vector<float> new_aspect_ratios;
ExpandAspectRatios(aspect_ratios, flip, &new_aspect_ratios);
auto& dev_ctx = ctx.template device_context<platform::MLUDeviceContext>();
phi::DenseTensor ratios;
paddle::framework::TensorFromVector(new_aspect_ratios, dev_ctx, &ratios);
MLUOpTensorDesc new_aspect_ratios_desc(ratios);
auto min_sizes = ctx.Attr<std::vector<float>>("min_sizes");
phi::DenseTensor min;
paddle::framework::TensorFromVector(min_sizes, dev_ctx, &min);
MLUOpTensorDesc min_sizes_desc(min);
auto max_sizes = ctx.Attr<std::vector<float>>("max_sizes");
phi::DenseTensor max;
paddle::framework::TensorFromVector(max_sizes, dev_ctx, &max);
MLUOpTensorDesc max_sizes_desc(max);
auto variances_attr = ctx.Attr<std::vector<float>>("variances");
phi::DenseTensor var_tensor;
paddle::framework::TensorFromVector(variances_attr, dev_ctx, &var_tensor);
MLUOpTensorDesc variances_attr_desc(var_tensor);
auto place = ctx.GetPlace();
boxes->mutable_data<T>(place);
variances->mutable_data<T>(place);
MLUOpTensorDesc var_desc(*variances);
MLUOpTensorDesc output_desc(*boxes);
MLUOP::OpPriorBox(ctx,
min_sizes_desc.get(),
GetBasePtr(&min),
new_aspect_ratios_desc.get(),
GetBasePtr(&ratios),
variances_attr_desc.get(),
GetBasePtr(&var_tensor),
max_sizes_desc.get(),
GetBasePtr(&max),
height,
width,
im_height,
im_width,
step_h,
step_w,
offset,
clip,
min_max_aspect_ratios_order,
output_desc.get(),
GetBasePtr(boxes),
var_desc.get(),
GetBasePtr(variances));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(prior_box, ops::PriorBoxMLUKernel<float>);
......@@ -5458,5 +5458,54 @@ MLURNNDesc::~MLURNNDesc() {
scores));
}
/* static */ void MLUOP::OpPriorBox(
const ExecutionContext& ctx,
const mluOpTensorDescriptor_t min_sizes_desc,
const void* min_sizes,
const mluOpTensorDescriptor_t aspect_ratios_desc,
const void* aspect_ratios,
const mluOpTensorDescriptor_t variances_desc,
const void* variances,
const mluOpTensorDescriptor_t max_sizes_desc,
const void* max_sizes,
const int height,
const int width,
const int im_height,
const int im_width,
const float step_h,
const float step_w,
const float offset,
const bool clip,
const bool min_max_aspect_ratios_order,
const mluOpTensorDescriptor_t output_desc,
void* output,
const mluOpTensorDescriptor_t var_desc,
void* var) {
mluOpHandle_t handle = GetMLUOpHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(mluOpPriorBox(handle,
min_sizes_desc,
min_sizes,
aspect_ratios_desc,
aspect_ratios,
variances_desc,
variances,
max_sizes_desc,
max_sizes,
height,
width,
im_height,
im_width,
step_h,
step_w,
offset,
clip,
min_max_aspect_ratios_order,
output_desc,
output,
var_desc,
var));
}
} // namespace operators
} // namespace paddle
......@@ -2312,6 +2312,29 @@ class MLUOP {
void* boxes,
const mluOpTensorDescriptor_t scores_desc,
void* scores);
static void OpPriorBox(const ExecutionContext& ctx,
const mluOpTensorDescriptor_t min_sizes_desc,
const void* min_sizes,
const mluOpTensorDescriptor_t aspect_ratios_desc,
const void* aspect_ratios,
const mluOpTensorDescriptor_t variances_desc,
const void* variances,
const mluOpTensorDescriptor_t max_sizes_desc,
const void* max_sizes,
const int height,
const int width,
const int im_height,
const int im_width,
const float step_h,
const float step_w,
const float offset,
const bool clip,
const bool min_max_aspect_ratios_order,
const mluOpTensorDescriptor_t output_desc,
void* output,
const mluOpTensorDescriptor_t var_desc,
void* var);
};
const std::map<const std::string, std::pair<std::vector<int>, std::vector<int>>>
TransPermMap = {
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append('..')
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
import paddle
import math
paddle.enable_static()
class TestMLUPriorBox(OpTest):
def setUp(self):
self.op_type = "prior_box"
self.set_mlu()
self.init_dtype()
self.set_data()
def test_check_output(self):
self.check_output_with_place(self.place)
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.MLUPlace(0)
def init_dtype(self):
self.dtype = np.float32
def set_data(self):
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {'Input': self.input, 'Image': self.image}
self.attrs = {
'min_sizes': self.min_sizes,
'aspect_ratios': self.aspect_ratios,
'variances': self.variances,
'flip': self.flip,
'clip': self.clip,
'min_max_aspect_ratios_order': self.min_max_aspect_ratios_order,
'step_w': self.step_w,
'step_h': self.step_h,
'offset': self.offset
}
if len(self.max_sizes) > 0:
self.attrs['max_sizes'] = self.max_sizes
self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var}
def set_max_sizes(self):
max_sizes = [5, 10]
self.max_sizes = np.array(max_sizes).astype('float32').tolist()
def set_min_max_aspect_ratios_order(self):
self.min_max_aspect_ratios_order = True
def init_test_params(self):
self.layer_w = 32
self.layer_h = 32
self.image_w = 40
self.image_h = 40
self.step_w = float(self.image_w) / float(self.layer_w)
self.step_h = float(self.image_h) / float(self.layer_h)
self.input_channels = 2
self.image_channels = 3
self.batch_size = 10
self.min_sizes = [2, 4]
self.min_sizes = np.array(self.min_sizes).astype('float32').tolist()
self.set_max_sizes()
self.aspect_ratios = [2.0, 3.0]
self.flip = True
self.set_min_max_aspect_ratios_order()
self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0]
self.aspect_ratios = np.array(self.aspect_ratios,
dtype=np.float64).flatten()
self.variances = [0.1, 0.1, 0.2, 0.2]
self.variances = np.array(self.variances, dtype=np.float64).flatten()
self.clip = True
self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes)
if len(self.max_sizes) > 0:
self.num_priors += len(self.max_sizes)
self.offset = 0.5
def init_test_input(self):
self.image = np.random.random(
(self.batch_size, self.image_channels, self.image_w,
self.image_h)).astype('float32')
self.input = np.random.random(
(self.batch_size, self.input_channels, self.layer_w,
self.layer_h)).astype('float32')
def init_test_output(self):
out_dim = (self.layer_h, self.layer_w, self.num_priors, 4)
out_boxes = np.zeros(out_dim).astype('float32')
out_var = np.zeros(out_dim).astype('float32')
idx = 0
for h in range(self.layer_h):
for w in range(self.layer_w):
c_x = (w + self.offset) * self.step_w
c_y = (h + self.offset) * self.step_h
idx = 0
for s in range(len(self.min_sizes)):
min_size = self.min_sizes[s]
if not self.min_max_aspect_ratios_order:
# rest of priors
for r in range(len(self.real_aspect_ratios)):
ar = self.real_aspect_ratios[r]
c_w = min_size * math.sqrt(ar) / 2
c_h = (min_size / math.sqrt(ar)) / 2
out_boxes[h, w,
idx, :] = [(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h]
idx += 1
if len(self.max_sizes) > 0:
max_size = self.max_sizes[s]
# second prior: aspect_ratio = 1,
c_w = c_h = math.sqrt(min_size * max_size) / 2
out_boxes[h, w,
idx, :] = [(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h]
idx += 1
else:
c_w = c_h = min_size / 2.
out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h]
idx += 1
if len(self.max_sizes) > 0:
max_size = self.max_sizes[s]
# second prior: aspect_ratio = 1,
c_w = c_h = math.sqrt(min_size * max_size) / 2
out_boxes[h, w,
idx, :] = [(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h]
idx += 1
# rest of priors
for r in range(len(self.real_aspect_ratios)):
ar = self.real_aspect_ratios[r]
if abs(ar - 1.) < 1e-6:
continue
c_w = min_size * math.sqrt(ar) / 2
c_h = (min_size / math.sqrt(ar)) / 2
out_boxes[h, w,
idx, :] = [(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h]
idx += 1
# clip the prior's coordidate such that it is within[0, 1]
if self.clip:
out_boxes = np.clip(out_boxes, 0.0, 1.0)
# set the variance.
out_var = np.tile(self.variances,
(self.layer_h, self.layer_w, self.num_priors, 1))
self.out_boxes = out_boxes.astype('float32')
self.out_var = out_var.astype('float32')
class TestMLUPriorBoxWithoutMaxSize(TestMLUPriorBox):
def set_max_sizes(self):
self.max_sizes = []
class TestMLUPriorBoxWithoutSpecifiedOutOrder(TestMLUPriorBox):
def set_min_max_aspect_ratios_order(self):
self.min_max_aspect_ratios_order = False
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册