未验证 提交 b77ebb2a 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #15919 from heavengate/yolo_box

add yolo_box for detection box calc in YOLOv3
......@@ -331,6 +331,7 @@ paddle.fluid.layers.iou_similarity (ArgSpec(args=['x', 'y', 'name'], varargs=Non
paddle.fluid.layers.box_coder (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0)), ('document', '032d0f4b7d8f6235ee5d91e473344f0e'))
paddle.fluid.layers.polygon_box_transform (ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '0e5ac2507723a0b5adec473f9556799b'))
paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'gtscore', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(None, True, None)), ('document', '57fa96922e42db8f064c3fb77f2255e8'))
paddle.fluid.layers.yolo_box (ArgSpec(args=['x', 'img_size', 'anchors', 'class_num', 'conf_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '5566169a5ab993d177792c023c7fb340'))
paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '397e9e02b451d99c56e20f268fa03f2e'))
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0'))
paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d'))
......
......@@ -33,6 +33,7 @@ detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu)
detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu)
if(WITH_GPU)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/yolo_box_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class YoloBoxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of YoloBoxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ImgSize"),
"Input(ImgSize) of YoloBoxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Boxes"),
"Output(Boxes) of YoloBoxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Scores"),
"Output(Scores) of YoloBoxOp should not be null.");
auto dim_x = ctx->GetInputDim("X");
auto dim_imgsize = ctx->GetInputDim("ImgSize");
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
int anchor_num = anchors.size() / 2;
auto class_num = ctx->Attrs().Get<int>("class_num");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
PADDLE_ENFORCE_EQ(
dim_x[1], anchor_num * (5 + class_num),
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num)).");
PADDLE_ENFORCE_EQ(dim_imgsize.size(), 2,
"Input(ImgSize) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(
dim_imgsize[0], dim_x[0],
"Input(ImgSize) dim[0] and Input(X) dim[0] should be same.");
PADDLE_ENFORCE_EQ(dim_imgsize[1], 2, "Input(ImgSize) dim[1] should be 2.");
PADDLE_ENFORCE_GT(anchors.size(), 0,
"Attr(anchors) length should be greater than 0.");
PADDLE_ENFORCE_EQ(anchors.size() % 2, 0,
"Attr(anchors) length should be even integer.");
PADDLE_ENFORCE_GT(class_num, 0,
"Attr(class_num) should be an integer greater than 0.");
int box_num = dim_x[2] * dim_x[3] * anchor_num;
std::vector<int64_t> dim_boxes({dim_x[0], box_num, 4});
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_boxes));
std::vector<int64_t> dim_scores({dim_x[0], box_num, class_num});
ctx->SetOutputDim("Scores", framework::make_ddim(dim_scores));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
};
class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input tensor of YoloBox operator is a 4-D tensor with "
"shape of [N, C, H, W]. The second dimension(C) stores "
"box locations, confidence score and classification one-hot "
"keys of each anchor box. Generally, X should be the output "
"of YOLOv3 network.");
AddInput("ImgSize",
"The image size tensor of YoloBox operator, "
"This is a 2-D tensor with shape of [N, 2]. This tensor holds "
"height and width of each input image used for resizing output "
"box in input image scale.");
AddOutput("Boxes",
"The output tensor of detection boxes of YoloBox operator, "
"This is a 3-D tensor with shape of [N, M, 4], N is the "
"batch num, M is output box number, and the 3rd dimension "
"stores [xmin, ymin, xmax, ymax] coordinates of boxes.");
AddOutput("Scores",
"The output tensor of detection boxes scores of YoloBox "
"operator, This is a 3-D tensor with shape of "
"[N, M, :attr:`class_num`], N is the batch num, M is "
"output box number.");
AddAttr<int>("class_num", "The number of classes to predict.");
AddAttr<std::vector<int>>("anchors",
"The anchor width and height, "
"it will be parsed pair by pair.")
.SetDefault(std::vector<int>{});
AddAttr<int>("downsample_ratio",
"The downsample ratio from network input to YoloBox operator "
"input, so 32, 16, 8 should be set for the first, second, "
"and thrid YoloBox operators.")
.SetDefault(32);
AddAttr<float>("conf_thresh",
"The confidence scores threshold of detection boxes. "
"Boxes with confidence scores under threshold should "
"be ignored.")
.SetDefault(0.01);
AddComment(R"DOC(
This operator generates YOLO detection boxes from output of YOLOv3 network.
The output of previous network is in shape [N, C, H, W], while H and W
should be the same, H and W specify the grid size, each grid point predict
given number boxes, this given number, which following will be represented as S,
is specified by the number of anchors. In the second dimension(the channel
dimension), C should be equal to S * (5 + class_num), class_num is the object
category number of source dataset(such as 80 in coco dataset), so the
second(channel) dimension, apart from 4 box location coordinates x, y, w, h,
also includes confidence score of the box and class one-hot key of each anchor
box.
Assume the 4 location coordinates are :math:`t_x, t_y, t_w, t_h`, the box
predictions should be as follows:
$$
b_x = \\sigma(t_x) + c_x
$$
$$
b_y = \\sigma(t_y) + c_y
$$
$$
b_w = p_w e^{t_w}
$$
$$
b_h = p_h e^{t_h}
$$
in the equation above, :math:`c_x, c_y` is the left top corner of current grid
and :math:`p_w, p_h` is specified by anchors.
The logistic regression value of the 5th channel of each anchor prediction boxes
represents the confidence score of each prediction box, and the logistic
regression value of the last :attr:`class_num` channels of each anchor prediction
boxes represents the classifcation scores. Boxes with confidence scores less than
:attr:`conf_thresh` should be ignored, and box final scores is the product of
confidence scores and classification scores.
$$
score_{pred} = score_{conf} * score_{class}
$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(yolo_box, ops::YoloBoxOp, ops::YoloBoxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel<float>,
ops::YoloBoxKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/yolo_box_op.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
T* scores, const float conf_thresh,
const int* anchors, const int n, const int h,
const int w, const int an_num, const int class_num,
const int box_num, int input_size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
T box[4];
for (; tid < n * box_num; tid += stride) {
int grid_num = h * w;
int i = tid / box_num;
int j = (tid % box_num) / grid_num;
int k = (tid % grid_num) / w;
int l = tid % w;
int an_stride = (5 + class_num) * grid_num;
int img_height = imgsize[2 * i];
int img_width = imgsize[2 * i + 1];
int obj_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4);
T conf = sigmoid<T>(input[obj_idx]);
if (conf < conf_thresh) {
continue;
}
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx,
grid_num, img_height, img_width);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width);
int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num;
CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf,
grid_num);
}
}
template <typename T>
class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* img_size = ctx.Input<Tensor>("ImgSize");
auto* boxes = ctx.Output<Tensor>("Boxes");
auto* scores = ctx.Output<Tensor>("Scores");
auto anchors = ctx.Attr<std::vector<int>>("anchors");
int class_num = ctx.Attr<int>("class_num");
float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
const int n = input->dims()[0];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int box_num = boxes->dims()[1];
const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h;
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = sizeof(int) * anchors.size();
auto anchors_ptr = allocator.Allocate(sizeof(int) * anchors.size());
int* anchors_data = reinterpret_cast<int*>(anchors_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
const auto cplace = platform::CPUPlace();
memory::Copy(gplace, anchors_data, cplace, anchors.data(), bytes,
dev_ctx.stream());
const T* input_data = input->data<T>();
const int* imgsize_data = img_size->data<int>();
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
T* scores_data =
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
set_zero(dev_ctx, boxes, static_cast<T>(0));
set_zero(dev_ctx, scores, static_cast<T>(0));
int grid_dim = (n * box_num + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, n, h, w, an_num, class_num, box_num, input_size);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(yolo_box, ops::YoloBoxOpCUDAKernel<float>,
ops::YoloBoxOpCUDAKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
HOSTDEVICE inline T sigmoid(T x) {
return 1.0 / (1.0 + std::exp(-x));
}
template <typename T>
HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i,
int j, int an_idx, int grid_size,
int input_size, int index, int stride,
int img_height, int img_width) {
box[0] = (i + sigmoid<T>(x[index])) * img_width / grid_size;
box[1] = (j + sigmoid<T>(x[index + stride])) * img_height / grid_size;
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
input_size;
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
img_height / input_size;
}
HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
int an_num, int an_stride, int stride,
int entry) {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
}
template <typename T>
HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx,
const int img_height,
const int img_width) {
boxes[box_idx] = box[0] - box[2] / 2;
boxes[box_idx + 1] = box[1] - box[3] / 2;
boxes[box_idx + 2] = box[0] + box[2] / 2;
boxes[box_idx + 3] = box[1] + box[3] / 2;
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<T>(0);
boxes[box_idx + 1] =
boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1
? boxes[box_idx + 2]
: static_cast<T>(img_width - 1);
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1
? boxes[box_idx + 3]
: static_cast<T>(img_height - 1);
}
template <typename T>
HOSTDEVICE inline void CalcLabelScore(T* scores, const T* input,
const int label_idx, const int score_idx,
const int class_num, const T conf,
const int stride) {
for (int i = 0; i < class_num; i++) {
scores[score_idx + i] = conf * sigmoid<T>(input[label_idx + i * stride]);
}
}
template <typename T>
class YoloBoxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* imgsize = ctx.Input<Tensor>("ImgSize");
auto* boxes = ctx.Output<Tensor>("Boxes");
auto* scores = ctx.Output<Tensor>("Scores");
auto anchors = ctx.Attr<std::vector<int>>("anchors");
int class_num = ctx.Attr<int>("class_num");
float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
const int n = input->dims()[0];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int box_num = boxes->dims()[1];
const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h;
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
Tensor anchors_;
auto anchors_data =
anchors_.mutable_data<int>({an_num * 2}, ctx.GetPlace());
std::copy(anchors.begin(), anchors.end(), anchors_data);
const T* input_data = input->data<T>();
const int* imgsize_data = imgsize->data<int>();
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
memset(boxes_data, 0, boxes->numel() * sizeof(T));
T* scores_data =
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
memset(scores_data, 0, scores->numel() * sizeof(T));
T box[4];
for (int i = 0; i < n; i++) {
int img_height = imgsize_data[2 * i];
int img_width = imgsize_data[2 * i + 1];
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
int obj_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 4);
T conf = sigmoid<T>(input_data[obj_idx]);
if (conf < conf_thresh) {
continue;
}
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0);
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, input_size,
box_idx, stride, img_height, img_width);
box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height,
img_width);
int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5);
int score_idx = (i * box_num + j * stride + k * w + l) * class_num;
CalcLabelScore<T>(scores_data, input_data, label_idx, score_idx,
class_num, conf, stride);
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -49,6 +49,7 @@ __all__ = [
'box_coder',
'polygon_box_transform',
'yolov3_loss',
'yolo_box',
'box_clip',
'multiclass_nms',
'distribute_fpn_proposals',
......@@ -628,6 +629,83 @@ def yolov3_loss(x,
return loss
@templatedoc(op_type="yolo_box")
def yolo_box(x,
img_size,
anchors,
class_num,
conf_thresh,
downsample_ratio,
name=None):
"""
${comment}
Args:
x (Variable): ${x_comment}
img_size (Variable): ${img_size_comment}
anchors (list|tuple): ${anchors_comment}
class_num (int): ${class_num_comment}
conf_thresh (float): ${conf_thresh_comment}
downsample_ratio (int): ${downsample_ratio_comment}
name (string): the name of yolo box layer. Default None.
Returns:
Variable: A 3-D tensor with shape [N, M, 4], the coordinates of boxes,
and a 3-D tensor with shape [N, M, :attr:`class_num`], the classification
scores of boxes.
Raises:
TypeError: Input x of yolov_box must be Variable
TypeError: Attr anchors of yolo box must be list or tuple
TypeError: Attr class_num of yolo box must be an integer
TypeError: Attr conf_thresh of yolo box must be a float number
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32')
anchors = [10, 13, 16, 30, 33, 23]
loss = fluid.layers.yolo_box(x=x, class_num=80, anchors=anchors,
conf_thresh=0.01, downsample_ratio=32)
"""
helper = LayerHelper('yolo_box', **locals())
if not isinstance(x, Variable):
raise TypeError("Input x of yolo_box must be Variable")
if not isinstance(img_size, Variable):
raise TypeError("Input img_size of yolo_box must be Variable")
if not isinstance(anchors, list) and not isinstance(anchors, tuple):
raise TypeError("Attr anchors of yolo_box must be list or tuple")
if not isinstance(class_num, int):
raise TypeError("Attr class_num of yolo_box must be an integer")
if not isinstance(conf_thresh, float):
raise TypeError("Attr ignore_thresh of yolo_box must be a float number")
boxes = helper.create_variable_for_type_inference(dtype=x.dtype)
scores = helper.create_variable_for_type_inference(dtype=x.dtype)
attrs = {
"anchors": anchors,
"class_num": class_num,
"conf_thresh": conf_thresh,
"downsample_ratio": downsample_ratio,
}
helper.append_op(
type='yolo_box',
inputs={
"X": x,
"ImgSize": img_size,
},
outputs={
'Boxes': boxes,
'Scores': scores,
},
attrs=attrs)
return boxes, scores
@templatedoc()
def detection_map(detect_res,
label,
......
......@@ -489,6 +489,16 @@ class TestYoloDetection(unittest.TestCase):
self.assertIsNotNone(loss)
def test_yolo_box(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
img_size = layers.data(name='img_size', shape=[2], dtype='int32')
boxes, scores = layers.yolo_box(x, img_size, [10, 13, 30, 13], 10,
0.01, 32)
self.assertIsNotNone(boxes)
self.assertIsNotNone(scores)
class TestBoxClip(unittest.TestCase):
def test_box_clip(self):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid import core
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-1.0 * x))
def YoloBox(x, img_size, attrs):
n, c, h, w = x.shape
anchors = attrs['anchors']
an_num = int(len(anchors) // 2)
class_num = attrs['class_num']
conf_thresh = attrs['conf_thresh']
downsample = attrs['downsample']
input_size = downsample * h
x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w
pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
anchors_s = np.array(
[(an_w / input_size, an_h / input_size) for an_w, an_h in anchors])
anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1))
anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1))
pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w
pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h
pred_conf = sigmoid(x[:, :, :, :, 4:5])
pred_conf[pred_conf < conf_thresh] = 0.
pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf
pred_box = pred_box * (pred_conf > 0.).astype('float32')
pred_box = pred_box.reshape((n, -1, 4))
pred_box[:, :, :2], pred_box[:, :, 2:4] = \
pred_box[:, :, :2] - pred_box[:, :, 2:4] / 2., \
pred_box[:, :, :2] + pred_box[:, :, 2:4] / 2.0
pred_box[:, :, 0] = pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis]
pred_box[:, :, 1] = pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis]
pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis]
pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis]
for i in range(len(pred_box)):
pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf)
pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf)
pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], -np.inf,
img_size[i, 1] - 1)
pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], -np.inf,
img_size[i, 0] - 1)
return pred_box, pred_score.reshape((n, -1, class_num))
class TestYoloBoxOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'yolo_box'
x = np.random.random(self.x_shape).astype('float32')
img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32')
self.attrs = {
"anchors": self.anchors,
"class_num": self.class_num,
"conf_thresh": self.conf_thresh,
"downsample": self.downsample,
}
self.inputs = {
'X': x,
'ImgSize': img_size,
}
boxes, scores = YoloBox(x, img_size, self.attrs)
self.outputs = {
"Boxes": boxes,
"Scores": scores,
}
def test_check_output(self):
self.check_output()
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册