提交 bc48453b 编写于 作者: X xiaoting 提交者: Kaipeng Deng

polish the label_smooth (#17138)

* polish the label_smooth

test=develop

* polish code

test=develop
上级 bf4b21fa
...@@ -282,8 +282,9 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -282,8 +282,9 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
T label_pos = 1.0; T label_pos = 1.0;
T label_neg = 0.0; T label_neg = 0.0;
if (use_label_smooth) { if (use_label_smooth) {
label_pos = 1.0 - 1.0 / static_cast<T>(class_num); T smooth_weight = std::min(1.0 / static_cast<T>(class_num), 1.0 / 40);
label_neg = 1.0 / static_cast<T>(class_num); label_pos = 1.0 - smooth_weight;
label_neg = smooth_weight;
} }
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
...@@ -437,8 +438,9 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> { ...@@ -437,8 +438,9 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
T label_pos = 1.0; T label_pos = 1.0;
T label_neg = 0.0; T label_neg = 0.0;
if (use_label_smooth) { if (use_label_smooth) {
label_pos = 1.0 - 1.0 / static_cast<T>(class_num); T smooth_weight = std::min(1.0 / static_cast<T>(class_num), 1.0 / 40);
label_neg = 1.0 / static_cast<T>(class_num); label_pos = 1.0 - smooth_weight;
label_neg = smooth_weight;
} }
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
......
...@@ -81,8 +81,9 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): ...@@ -81,8 +81,9 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs):
x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
loss = np.zeros((n)).astype('float32') loss = np.zeros((n)).astype('float32')
label_pos = 1.0 - 1.0 / class_num if use_label_smooth else 1.0 smooth_weight = min(1.0 / class_num, 1.0 / 40)
label_neg = 1.0 / class_num if use_label_smooth else 0.0 label_pos = 1.0 - smooth_weight if use_label_smooth else 1.0
label_neg = smooth_weight if use_label_smooth else 0.0
pred_box = x[:, :, :, :, :4].copy() pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册