未验证 提交 abb49df2 编写于 作者: W wangguanzhong 提交者: GitHub

Enhance yolo_box & yolov3_loss (#24370)

* add scale_x_y for yolo_box, test=develop

* refine eps in iou_similarity, test=develop
上级 d1bb76a2
...@@ -18,7 +18,8 @@ limitations under the License. */ ...@@ -18,7 +18,8 @@ limitations under the License. */
template <typename T> template <typename T>
inline HOSTDEVICE T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2, inline HOSTDEVICE T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2,
T ymin2, T xmax2, T ymax2, bool normalized) { T ymin2, T xmax2, T ymax2, bool normalized,
T eps) {
constexpr T zero = static_cast<T>(0); constexpr T zero = static_cast<T>(0);
T area1; T area1;
T area2; T area2;
...@@ -43,19 +44,21 @@ inline HOSTDEVICE T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2, ...@@ -43,19 +44,21 @@ inline HOSTDEVICE T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2,
inter_height = inter_height > zero ? inter_height : zero; inter_height = inter_height > zero ? inter_height : zero;
inter_width = inter_width > zero ? inter_width : zero; inter_width = inter_width > zero ? inter_width : zero;
T inter_area = inter_width * inter_height; T inter_area = inter_width * inter_height;
T union_area = area1 + area2 - inter_area; T union_area = area1 + area2 - inter_area + eps;
T sim_score = inter_area / union_area; T sim_score = inter_area / union_area;
return sim_score; return sim_score;
} }
template <typename T> template <typename T>
struct IOUSimilarityFunctor { struct IOUSimilarityFunctor {
IOUSimilarityFunctor(const T* x, const T* y, T* z, int cols, bool normalized) IOUSimilarityFunctor(const T* x, const T* y, T* z, int cols, bool normalized,
T eps)
: x_(x), : x_(x),
y_(y), y_(y),
z_(z), z_(z),
cols_(static_cast<size_t>(cols)), cols_(static_cast<size_t>(cols)),
normalized_(normalized) {} normalized_(normalized),
eps_(eps) {}
inline HOSTDEVICE void operator()(size_t tid) const { inline HOSTDEVICE void operator()(size_t tid) const {
size_t row_id = tid / cols_; size_t row_id = tid / cols_;
...@@ -72,7 +75,7 @@ struct IOUSimilarityFunctor { ...@@ -72,7 +75,7 @@ struct IOUSimilarityFunctor {
T y_max2 = y_[col_id * 4 + 3]; T y_max2 = y_[col_id * 4 + 3];
T sim = IOUSimilarity(x_min1, y_min1, x_max1, y_max1, x_min2, y_min2, T sim = IOUSimilarity(x_min1, y_min1, x_max1, y_max1, x_min2, y_min2,
x_max2, y_max2, normalized_); x_max2, y_max2, normalized_, eps_);
z_[row_id * cols_ + col_id] = sim; z_[row_id * cols_ + col_id] = sim;
} }
...@@ -81,6 +84,7 @@ struct IOUSimilarityFunctor { ...@@ -81,6 +84,7 @@ struct IOUSimilarityFunctor {
T* z_; T* z_;
const size_t cols_; const size_t cols_;
bool normalized_; bool normalized_;
T eps_;
}; };
namespace paddle { namespace paddle {
...@@ -97,9 +101,10 @@ class IOUSimilarityKernel : public framework::OpKernel<T> { ...@@ -97,9 +101,10 @@ class IOUSimilarityKernel : public framework::OpKernel<T> {
int x_n = in_x->dims()[0]; int x_n = in_x->dims()[0];
int y_n = in_y->dims()[0]; int y_n = in_y->dims()[0];
T eps = static_cast<T>(1e-10);
IOUSimilarityFunctor<T> functor(in_x->data<T>(), in_y->data<T>(), IOUSimilarityFunctor<T> functor(in_x->data<T>(), in_y->data<T>(),
out->mutable_data<T>(ctx.GetPlace()), y_n, out->mutable_data<T>(ctx.GetPlace()), y_n,
normalized); normalized, eps);
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()), x_n * y_n); static_cast<const DeviceContext&>(ctx.device_context()), x_n * y_n);
......
...@@ -136,6 +136,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -136,6 +136,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"Whether clip output bonding box in Input(ImgSize) " "Whether clip output bonding box in Input(ImgSize) "
"boundary. Default true.") "boundary. Default true.")
.SetDefault(true); .SetDefault(true);
AddAttr<float>("scale_x_y",
"Scale the center point of decoded bounding "
"box. Default 1.0")
.SetDefault(1.);
AddComment(R"DOC( AddComment(R"DOC(
This operator generates YOLO detection boxes from output of YOLOv3 network. This operator generates YOLO detection boxes from output of YOLOv3 network.
......
...@@ -26,7 +26,8 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, ...@@ -26,7 +26,8 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
T* scores, const float conf_thresh, T* scores, const float conf_thresh,
const int* anchors, const int n, const int h, const int* anchors, const int n, const int h,
const int w, const int an_num, const int class_num, const int w, const int an_num, const int class_num,
const int box_num, int input_size, bool clip_bbox) { const int box_num, int input_size, bool clip_bbox,
const float scale, const float bias) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
T box[4]; T box[4];
...@@ -51,7 +52,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, ...@@ -51,7 +52,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int box_idx = int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0); GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx, GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx,
grid_num, img_height, img_width); grid_num, img_height, img_width, scale, bias);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4; box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox); CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox);
...@@ -77,6 +78,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -77,6 +78,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
float conf_thresh = ctx.Attr<float>("conf_thresh"); float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio"); int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool clip_bbox = ctx.Attr<bool>("clip_bbox"); bool clip_bbox = ctx.Attr<bool>("clip_bbox");
float scale = ctx.Attr<float>("scale_x_y");
float bias = -0.5 * (scale - 1.);
const int n = input->dims()[0]; const int n = input->dims()[0];
const int h = input->dims()[2]; const int h = input->dims()[2];
...@@ -109,7 +112,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -109,7 +112,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh, input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, n, h, w, an_num, class_num, box_num, input_size, anchors_data, n, h, w, an_num, class_num, box_num, input_size,
clip_bbox); clip_bbox, scale, bias);
} }
}; };
......
...@@ -29,9 +29,11 @@ template <typename T> ...@@ -29,9 +29,11 @@ template <typename T>
HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i, HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i,
int j, int an_idx, int grid_size, int j, int an_idx, int grid_size,
int input_size, int index, int stride, int input_size, int index, int stride,
int img_height, int img_width) { int img_height, int img_width, float scale,
box[0] = (i + sigmoid<T>(x[index])) * img_width / grid_size; float bias) {
box[1] = (j + sigmoid<T>(x[index + stride])) * img_height / grid_size; box[0] = (i + sigmoid<T>(x[index]) * scale + bias) * img_width / grid_size;
box[1] = (j + sigmoid<T>(x[index + stride]) * scale + bias) * img_height /
grid_size;
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
input_size; input_size;
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
...@@ -89,6 +91,8 @@ class YoloBoxKernel : public framework::OpKernel<T> { ...@@ -89,6 +91,8 @@ class YoloBoxKernel : public framework::OpKernel<T> {
float conf_thresh = ctx.Attr<float>("conf_thresh"); float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio"); int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool clip_bbox = ctx.Attr<bool>("clip_bbox"); bool clip_bbox = ctx.Attr<bool>("clip_bbox");
float scale = ctx.Attr<float>("scale_x_y");
float bias = -0.5 * (scale - 1.);
const int n = input->dims()[0]; const int n = input->dims()[0];
const int h = input->dims()[2]; const int h = input->dims()[2];
...@@ -131,7 +135,7 @@ class YoloBoxKernel : public framework::OpKernel<T> { ...@@ -131,7 +135,7 @@ class YoloBoxKernel : public framework::OpKernel<T> {
int box_idx = int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0);
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, input_size, GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, input_size,
box_idx, stride, img_height, img_width); box_idx, stride, img_height, img_width, scale, bias);
box_idx = (i * box_num + j * stride + k * w + l) * 4; box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width, CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width,
clip_bbox); clip_bbox);
......
...@@ -215,6 +215,10 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -215,6 +215,10 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_label_smooth", AddAttr<bool>("use_label_smooth",
"Whether to use label smooth. Default True.") "Whether to use label smooth. Default True.")
.SetDefault(true); .SetDefault(true);
AddAttr<float>("scale_x_y",
"Scale the center point of decoded bounding "
"box. Default 1.0")
.SetDefault(1.);
AddComment(R"DOC( AddComment(R"DOC(
This operator generates yolov3 loss based on given predict result and ground This operator generates yolov3 loss based on given predict result and ground
truth boxes. truth boxes.
......
...@@ -73,10 +73,11 @@ static inline T sigmoid(T x) { ...@@ -73,10 +73,11 @@ static inline T sigmoid(T x) {
template <typename T> template <typename T>
static inline Box<T> GetYoloBox(const T* x, std::vector<int> anchors, int i, static inline Box<T> GetYoloBox(const T* x, std::vector<int> anchors, int i,
int j, int an_idx, int grid_size, int j, int an_idx, int grid_size,
int input_size, int index, int stride) { int input_size, int index, int stride,
float scale, float bias) {
Box<T> b; Box<T> b;
b.x = (i + sigmoid<T>(x[index])) / grid_size; b.x = (i + sigmoid<T>(x[index]) * scale + bias) / grid_size;
b.y = (j + sigmoid<T>(x[index + stride])) / grid_size; b.y = (j + sigmoid<T>(x[index + stride]) * scale + bias) / grid_size;
b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] / input_size; b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] / input_size;
b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] / input_size; b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] / input_size;
return b; return b;
...@@ -267,6 +268,8 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -267,6 +268,8 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
float ignore_thresh = ctx.Attr<float>("ignore_thresh"); float ignore_thresh = ctx.Attr<float>("ignore_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio"); int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool use_label_smooth = ctx.Attr<bool>("use_label_smooth"); bool use_label_smooth = ctx.Attr<bool>("use_label_smooth");
float scale = ctx.Attr<float>("scale_x_y");
float bias = -0.5 * (scale - 1.);
const int n = input->dims()[0]; const int n = input->dims()[0];
const int h = input->dims()[2]; const int h = input->dims()[2];
...@@ -325,8 +328,9 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -325,8 +328,9 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
// then ignore_thresh, ignore the objectness loss. // then ignore_thresh, ignore the objectness loss.
int box_idx = int box_idx =
GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0); GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0);
Box<T> pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j], Box<T> pred =
h, input_size, box_idx, stride); GetYoloBox(input_data, anchors, l, k, anchor_mask[j], h,
input_size, box_idx, stride, scale, bias);
T best_iou = 0; T best_iou = 0;
for (int t = 0; t < b; t++) { for (int t = 0; t < b; t++) {
if (!gt_valid_mask_data[i * b + t]) { if (!gt_valid_mask_data[i * b + t]) {
......
...@@ -919,7 +919,8 @@ def yolov3_loss(x, ...@@ -919,7 +919,8 @@ def yolov3_loss(x,
downsample_ratio, downsample_ratio,
gt_score=None, gt_score=None,
use_label_smooth=True, use_label_smooth=True,
name=None): name=None,
scale_x_y=1.):
""" """
${comment} ${comment}
...@@ -945,6 +946,7 @@ def yolov3_loss(x, ...@@ -945,6 +946,7 @@ def yolov3_loss(x,
gt_score (Variable): mixup score of ground truth boxes, should be in shape gt_score (Variable): mixup score of ground truth boxes, should be in shape
of [N, B]. Default None. of [N, B]. Default None.
use_label_smooth (bool): ${use_label_smooth_comment} use_label_smooth (bool): ${use_label_smooth_comment}
scale_x_y (float): ${scale_x_y_comment}
Returns: Returns:
Variable: A 1-D tensor with shape [N], the value of yolov3 loss Variable: A 1-D tensor with shape [N], the value of yolov3 loss
...@@ -1017,6 +1019,7 @@ def yolov3_loss(x, ...@@ -1017,6 +1019,7 @@ def yolov3_loss(x,
"ignore_thresh": ignore_thresh, "ignore_thresh": ignore_thresh,
"downsample_ratio": downsample_ratio, "downsample_ratio": downsample_ratio,
"use_label_smooth": use_label_smooth, "use_label_smooth": use_label_smooth,
"scale_x_y": scale_x_y,
} }
helper.append_op( helper.append_op(
...@@ -1039,7 +1042,8 @@ def yolo_box(x, ...@@ -1039,7 +1042,8 @@ def yolo_box(x,
conf_thresh, conf_thresh,
downsample_ratio, downsample_ratio,
clip_bbox=True, clip_bbox=True,
name=None): name=None,
scale_x_y=1.):
""" """
${comment} ${comment}
...@@ -1051,6 +1055,7 @@ def yolo_box(x, ...@@ -1051,6 +1055,7 @@ def yolo_box(x,
conf_thresh (float): ${conf_thresh_comment} conf_thresh (float): ${conf_thresh_comment}
downsample_ratio (int): ${downsample_ratio_comment} downsample_ratio (int): ${downsample_ratio_comment}
clip_bbox (bool): ${clip_bbox_comment} clip_bbox (bool): ${clip_bbox_comment}
scale_x_y (float): ${scale_x_y_comment}
name (string): The default value is None. Normally there is no need name (string): The default value is None. Normally there is no need
for user to set this property. For more information, for user to set this property. For more information,
please refer to :ref:`api_guide_Name` please refer to :ref:`api_guide_Name`
...@@ -1099,6 +1104,7 @@ def yolo_box(x, ...@@ -1099,6 +1104,7 @@ def yolo_box(x,
"conf_thresh": conf_thresh, "conf_thresh": conf_thresh,
"downsample_ratio": downsample_ratio, "downsample_ratio": downsample_ratio,
"clip_bbox": clip_bbox, "clip_bbox": clip_bbox,
"scale_x_y": scale_x_y,
} }
helper.append_op( helper.append_op(
......
...@@ -552,6 +552,36 @@ class TestYoloDetection(unittest.TestCase): ...@@ -552,6 +552,36 @@ class TestYoloDetection(unittest.TestCase):
self.assertIsNotNone(boxes) self.assertIsNotNone(boxes)
self.assertIsNotNone(scores) self.assertIsNotNone(scores)
def test_yolov3_loss_with_scale(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
gt_box = layers.data(name='gt_box', shape=[10, 4], dtype='float32')
gt_label = layers.data(name='gt_label', shape=[10], dtype='int32')
gt_score = layers.data(name='gt_score', shape=[10], dtype='float32')
loss = layers.yolov3_loss(
x,
gt_box,
gt_label, [10, 13, 30, 13], [0, 1],
10,
0.7,
32,
gt_score=gt_score,
use_label_smooth=False,
scale_x_y=1.2)
self.assertIsNotNone(loss)
def test_yolo_box_with_scale(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
img_size = layers.data(name='img_size', shape=[2], dtype='int32')
boxes, scores = layers.yolo_box(
x, img_size, [10, 13, 30, 13], 10, 0.01, 32, scale_x_y=1.2)
self.assertIsNotNone(boxes)
self.assertIsNotNone(scores)
class TestBoxClip(unittest.TestCase): class TestBoxClip(unittest.TestCase):
def test_box_clip(self): def test_box_clip(self):
......
...@@ -33,6 +33,8 @@ def YoloBox(x, img_size, attrs): ...@@ -33,6 +33,8 @@ def YoloBox(x, img_size, attrs):
conf_thresh = attrs['conf_thresh'] conf_thresh = attrs['conf_thresh']
downsample = attrs['downsample'] downsample = attrs['downsample']
clip_bbox = attrs['clip_bbox'] clip_bbox = attrs['clip_bbox']
scale_x_y = attrs['scale_x_y']
bias_x_y = -0.5 * (scale_x_y - 1.)
input_size = downsample * h input_size = downsample * h
x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
...@@ -40,8 +42,10 @@ def YoloBox(x, img_size, attrs): ...@@ -40,8 +42,10 @@ def YoloBox(x, img_size, attrs):
pred_box = x[:, :, :, :, :4].copy() pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w pred_box[:, :, :, :, 0] = (
pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h grid_x + sigmoid(pred_box[:, :, :, :, 0]) * scale_x_y + bias_x_y) / w
pred_box[:, :, :, :, 1] = (
grid_y + sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y + bias_x_y) / h
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)] anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
anchors_s = np.array( anchors_s = np.array(
...@@ -90,6 +94,7 @@ class TestYoloBoxOp(OpTest): ...@@ -90,6 +94,7 @@ class TestYoloBoxOp(OpTest):
"conf_thresh": self.conf_thresh, "conf_thresh": self.conf_thresh,
"downsample": self.downsample, "downsample": self.downsample,
"clip_bbox": self.clip_bbox, "clip_bbox": self.clip_bbox,
"scale_x_y": self.scale_x_y,
} }
self.inputs = { self.inputs = {
...@@ -115,6 +120,7 @@ class TestYoloBoxOp(OpTest): ...@@ -115,6 +120,7 @@ class TestYoloBoxOp(OpTest):
self.clip_bbox = True self.clip_bbox = True
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2) self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.
class TestYoloBoxOpNoClipBbox(TestYoloBoxOp): class TestYoloBoxOpNoClipBbox(TestYoloBoxOp):
...@@ -128,6 +134,21 @@ class TestYoloBoxOpNoClipBbox(TestYoloBoxOp): ...@@ -128,6 +134,21 @@ class TestYoloBoxOpNoClipBbox(TestYoloBoxOp):
self.clip_bbox = False self.clip_bbox = False
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2) self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.
class TestYoloBoxOpScaleXY(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = True
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.2
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -77,6 +77,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): ...@@ -77,6 +77,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs):
ignore_thresh = attrs['ignore_thresh'] ignore_thresh = attrs['ignore_thresh']
downsample_ratio = attrs['downsample_ratio'] downsample_ratio = attrs['downsample_ratio']
use_label_smooth = attrs['use_label_smooth'] use_label_smooth = attrs['use_label_smooth']
scale_x_y = attrs['scale_x_y']
bias_x_y = -0.5 * (scale_x_y - 1.)
input_size = downsample_ratio * h input_size = downsample_ratio * h
x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
loss = np.zeros((n)).astype('float64') loss = np.zeros((n)).astype('float64')
...@@ -88,8 +90,10 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): ...@@ -88,8 +90,10 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs):
pred_box = x[:, :, :, :, :4].copy() pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w pred_box[:, :, :, :, 0] = (
pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h grid_x + sigmoid(pred_box[:, :, :, :, 0]) * scale_x_y + bias_x_y) / w
pred_box[:, :, :, :, 1] = (
grid_y + sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y + bias_x_y) / h
mask_anchors = [] mask_anchors = []
for m in anchor_mask: for m in anchor_mask:
...@@ -180,6 +184,7 @@ class TestYolov3LossOp(OpTest): ...@@ -180,6 +184,7 @@ class TestYolov3LossOp(OpTest):
"ignore_thresh": self.ignore_thresh, "ignore_thresh": self.ignore_thresh,
"downsample_ratio": self.downsample_ratio, "downsample_ratio": self.downsample_ratio,
"use_label_smooth": self.use_label_smooth, "use_label_smooth": self.use_label_smooth,
"scale_x_y": self.scale_x_y,
} }
self.inputs = { self.inputs = {
...@@ -222,6 +227,7 @@ class TestYolov3LossOp(OpTest): ...@@ -222,6 +227,7 @@ class TestYolov3LossOp(OpTest):
self.gtbox_shape = (3, 5, 4) self.gtbox_shape = (3, 5, 4)
self.gtscore = True self.gtscore = True
self.use_label_smooth = True self.use_label_smooth = True
self.scale_x_y = 1.
class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp): class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp):
...@@ -238,6 +244,7 @@ class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp): ...@@ -238,6 +244,7 @@ class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp):
self.gtbox_shape = (3, 5, 4) self.gtbox_shape = (3, 5, 4)
self.gtscore = True self.gtscore = True
self.use_label_smooth = False self.use_label_smooth = False
self.scale_x_y = 1.
class TestYolov3LossNoGTScore(TestYolov3LossOp): class TestYolov3LossNoGTScore(TestYolov3LossOp):
...@@ -254,6 +261,24 @@ class TestYolov3LossNoGTScore(TestYolov3LossOp): ...@@ -254,6 +261,24 @@ class TestYolov3LossNoGTScore(TestYolov3LossOp):
self.gtbox_shape = (3, 5, 4) self.gtbox_shape = (3, 5, 4)
self.gtscore = False self.gtscore = False
self.use_label_smooth = True self.use_label_smooth = True
self.scale_x_y = 1.
class TestYolov3LossWithScaleXY(TestYolov3LossOp):
def initTestCase(self):
self.anchors = [
10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198,
373, 326
]
self.anchor_mask = [0, 1, 2]
self.class_num = 5
self.ignore_thresh = 0.7
self.downsample_ratio = 32
self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5)
self.gtbox_shape = (3, 5, 4)
self.gtscore = True
self.use_label_smooth = True
self.scale_x_y = 1.2
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册