未验证 提交 b154470c 编写于 作者: W wangxinxin08 提交者: GitHub

add two attributes for yolo box (#33400)

* add two attributes for yolo box
上级 ddc95a01
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "paddle/fluid/operators/detection/yolo_box_op.h" #include "paddle/fluid/operators/detection/yolo_box_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -31,11 +32,35 @@ class YoloBoxOp : public framework::OperatorWithKernel { ...@@ -31,11 +32,35 @@ class YoloBoxOp : public framework::OperatorWithKernel {
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors"); auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
int anchor_num = anchors.size() / 2; int anchor_num = anchors.size() / 2;
auto class_num = ctx->Attrs().Get<int>("class_num"); auto class_num = ctx->Attrs().Get<int>("class_num");
auto iou_aware = ctx->Attrs().Get<bool>("iou_aware");
auto iou_aware_factor = ctx->Attrs().Get<float>("iou_aware_factor");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(dim_x.size(), 4, platform::errors::InvalidArgument(
"Input(X) should be a 4-D tensor." "Input(X) should be a 4-D tensor."
"But received X dimension(%s)", "But received X dimension(%s)",
dim_x.size())); dim_x.size()));
if (iou_aware) {
PADDLE_ENFORCE_EQ(
dim_x[1], anchor_num * (6 + class_num),
platform::errors::InvalidArgument(
"Input(X) dim[1] should be equal to (anchor_mask_number * (6 "
"+ class_num)) while iou_aware is true."
"But received dim[1](%s) != (anchor_mask_number * "
"(6+class_num)(%s).",
dim_x[1], anchor_num * (6 + class_num)));
PADDLE_ENFORCE_GE(
iou_aware_factor, 0,
platform::errors::InvalidArgument(
"Attr(iou_aware_factor) should greater than or equal to 0."
"But received iou_aware_factor (%s)",
iou_aware_factor));
PADDLE_ENFORCE_LE(
iou_aware_factor, 1,
platform::errors::InvalidArgument(
"Attr(iou_aware_factor) should less than or equal to 1."
"But received iou_aware_factor (%s)",
iou_aware_factor));
} else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim_x[1], anchor_num * (5 + class_num), dim_x[1], anchor_num * (5 + class_num),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -44,6 +69,7 @@ class YoloBoxOp : public framework::OperatorWithKernel { ...@@ -44,6 +69,7 @@ class YoloBoxOp : public framework::OperatorWithKernel {
"But received dim[1](%s) != (anchor_mask_number * " "But received dim[1](%s) != (anchor_mask_number * "
"(5+class_num)(%s).", "(5+class_num)(%s).",
dim_x[1], anchor_num * (5 + class_num))); dim_x[1], anchor_num * (5 + class_num)));
}
PADDLE_ENFORCE_EQ(dim_imgsize.size(), 2, PADDLE_ENFORCE_EQ(dim_imgsize.size(), 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input(ImgSize) should be a 2-D tensor." "Input(ImgSize) should be a 2-D tensor."
...@@ -140,6 +166,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -140,6 +166,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"Scale the center point of decoded bounding " "Scale the center point of decoded bounding "
"box. Default 1.0") "box. Default 1.0")
.SetDefault(1.); .SetDefault(1.);
AddAttr<bool>("iou_aware", "Whether use iou aware. Default false.")
.SetDefault(false);
AddAttr<float>("iou_aware_factor", "iou aware factor. Default 0.5.")
.SetDefault(0.5);
AddComment(R"DOC( AddComment(R"DOC(
This operator generates YOLO detection boxes from output of YOLOv3 network. This operator generates YOLO detection boxes from output of YOLOv3 network.
...@@ -147,7 +177,8 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -147,7 +177,8 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
should be the same, H and W specify the grid size, each grid point predict should be the same, H and W specify the grid size, each grid point predict
given number boxes, this given number, which following will be represented as S, given number boxes, this given number, which following will be represented as S,
is specified by the number of anchors. In the second dimension(the channel is specified by the number of anchors. In the second dimension(the channel
dimension), C should be equal to S * (5 + class_num), class_num is the object dimension), C should be equal to S * (5 + class_num) if :attr:`iou_aware` is false,
otherwise C should be equal to S * (6 + class_num). class_num is the object
category number of source dataset(such as 80 in coco dataset), so the category number of source dataset(such as 80 in coco dataset), so the
second(channel) dimension, apart from 4 box location coordinates x, y, w, h, second(channel) dimension, apart from 4 box location coordinates x, y, w, h,
also includes confidence score of the box and class one-hot key of each anchor also includes confidence score of the box and class one-hot key of each anchor
...@@ -183,6 +214,15 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -183,6 +214,15 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
score_{pred} = score_{conf} * score_{class} score_{pred} = score_{conf} * score_{class}
$$ $$
where the confidence scores follow the formula bellow
.. math::
score_{conf} = \begin{case}
obj, \text{if } iou_aware == flase \\
obj^{1 - iou_aware_factor} * iou^{iou_aware_factor}, \text{otherwise}
\end{case}
)DOC"); )DOC");
} }
}; };
...@@ -197,3 +237,12 @@ REGISTER_OPERATOR( ...@@ -197,3 +237,12 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel<float>, REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel<float>,
ops::YoloBoxKernel<double>); ops::YoloBoxKernel<double>);
REGISTER_OP_VERSION(yolo_box)
.AddCheckpoint(
R"ROC(
Upgrade yolo box to add new attribute [iou_aware, iou_aware_factor].
)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("iou_aware", "Whether use iou aware", false)
.NewAttr("iou_aware_factor", "iou aware factor", 0.5f));
...@@ -28,7 +28,8 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, ...@@ -28,7 +28,8 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
const int w, const int an_num, const int class_num, const int w, const int an_num, const int class_num,
const int box_num, int input_size_h, const int box_num, int input_size_h,
int input_size_w, bool clip_bbox, const float scale, int input_size_w, bool clip_bbox, const float scale,
const float bias) { const float bias, bool iou_aware,
const float iou_aware_factor) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
T box[4]; T box[4];
...@@ -43,23 +44,29 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, ...@@ -43,23 +44,29 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int img_height = imgsize[2 * i]; int img_height = imgsize[2 * i];
int img_width = imgsize[2 * i + 1]; int img_width = imgsize[2 * i + 1];
int obj_idx = int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4,
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4); iou_aware);
T conf = sigmoid<T>(input[obj_idx]); T conf = sigmoid<T>(input[obj_idx]);
if (iou_aware) {
int iou_idx = GetIoUIndex(i, j, k * w + l, an_num, an_stride, grid_num);
T iou = sigmoid<T>(input[iou_idx]);
conf = pow(conf, static_cast<T>(1. - iou_aware_factor)) *
pow(iou, static_cast<T>(iou_aware_factor));
}
if (conf < conf_thresh) { if (conf < conf_thresh) {
continue; continue;
} }
int box_idx = int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0,
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0); iou_aware);
GetYoloBox<T>(box, input, anchors, l, k, j, h, w, input_size_h, GetYoloBox<T>(box, input, anchors, l, k, j, h, w, input_size_h,
input_size_w, box_idx, grid_num, img_height, img_width, scale, input_size_w, box_idx, grid_num, img_height, img_width, scale,
bias); bias);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4; box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox); CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox);
int label_idx = int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num,
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); 5, iou_aware);
int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num; int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num;
CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf, CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf,
grid_num); grid_num);
...@@ -80,6 +87,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -80,6 +87,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
float conf_thresh = ctx.Attr<float>("conf_thresh"); float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio"); int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool clip_bbox = ctx.Attr<bool>("clip_bbox"); bool clip_bbox = ctx.Attr<bool>("clip_bbox");
bool iou_aware = ctx.Attr<bool>("iou_aware");
float iou_aware_factor = ctx.Attr<float>("iou_aware_factor");
float scale = ctx.Attr<float>("scale_x_y"); float scale = ctx.Attr<float>("scale_x_y");
float bias = -0.5 * (scale - 1.); float bias = -0.5 * (scale - 1.);
...@@ -115,7 +124,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -115,7 +124,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
ctx.cuda_device_context().stream()>>>( ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh, input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, n, h, w, an_num, class_num, box_num, input_size_h, anchors_data, n, h, w, an_num, class_num, box_num, input_size_h,
input_size_w, clip_bbox, scale, bias); input_size_w, clip_bbox, scale, bias, iou_aware, iou_aware_factor);
} }
}; };
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
namespace paddle { namespace paddle {
...@@ -43,8 +44,19 @@ HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i, ...@@ -43,8 +44,19 @@ HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i,
HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
int an_num, int an_stride, int stride, int an_num, int an_stride, int stride,
int entry) { int entry, bool iou_aware) {
if (iou_aware) {
return (batch * an_num + an_idx) * an_stride +
(batch * an_num + an_num + entry) * stride + hw_idx;
} else {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
}
}
HOSTDEVICE inline int GetIoUIndex(int batch, int an_idx, int hw_idx, int an_num,
int an_stride, int stride) {
return batch * an_num * an_stride + (batch * an_num + an_idx) * stride +
hw_idx;
} }
template <typename T> template <typename T>
...@@ -92,6 +104,8 @@ class YoloBoxKernel : public framework::OpKernel<T> { ...@@ -92,6 +104,8 @@ class YoloBoxKernel : public framework::OpKernel<T> {
float conf_thresh = ctx.Attr<float>("conf_thresh"); float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio"); int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool clip_bbox = ctx.Attr<bool>("clip_bbox"); bool clip_bbox = ctx.Attr<bool>("clip_bbox");
bool iou_aware = ctx.Attr<bool>("iou_aware");
float iou_aware_factor = ctx.Attr<float>("iou_aware_factor");
float scale = ctx.Attr<float>("scale_x_y"); float scale = ctx.Attr<float>("scale_x_y");
float bias = -0.5 * (scale - 1.); float bias = -0.5 * (scale - 1.);
...@@ -127,15 +141,22 @@ class YoloBoxKernel : public framework::OpKernel<T> { ...@@ -127,15 +141,22 @@ class YoloBoxKernel : public framework::OpKernel<T> {
for (int j = 0; j < an_num; j++) { for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) { for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) { for (int l = 0; l < w; l++) {
int obj_idx = int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride,
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 4); stride, 4, iou_aware);
T conf = sigmoid<T>(input_data[obj_idx]); T conf = sigmoid<T>(input_data[obj_idx]);
if (iou_aware) {
int iou_idx =
GetIoUIndex(i, j, k * w + l, an_num, an_stride, stride);
T iou = sigmoid<T>(input_data[iou_idx]);
conf = pow(conf, static_cast<T>(1. - iou_aware_factor)) *
pow(iou, static_cast<T>(iou_aware_factor));
}
if (conf < conf_thresh) { if (conf < conf_thresh) {
continue; continue;
} }
int box_idx = int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride,
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); stride, 0, iou_aware);
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, w, GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, w,
input_size_h, input_size_w, box_idx, stride, input_size_h, input_size_w, box_idx, stride,
img_height, img_width, scale, bias); img_height, img_width, scale, bias);
...@@ -143,8 +164,8 @@ class YoloBoxKernel : public framework::OpKernel<T> { ...@@ -143,8 +164,8 @@ class YoloBoxKernel : public framework::OpKernel<T> {
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width, CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width,
clip_bbox); clip_bbox);
int label_idx = int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride,
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5); stride, 5, iou_aware);
int score_idx = (i * box_num + j * stride + k * w + l) * class_num; int score_idx = (i * box_num + j * stride + k * w + l) * class_num;
CalcLabelScore<T>(scores_data, input_data, label_idx, score_idx, CalcLabelScore<T>(scores_data, input_data, label_idx, score_idx,
class_num, conf, stride); class_num, conf, stride);
......
...@@ -1139,7 +1139,9 @@ def yolo_box(x, ...@@ -1139,7 +1139,9 @@ def yolo_box(x,
downsample_ratio, downsample_ratio,
clip_bbox=True, clip_bbox=True,
name=None, name=None,
scale_x_y=1.): scale_x_y=1.,
iou_aware=False,
iou_aware_factor=0.5):
""" """
${comment} ${comment}
...@@ -1156,6 +1158,8 @@ def yolo_box(x, ...@@ -1156,6 +1158,8 @@ def yolo_box(x,
name (string): The default value is None. Normally there is no need name (string): The default value is None. Normally there is no need
for user to set this property. For more information, for user to set this property. For more information,
please refer to :ref:`api_guide_Name` please refer to :ref:`api_guide_Name`
iou_aware (bool): ${iou_aware_comment}
iou_aware_factor (float): ${iou_aware_factor_comment}
Returns: Returns:
Variable: A 3-D tensor with shape [N, M, 4], the coordinates of boxes, Variable: A 3-D tensor with shape [N, M, 4], the coordinates of boxes,
...@@ -1204,6 +1208,8 @@ def yolo_box(x, ...@@ -1204,6 +1208,8 @@ def yolo_box(x,
"downsample_ratio": downsample_ratio, "downsample_ratio": downsample_ratio,
"clip_bbox": clip_bbox, "clip_bbox": clip_bbox,
"scale_x_y": scale_x_y, "scale_x_y": scale_x_y,
"iou_aware": iou_aware,
"iou_aware_factor": iou_aware_factor
} }
helper.append_op( helper.append_op(
......
...@@ -35,10 +35,16 @@ def YoloBox(x, img_size, attrs): ...@@ -35,10 +35,16 @@ def YoloBox(x, img_size, attrs):
downsample = attrs['downsample'] downsample = attrs['downsample']
clip_bbox = attrs['clip_bbox'] clip_bbox = attrs['clip_bbox']
scale_x_y = attrs['scale_x_y'] scale_x_y = attrs['scale_x_y']
iou_aware = attrs['iou_aware']
iou_aware_factor = attrs['iou_aware_factor']
bias_x_y = -0.5 * (scale_x_y - 1.) bias_x_y = -0.5 * (scale_x_y - 1.)
input_h = downsample * h input_h = downsample * h
input_w = downsample * w input_w = downsample * w
if iou_aware:
ioup = x[:, :an_num, :, :]
ioup = np.expand_dims(ioup, axis=-1)
x = x[:, an_num:, :, :]
x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
pred_box = x[:, :, :, :, :4].copy() pred_box = x[:, :, :, :, :4].copy()
...@@ -57,6 +63,10 @@ def YoloBox(x, img_size, attrs): ...@@ -57,6 +63,10 @@ def YoloBox(x, img_size, attrs):
pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w
pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h
if iou_aware:
pred_conf = sigmoid(x[:, :, :, :, 4:5])**(
1 - iou_aware_factor) * sigmoid(ioup)**iou_aware_factor
else:
pred_conf = sigmoid(x[:, :, :, :, 4:5]) pred_conf = sigmoid(x[:, :, :, :, 4:5])
pred_conf[pred_conf < conf_thresh] = 0. pred_conf[pred_conf < conf_thresh] = 0.
pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf
...@@ -97,6 +107,8 @@ class TestYoloBoxOp(OpTest): ...@@ -97,6 +107,8 @@ class TestYoloBoxOp(OpTest):
"downsample": self.downsample, "downsample": self.downsample,
"clip_bbox": self.clip_bbox, "clip_bbox": self.clip_bbox,
"scale_x_y": self.scale_x_y, "scale_x_y": self.scale_x_y,
"iou_aware": self.iou_aware,
"iou_aware_factor": self.iou_aware_factor
} }
self.inputs = { self.inputs = {
...@@ -123,6 +135,8 @@ class TestYoloBoxOp(OpTest): ...@@ -123,6 +135,8 @@ class TestYoloBoxOp(OpTest):
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2) self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1. self.scale_x_y = 1.
self.iou_aware = False
self.iou_aware_factor = 0.5
class TestYoloBoxOpNoClipBbox(TestYoloBoxOp): class TestYoloBoxOpNoClipBbox(TestYoloBoxOp):
...@@ -137,6 +151,8 @@ class TestYoloBoxOpNoClipBbox(TestYoloBoxOp): ...@@ -137,6 +151,8 @@ class TestYoloBoxOpNoClipBbox(TestYoloBoxOp):
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2) self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1. self.scale_x_y = 1.
self.iou_aware = False
self.iou_aware_factor = 0.5
class TestYoloBoxOpScaleXY(TestYoloBoxOp): class TestYoloBoxOpScaleXY(TestYoloBoxOp):
...@@ -151,19 +167,36 @@ class TestYoloBoxOpScaleXY(TestYoloBoxOp): ...@@ -151,19 +167,36 @@ class TestYoloBoxOpScaleXY(TestYoloBoxOp):
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2) self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.2 self.scale_x_y = 1.2
self.iou_aware = False
self.iou_aware_factor = 0.5
class TestYoloBoxOpIoUAware(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = True
self.x_shape = (self.batch_size, an_num * (6 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.
self.iou_aware = True
self.iou_aware_factor = 0.5
class TestYoloBoxDygraph(unittest.TestCase): class TestYoloBoxDygraph(unittest.TestCase):
def test_dygraph(self): def test_dygraph(self):
paddle.disable_static() paddle.disable_static()
x = np.random.random([2, 14, 8, 8]).astype('float32')
img_size = np.ones((2, 2)).astype('int32') img_size = np.ones((2, 2)).astype('int32')
x = paddle.to_tensor(x)
img_size = paddle.to_tensor(img_size) img_size = paddle.to_tensor(img_size)
x1 = np.random.random([2, 14, 8, 8]).astype('float32')
x1 = paddle.to_tensor(x1)
boxes, scores = paddle.vision.ops.yolo_box( boxes, scores = paddle.vision.ops.yolo_box(
x, x1,
img_size=img_size, img_size=img_size,
anchors=[10, 13, 16, 30], anchors=[10, 13, 16, 30],
class_num=2, class_num=2,
...@@ -172,16 +205,30 @@ class TestYoloBoxDygraph(unittest.TestCase): ...@@ -172,16 +205,30 @@ class TestYoloBoxDygraph(unittest.TestCase):
clip_bbox=True, clip_bbox=True,
scale_x_y=1.) scale_x_y=1.)
assert boxes is not None and scores is not None assert boxes is not None and scores is not None
x2 = np.random.random([2, 16, 8, 8]).astype('float32')
x2 = paddle.to_tensor(x2)
boxes, scores = paddle.vision.ops.yolo_box(
x2,
img_size=img_size,
anchors=[10, 13, 16, 30],
class_num=2,
conf_thresh=0.01,
downsample_ratio=8,
clip_bbox=True,
scale_x_y=1.,
iou_aware=True,
iou_aware_factor=0.5)
paddle.enable_static() paddle.enable_static()
class TestYoloBoxStatic(unittest.TestCase): class TestYoloBoxStatic(unittest.TestCase):
def test_static(self): def test_static(self):
x = paddle.static.data('x', [2, 14, 8, 8], 'float32') x1 = paddle.static.data('x1', [2, 14, 8, 8], 'float32')
img_size = paddle.static.data('img_size', [2, 2], 'int32') img_size = paddle.static.data('img_size', [2, 2], 'int32')
boxes, scores = paddle.vision.ops.yolo_box( boxes, scores = paddle.vision.ops.yolo_box(
x, x1,
img_size=img_size, img_size=img_size,
anchors=[10, 13, 16, 30], anchors=[10, 13, 16, 30],
class_num=2, class_num=2,
...@@ -191,6 +238,20 @@ class TestYoloBoxStatic(unittest.TestCase): ...@@ -191,6 +238,20 @@ class TestYoloBoxStatic(unittest.TestCase):
scale_x_y=1.) scale_x_y=1.)
assert boxes is not None and scores is not None assert boxes is not None and scores is not None
x2 = paddle.static.data('x2', [2, 16, 8, 8], 'float32')
boxes, scores = paddle.vision.ops.yolo_box(
x2,
img_size=img_size,
anchors=[10, 13, 16, 30],
class_num=2,
conf_thresh=0.01,
downsample_ratio=8,
clip_bbox=True,
scale_x_y=1.,
iou_aware=True,
iou_aware_factor=0.5)
assert boxes is not None and scores is not None
class TestYoloBoxOpHW(TestYoloBoxOp): class TestYoloBoxOpHW(TestYoloBoxOp):
def initTestCase(self): def initTestCase(self):
...@@ -204,6 +265,8 @@ class TestYoloBoxOpHW(TestYoloBoxOp): ...@@ -204,6 +265,8 @@ class TestYoloBoxOpHW(TestYoloBoxOp):
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 9) self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 9)
self.imgsize_shape = (self.batch_size, 2) self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1. self.scale_x_y = 1.
self.iou_aware = False
self.iou_aware_factor = 0.5
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -247,7 +247,9 @@ def yolo_box(x, ...@@ -247,7 +247,9 @@ def yolo_box(x,
downsample_ratio, downsample_ratio,
clip_bbox=True, clip_bbox=True,
name=None, name=None,
scale_x_y=1.): scale_x_y=1.,
iou_aware=False,
iou_aware_factor=0.5):
r""" r"""
This operator generates YOLO detection boxes from output of YOLOv3 network. This operator generates YOLO detection boxes from output of YOLOv3 network.
...@@ -256,7 +258,8 @@ def yolo_box(x, ...@@ -256,7 +258,8 @@ def yolo_box(x,
should be the same, H and W specify the grid size, each grid point predict should be the same, H and W specify the grid size, each grid point predict
given number boxes, this given number, which following will be represented as S, given number boxes, this given number, which following will be represented as S,
is specified by the number of anchors. In the second dimension(the channel is specified by the number of anchors. In the second dimension(the channel
dimension), C should be equal to S * (5 + class_num), class_num is the object dimension), C should be equal to S * (5 + class_num) if :attr:`iou_aware` is false,
otherwise C should be equal to S * (6 + class_num). class_num is the object
category number of source dataset(such as 80 in coco dataset), so the category number of source dataset(such as 80 in coco dataset), so the
second(channel) dimension, apart from 4 box location coordinates x, y, w, h, second(channel) dimension, apart from 4 box location coordinates x, y, w, h,
also includes confidence score of the box and class one-hot key of each anchor also includes confidence score of the box and class one-hot key of each anchor
...@@ -292,6 +295,15 @@ def yolo_box(x, ...@@ -292,6 +295,15 @@ def yolo_box(x,
score_{pred} = score_{conf} * score_{class} score_{pred} = score_{conf} * score_{class}
$$ $$
where the confidence scores follow the formula bellow
.. math::
score_{conf} = \begin{case}
obj, \text{if } iou_aware == flase \\
obj^{1 - iou_aware_factor} * iou^{iou_aware_factor}, \text{otherwise}
\end{case}
Args: Args:
x (Tensor): The input tensor of YoloBox operator is a 4-D tensor with x (Tensor): The input tensor of YoloBox operator is a 4-D tensor with
shape of [N, C, H, W]. The second dimension(C) stores box shape of [N, C, H, W]. The second dimension(C) stores box
...@@ -313,13 +325,14 @@ def yolo_box(x, ...@@ -313,13 +325,14 @@ def yolo_box(x,
should be set for the first, second, and thrid should be set for the first, second, and thrid
:attr:`yolo_box` layer. :attr:`yolo_box` layer.
clip_bbox (bool): Whether clip output bonding box in :attr:`img_size` clip_bbox (bool): Whether clip output bonding box in :attr:`img_size`
boundary. Default true." boundary. Default true.
"
scale_x_y (float): Scale the center point of decoded bounding box. scale_x_y (float): Scale the center point of decoded bounding box.
Default 1.0 Default 1.0
name (string): The default value is None. Normally there is no need name (string): The default value is None. Normally there is no need
for user to set this property. For more information, for user to set this property. For more information,
please refer to :ref:`api_guide_Name` please refer to :ref:`api_guide_Name`
iou_aware (bool): Whether use iou aware. Default false
iou_aware_factor (float): iou aware factor. Default 0.5
Returns: Returns:
Tensor: A 3-D tensor with shape [N, M, 4], the coordinates of boxes, Tensor: A 3-D tensor with shape [N, M, 4], the coordinates of boxes,
...@@ -358,7 +371,8 @@ def yolo_box(x, ...@@ -358,7 +371,8 @@ def yolo_box(x,
boxes, scores = core.ops.yolo_box( boxes, scores = core.ops.yolo_box(
x, img_size, 'anchors', anchors, 'class_num', class_num, x, img_size, 'anchors', anchors, 'class_num', class_num,
'conf_thresh', conf_thresh, 'downsample_ratio', downsample_ratio, 'conf_thresh', conf_thresh, 'downsample_ratio', downsample_ratio,
'clip_bbox', clip_bbox, 'scale_x_y', scale_x_y) 'clip_bbox', clip_bbox, 'scale_x_y', scale_x_y, 'iou_aware',
iou_aware, 'iou_aware_factor', iou_aware_factor)
return boxes, scores return boxes, scores
helper = LayerHelper('yolo_box', **locals()) helper = LayerHelper('yolo_box', **locals())
...@@ -378,6 +392,8 @@ def yolo_box(x, ...@@ -378,6 +392,8 @@ def yolo_box(x,
"downsample_ratio": downsample_ratio, "downsample_ratio": downsample_ratio,
"clip_bbox": clip_bbox, "clip_bbox": clip_bbox,
"scale_x_y": scale_x_y, "scale_x_y": scale_x_y,
"iou_aware": iou_aware,
"iou_aware_factor": iou_aware_factor
} }
helper.append_op( helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册