提交 3c08f620 编写于 作者: D dengkaipeng

add label smooth. test=develop

上级 cc01db60
......@@ -159,7 +159,9 @@ static inline void CalcLabelLoss(T* loss, const T* input, const int index,
const int label, const int class_num,
const int stride) {
for (int i = 0; i < class_num; i++) {
loss[0] += SCE<T>(input[index + i * stride], (i == label) ? 1.0 : 0.0);
T pred = input[index + i * stride] < -0.5 ? input[index + i * stride]
: 1.0 / class_num;
loss[0] += SCE<T>(pred, (i == label) ? 1.0 : 0.0);
}
}
......@@ -169,8 +171,10 @@ static inline void CalcLabelLossGrad(T* input_grad, const T loss,
const int label, const int class_num,
const int stride) {
for (int i = 0; i < class_num; i++) {
T pred = input[index + i * stride] < -0.5 ? input[index + i * stride]
: 1.0 / class_num;
input_grad[index + i * stride] =
SCEGrad<T>(input[index + i * stride], (i == label) ? 1.0 : 0.0) * loss;
SCEGrad<T>(pred, (i == label) ? 1.0 : 0.0) * loss;
}
}
......@@ -406,15 +410,12 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
for (int i = 0; i < n; i++) {
for (int t = 0; t < b; t++) {
int mask_idx = gt_match_mask_data[i * b + t];
if (mask_idx >= 0) {
Box<T> gt = GetGtBox(gt_box_data, i, b, t);
if (LessEqualZero<T>(gt.w) || LessEqualZero<T>(gt.h)) {
continue;
}
int gi = static_cast<int>(gt.x * w);
int gj = static_cast<int>(gt.y * h);
int mask_idx = gt_match_mask_data[i * b + t];
if (mask_idx >= 0) {
int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 0);
CalcBoxLocationLossGrad<T>(
......
......@@ -86,6 +86,10 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs):
pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w
pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h
x[:, :, :, :, 5:] = np.where(x[:, :, :, :, 5:] < -0.5, x[:, :, :, :, 5:],
np.ones_like(x[:, :, :, :, 5:]) * 1.0 /
class_num)
mask_anchors = []
for m in anchor_mask:
mask_anchors.append((anchors[2 * m], anchors[2 * m + 1]))
......@@ -207,7 +211,7 @@ class TestYolov3LossOp(OpTest):
self.ignore_thresh = 0.7
self.downsample = 32
self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5)
self.gtbox_shape = (3, 10, 4)
self.gtbox_shape = (3, 5, 4)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册