未验证 提交 832b0a15 编写于 作者: 光明和真理's avatar 光明和真理 提交者: GitHub

[MLU] add_fluid_mluop_yolo_box (#46573)

上级 d16360c8
......@@ -47,6 +47,7 @@ elseif(WITH_MLU)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op_mlu.cc)
detection_library(prior_box_op SRCS prior_box_op.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op_mlu.cc)
elseif(WITH_ASCEND_CL)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op_npu.cc)
......@@ -55,6 +56,7 @@ else()
detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op.cu)
detection_library(prior_box_op SRCS prior_box_op.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc)
# detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc)
endif()
......@@ -73,7 +75,6 @@ detection_library(locality_aware_nms_op SRCS locality_aware_nms_op.cc DEPS gpc)
detection_library(matrix_nms_op SRCS matrix_nms_op.cc DEPS gpc)
detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc)
detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc
box_decoder_and_assign_op.cu)
detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
template <typename T>
class YoloBoxMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<phi::DenseTensor>("X");
auto* img_size = ctx.Input<phi::DenseTensor>("ImgSize");
auto* boxes = ctx.Output<phi::DenseTensor>("Boxes");
auto* scores = ctx.Output<phi::DenseTensor>("Scores");
const std::vector<int> anchors = ctx.Attr<std::vector<int>>("anchors");
auto class_num = ctx.Attr<int>("class_num");
auto conf_thresh = ctx.Attr<float>("conf_thresh");
auto downsample_ratio = ctx.Attr<int>("downsample_ratio");
auto clip_bbox = ctx.Attr<bool>("clip_bbox");
auto scale = ctx.Attr<float>("scale_x_y");
auto iou_aware = ctx.Attr<bool>("iou_aware");
auto iou_aware_factor = ctx.Attr<float>("iou_aware_factor");
int anchor_num = anchors.size() / 2;
int64_t size = anchors.size();
auto dim_x = x->dims();
int n = dim_x[0];
int s = anchor_num;
int h = dim_x[2];
int w = dim_x[3];
// The output of mluOpYoloBox: A 4-D tensor with shape [N, anchor_num, 4,
// H*W], the coordinates of boxes, and a 4-D tensor with shape [N,
// anchor_num, :attr:`class_num`, H*W], the classification scores of boxes.
std::vector<int64_t> boxes_dim_mluops({n, s, 4, h * w});
std::vector<int64_t> scores_dim_mluops({n, s, class_num, h * w});
// In Paddle framework: A 3-D tensor with shape [N, M, 4], the coordinates
// of boxes, and a 3-D tensor with shape [N, M, :attr:`class_num`], the
// classification scores of boxes.
std::vector<int64_t> boxes_out_dim({n, s, h * w, 4});
std::vector<int64_t> scores_out_dim({n, s, h * w, class_num});
auto& dev_ctx = ctx.template device_context<MLUDeviceContext>();
phi::DenseTensor boxes_tensor_mluops =
ctx.AllocateTmpTensor<T, MLUDeviceContext>({n, s, 4, h * w}, dev_ctx);
phi::DenseTensor scores_tensor_mluops =
ctx.AllocateTmpTensor<T, MLUDeviceContext>({n, s, class_num, h * w},
dev_ctx);
MLUOpTensorDesc boxes_trans_desc_mluops(
4, boxes_dim_mluops.data(), ToMluOpDataType<T>());
MLUCnnlTensorDesc boxes_trans_desc_cnnl(
4, boxes_dim_mluops.data(), ToCnnlDataType<T>());
MLUOpTensorDesc scores_trans_desc_mluops(
4, scores_dim_mluops.data(), ToMluOpDataType<T>());
MLUCnnlTensorDesc scores_trans_desc_cnnl(
4, scores_dim_mluops.data(), ToCnnlDataType<T>());
boxes->mutable_data<T>(ctx.GetPlace());
scores->mutable_data<T>(ctx.GetPlace());
FillMLUTensorWithHostValue(ctx, static_cast<T>(0), boxes);
FillMLUTensorWithHostValue(ctx, static_cast<T>(0), scores);
MLUOpTensorDesc x_desc(*x, MLUOP_LAYOUT_ARRAY, ToMluOpDataType<T>());
MLUOpTensorDesc img_size_desc(
*img_size, MLUOP_LAYOUT_ARRAY, ToMluOpDataType<int32_t>());
Tensor anchors_temp(framework::TransToPhiDataType(VT::INT32));
anchors_temp.Resize({size});
paddle::framework::TensorFromVector(
anchors, ctx.device_context(), &anchors_temp);
MLUOpTensorDesc anchors_desc(anchors_temp);
MLUCnnlTensorDesc boxes_desc_cnnl(
4, boxes_out_dim.data(), ToCnnlDataType<T>());
MLUCnnlTensorDesc scores_desc_cnnl(
4, scores_out_dim.data(), ToCnnlDataType<T>());
MLUOP::OpYoloBox(ctx,
x_desc.get(),
GetBasePtr(x),
img_size_desc.get(),
GetBasePtr(img_size),
anchors_desc.get(),
GetBasePtr(&anchors_temp),
class_num,
conf_thresh,
downsample_ratio,
clip_bbox,
scale,
iou_aware,
iou_aware_factor,
boxes_trans_desc_mluops.get(),
GetBasePtr(&boxes_tensor_mluops),
scores_trans_desc_mluops.get(),
GetBasePtr(&scores_tensor_mluops));
const std::vector<int> perm = {0, 1, 3, 2};
// transpose the boxes from [N, S, 4, H*W] to [N, S, H*W, 4]
MLUCnnl::Transpose(ctx,
perm,
4,
boxes_trans_desc_cnnl.get(),
GetBasePtr(&boxes_tensor_mluops),
boxes_desc_cnnl.get(),
GetBasePtr(boxes));
// transpose the scores from [N, S, class_num, H*W] to [N, S, H*W,
// class_num]
MLUCnnl::Transpose(ctx,
perm,
4,
scores_trans_desc_cnnl.get(),
GetBasePtr(&scores_tensor_mluops),
scores_desc_cnnl.get(),
GetBasePtr(scores));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(yolo_box, ops::YoloBoxMLUKernel<float>);
......@@ -5418,5 +5418,45 @@ MLURNNDesc::~MLURNNDesc() {
diff_x));
}
/* static */ void MLUOP::OpYoloBox(const ExecutionContext& ctx,
const mluOpTensorDescriptor_t x_desc,
const void* x,
const mluOpTensorDescriptor_t img_size_desc,
const void* img_size,
const mluOpTensorDescriptor_t anchors_desc,
const void* anchors,
const int class_num,
const float conf_thresh,
const int downsample_ratio,
const bool clip_bbox,
const float scale,
const bool iou_aware,
const float iou_aware_factor,
const mluOpTensorDescriptor_t boxes_desc,
void* boxes,
const mluOpTensorDescriptor_t scores_desc,
void* scores) {
mluOpHandle_t handle = GetMLUOpHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(mluOpYoloBox(handle,
x_desc,
x,
img_size_desc,
img_size,
anchors_desc,
anchors,
class_num,
conf_thresh,
downsample_ratio,
clip_bbox,
scale,
iou_aware,
iou_aware_factor,
boxes_desc,
boxes,
scores_desc,
scores));
}
} // namespace operators
} // namespace paddle
......@@ -2292,6 +2292,27 @@ class MLUCnnl {
void* diff_x);
};
class MLUOP {
public:
static void OpYoloBox(const ExecutionContext& ctx,
const mluOpTensorDescriptor_t x_desc,
const void* x,
const mluOpTensorDescriptor_t img_size_desc,
const void* img_size,
const mluOpTensorDescriptor_t anchors_desc,
const void* anchors,
const int class_num,
const float conf_thresh,
const int downsample_ratio,
const bool clip_bbox,
const float scale,
const bool iou_aware,
const float iou_aware_factor,
const mluOpTensorDescriptor_t boxes_desc,
void* boxes,
const mluOpTensorDescriptor_t scores_desc,
void* scores);
};
const std::map<const std::string, std::pair<std::vector<int>, std::vector<int>>>
TransPermMap = {
// trans_mode, (forward_perm, backward_perm)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import sys
sys.path.append("..")
import unittest
import numpy as np
from op_test import OpTest
import paddle
from paddle.fluid import core
import paddle.fluid as fluid
from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor
from paddle.fluid.framework import _test_eager_guard
paddle.enable_static()
def sigmoid(x):
return (1.0 / (1.0 + np.exp(((-1.0) * x))))
def YoloBox(x, img_size, attrs):
(n, c, h, w) = x.shape
anchors = attrs['anchors']
an_num = int((len(anchors) // 2))
class_num = attrs['class_num']
conf_thresh = attrs['conf_thresh']
downsample = attrs['downsample_ratio']
clip_bbox = attrs['clip_bbox']
scale_x_y = attrs['scale_x_y']
iou_aware = attrs['iou_aware']
iou_aware_factor = attrs['iou_aware_factor']
bias_x_y = ((-0.5) * (scale_x_y - 1.0))
input_h = (downsample * h)
input_w = (downsample * w)
if iou_aware:
ioup = x[:, :an_num, :, :]
ioup = np.expand_dims(ioup, axis=(-1))
x = x[:, an_num:, :, :]
x = x.reshape((n, an_num, (5 + class_num), h, w)).transpose((0, 1, 3, 4, 2))
pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
pred_box[:, :, :, :, 0] = ((
(grid_x + (sigmoid(pred_box[:, :, :, :, 0]) * scale_x_y)) + bias_x_y) /
w)
pred_box[:, :, :, :, 1] = ((
(grid_y + (sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y)) + bias_x_y) /
h)
anchors = [(anchors[i], anchors[(i + 1)])
for i in range(0, len(anchors), 2)]
anchors_s = np.array([((an_w / input_w), (an_h / input_h))
for (an_w, an_h) in anchors])
anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1))
anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1))
pred_box[:, :, :, :, 2] = (np.exp(pred_box[:, :, :, :, 2]) * anchor_w)
pred_box[:, :, :, :, 3] = (np.exp(pred_box[:, :, :, :, 3]) * anchor_h)
if iou_aware:
pred_conf = ((sigmoid(x[:, :, :, :, 4:5])**(1 - iou_aware_factor)) *
(sigmoid(ioup)**iou_aware_factor))
else:
pred_conf = sigmoid(x[:, :, :, :, 4:5])
pred_conf[(pred_conf < conf_thresh)] = 0.0
pred_score = (sigmoid(x[:, :, :, :, 5:]) * pred_conf)
pred_box = (pred_box * (pred_conf > 0.0).astype('float32'))
pred_box = pred_box.reshape((n, (-1), 4))
(pred_box[:, :, :2],
pred_box[:, :, 2:4]) = ((pred_box[:, :, :2] - (pred_box[:, :, 2:4] / 2.0)),
(pred_box[:, :, :2] + (pred_box[:, :, 2:4] / 2.0)))
pred_box[:, :, 0] = (pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis])
pred_box[:, :, 1] = (pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis])
pred_box[:, :, 2] = (pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis])
pred_box[:, :, 3] = (pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis])
if clip_bbox:
for i in range(len(pred_box)):
pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf)
pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf)
pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], (-np.inf),
(img_size[(i, 1)] - 1))
pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], (-np.inf),
(img_size[(i, 0)] - 1))
return (pred_box, pred_score.reshape((n, (-1), class_num)))
class TestYoloBoxOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'yolo_box'
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
self.__class__.no_need_check_grad = True
self.python_api = paddle.vision.ops.yolo_box
x = np.random.random(self.x_shape).astype('float32')
img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32')
self.attrs = {
'anchors': self.anchors,
'class_num': self.class_num,
'conf_thresh': self.conf_thresh,
'downsample_ratio': self.downsample,
'clip_bbox': self.clip_bbox,
'scale_x_y': self.scale_x_y,
'iou_aware': self.iou_aware,
'iou_aware_factor': self.iou_aware_factor
}
self.inputs = {'X': x, 'ImgSize': img_size}
(boxes, scores) = YoloBox(x, img_size, self.attrs)
self.outputs = {'Boxes': boxes, 'Scores': scores}
def test_check_output(self):
self.check_output_with_place(self.place, check_eager=False, atol=1e-5)
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = True
self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13,
13)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.0
self.iou_aware = False
self.iou_aware_factor = 0.5
class TestYoloBoxOpNoClipBbox(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = False
self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13,
13)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.0
self.iou_aware = False
self.iou_aware_factor = 0.5
class TestYoloBoxOpScaleXY(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = True
self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13,
13)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.2
self.iou_aware = False
self.iou_aware_factor = 0.5
class TestYoloBoxOpIoUAware(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = True
self.x_shape = (self.batch_size, (an_num * (6 + self.class_num)), 13,
13)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.0
self.iou_aware = True
self.iou_aware_factor = 0.5
class TestYoloBoxDygraph(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
img_size = np.ones((2, 2)).astype('int32')
img_size = paddle.to_tensor(img_size)
x1 = np.random.random([2, 14, 8, 8]).astype('float32')
x1 = paddle.to_tensor(x1)
(boxes, scores) = paddle.vision.ops.yolo_box(x1,
img_size=img_size,
anchors=[10, 13, 16, 30],
class_num=2,
conf_thresh=0.01,
downsample_ratio=8,
clip_bbox=True,
scale_x_y=1.0)
assert ((boxes is not None) and (scores is not None))
x2 = np.random.random([2, 16, 8, 8]).astype('float32')
x2 = paddle.to_tensor(x2)
(boxes, scores) = paddle.vision.ops.yolo_box(x2,
img_size=img_size,
anchors=[10, 13, 16, 30],
class_num=2,
conf_thresh=0.01,
downsample_ratio=8,
clip_bbox=True,
scale_x_y=1.0,
iou_aware=True,
iou_aware_factor=0.5)
paddle.enable_static()
class TestYoloBoxStatic(unittest.TestCase):
def test_static(self):
x1 = paddle.static.data('x1', [2, 14, 8, 8], 'float32')
img_size = paddle.static.data('img_size', [2, 2], 'int32')
(boxes, scores) = paddle.vision.ops.yolo_box(x1,
img_size=img_size,
anchors=[10, 13, 16, 30],
class_num=2,
conf_thresh=0.01,
downsample_ratio=8,
clip_bbox=True,
scale_x_y=1.0)
assert ((boxes is not None) and (scores is not None))
x2 = paddle.static.data('x2', [2, 16, 8, 8], 'float32')
(boxes, scores) = paddle.vision.ops.yolo_box(x2,
img_size=img_size,
anchors=[10, 13, 16, 30],
class_num=2,
conf_thresh=0.01,
downsample_ratio=8,
clip_bbox=True,
scale_x_y=1.0,
iou_aware=True,
iou_aware_factor=0.5)
assert ((boxes is not None) and (scores is not None))
class TestYoloBoxOpHW(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = False
self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13, 9)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.0
self.iou_aware = False
self.iou_aware_factor = 0.5
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册