提交 0c4acc83 编写于 作者: D dengkaipeng

imporve yolo loss implement. test=develop

上级 2fbfef2e
...@@ -34,11 +34,12 @@ class Yolov3LossOp : public framework::OperatorWithKernel { ...@@ -34,11 +34,12 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
auto dim_gtbox = ctx->GetInputDim("GTBox"); auto dim_gtbox = ctx->GetInputDim("GTBox");
auto dim_gtlabel = ctx->GetInputDim("GTLabel"); auto dim_gtlabel = ctx->GetInputDim("GTLabel");
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors"); auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
int anchor_num = anchors.size() / 2;
auto class_num = ctx->Attrs().Get<int>("class_num"); auto class_num = ctx->Attrs().Get<int>("class_num");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3],
"Input(X) dim[3] and dim[4] should be euqal."); "Input(X) dim[3] and dim[4] should be euqal.");
PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num), PADDLE_ENFORCE_EQ(dim_x[1], anchor_num * (5 + class_num),
"Input(X) dim[1] should be equal to (anchor_number * (5 " "Input(X) dim[1] should be equal to (anchor_number * (5 "
"+ class_num))."); "+ class_num)).");
PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3, PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3,
...@@ -105,20 +106,6 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -105,20 +106,6 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(406); .SetDefault(406);
AddAttr<float>("ignore_thresh", AddAttr<float>("ignore_thresh",
"The ignore threshold to ignore confidence loss."); "The ignore threshold to ignore confidence loss.");
AddAttr<float>("loss_weight_xy", "The weight of x, y location loss.")
.SetDefault(1.0);
AddAttr<float>("loss_weight_wh", "The weight of w, h location loss.")
.SetDefault(1.0);
AddAttr<float>(
"loss_weight_conf_target",
"The weight of confidence score loss in locations with target object.")
.SetDefault(1.0);
AddAttr<float>("loss_weight_conf_notarget",
"The weight of confidence score loss in locations without "
"target object.")
.SetDefault(1.0);
AddAttr<float>("loss_weight_class", "The weight of classification loss.")
.SetDefault(1.0);
AddComment(R"DOC( AddComment(R"DOC(
This operator generate yolov3 loss by given predict result and ground This operator generate yolov3 loss by given predict result and ground
truth boxes. truth boxes.
......
...@@ -416,11 +416,6 @@ def yolov3_loss(x, ...@@ -416,11 +416,6 @@ def yolov3_loss(x,
class_num, class_num,
ignore_thresh, ignore_thresh,
input_size, input_size,
loss_weight_xy=None,
loss_weight_wh=None,
loss_weight_conf_target=None,
loss_weight_conf_notarget=None,
loss_weight_class=None,
name=None): name=None):
""" """
${comment} ${comment}
...@@ -438,11 +433,6 @@ def yolov3_loss(x, ...@@ -438,11 +433,6 @@ def yolov3_loss(x,
class_num (int): ${class_num_comment} class_num (int): ${class_num_comment}
ignore_thresh (float): ${ignore_thresh_comment} ignore_thresh (float): ${ignore_thresh_comment}
input_size (int): ${input_size_comment} input_size (int): ${input_size_comment}
loss_weight_xy (float|None): ${loss_weight_xy_comment}
loss_weight_wh (float|None): ${loss_weight_wh_comment}
loss_weight_conf_target (float|None): ${loss_weight_conf_target_comment}
loss_weight_conf_notarget (float|None): ${loss_weight_conf_notarget_comment}
loss_weight_class (float|None): ${loss_weight_class_comment}
name (string): the name of yolov3 loss name (string): the name of yolov3 loss
Returns: Returns:
...@@ -495,18 +485,18 @@ def yolov3_loss(x, ...@@ -495,18 +485,18 @@ def yolov3_loss(x,
"input_size": input_size, "input_size": input_size,
} }
if loss_weight_xy is not None and isinstance(loss_weight_xy, float): # if loss_weight_xy is not None and isinstance(loss_weight_xy, float):
self.attrs['loss_weight_xy'] = loss_weight_xy # self.attrs['loss_weight_xy'] = loss_weight_xy
if loss_weight_wh is not None and isinstance(loss_weight_wh, float): # if loss_weight_wh is not None and isinstance(loss_weight_wh, float):
self.attrs['loss_weight_wh'] = loss_weight_wh # self.attrs['loss_weight_wh'] = loss_weight_wh
if loss_weight_conf_target is not None and isinstance( # if loss_weight_conf_target is not None and isinstance(
loss_weight_conf_target, float): # loss_weight_conf_target, float):
self.attrs['loss_weight_conf_target'] = loss_weight_conf_target # self.attrs['loss_weight_conf_target'] = loss_weight_conf_target
if loss_weight_conf_notarget is not None and isinstance( # if loss_weight_conf_notarget is not None and isinstance(
loss_weight_conf_notarget, float): # loss_weight_conf_notarget, float):
self.attrs['loss_weight_conf_notarget'] = loss_weight_conf_notarget # self.attrs['loss_weight_conf_notarget'] = loss_weight_conf_notarget
if loss_weight_class is not None and isinstance(loss_weight_class, float): # if loss_weight_class is not None and isinstance(loss_weight_class, float):
self.attrs['loss_weight_class'] = loss_weight_class # self.attrs['loss_weight_class'] = loss_weight_class
helper.append_op( helper.append_op(
type='yolov3_loss', type='yolov3_loss',
......
...@@ -470,8 +470,6 @@ class OpTest(unittest.TestCase): ...@@ -470,8 +470,6 @@ class OpTest(unittest.TestCase):
] ]
analytic_grads = self._get_gradient(inputs_to_check, place, analytic_grads = self._get_gradient(inputs_to_check, place,
output_names, no_grad_set) output_names, no_grad_set)
# print(numeric_grads[0][0, 4, :, :])
# print(analytic_grads[0][0, 4, :, :])
self._assert_is_close(numeric_grads, analytic_grads, inputs_to_check, self._assert_is_close(numeric_grads, analytic_grads, inputs_to_check,
max_relative_error, max_relative_error,
......
...@@ -80,8 +80,8 @@ def build_target(gtboxes, gtlabel, attrs, grid_size): ...@@ -80,8 +80,8 @@ def build_target(gtboxes, gtlabel, attrs, grid_size):
class_num = attrs["class_num"] class_num = attrs["class_num"]
input_size = attrs["input_size"] input_size = attrs["input_size"]
an_num = len(anchors) // 2 an_num = len(anchors) // 2
conf_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32')
obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32')
tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
...@@ -114,10 +114,10 @@ def build_target(gtboxes, gtlabel, attrs, grid_size): ...@@ -114,10 +114,10 @@ def build_target(gtboxes, gtlabel, attrs, grid_size):
max_iou = iou max_iou = iou
best_an_index = k best_an_index = k
if iou > ignore_thresh: if iou > ignore_thresh:
noobj_mask[i, best_an_index, gj, gi] = 0 conf_mask[i, best_an_index, gj, gi] = 0
conf_mask[i, best_an_index, gj, gi] = 1
obj_mask[i, best_an_index, gj, gi] = 1 obj_mask[i, best_an_index, gj, gi] = 1
noobj_mask[i, best_an_index, gj, gi] = 0
tx[i, best_an_index, gj, gi] = gx - gi tx[i, best_an_index, gj, gi] = gx - gi
ty[i, best_an_index, gj, gi] = gy - gj ty[i, best_an_index, gj, gi] = gy - gj
tw[i, best_an_index, gj, gi] = np.log(gw / anchors[2 * tw[i, best_an_index, gj, gi] = np.log(gw / anchors[2 *
...@@ -129,7 +129,7 @@ def build_target(gtboxes, gtlabel, attrs, grid_size): ...@@ -129,7 +129,7 @@ def build_target(gtboxes, gtlabel, attrs, grid_size):
tconf[i, best_an_index, gj, gi] = 1 tconf[i, best_an_index, gj, gi] = 1
tcls[i, best_an_index, gj, gi, gt_label] = 1 tcls[i, best_an_index, gj, gi, gt_label] = 1
return (tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask) return (tx, ty, tw, th, tweight, tconf, tcls, conf_mask, obj_mask)
def YoloV3Loss(x, gtbox, gtlabel, attrs): def YoloV3Loss(x, gtbox, gtlabel, attrs):
...@@ -144,11 +144,9 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): ...@@ -144,11 +144,9 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs):
pred_conf = x[:, :, :, :, 4] pred_conf = x[:, :, :, :, 4]
pred_cls = x[:, :, :, :, 5:] pred_cls = x[:, :, :, :, 5:]
tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask = build_target( tx, ty, tw, th, tweight, tconf, tcls, conf_mask, obj_mask = build_target(
gtbox, gtlabel, attrs, x.shape[2]) gtbox, gtlabel, attrs, x.shape[2])
# print("obj_mask: ", obj_mask[0, 0, :, :])
# print("noobj_mask: ", noobj_mask[0, 0, :, :])
obj_weight = obj_mask * tweight obj_weight = obj_mask * tweight
obj_mask_expand = np.tile( obj_mask_expand = np.tile(
np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num'])))
...@@ -156,30 +154,19 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): ...@@ -156,30 +154,19 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs):
loss_y = sce(pred_y, ty, obj_weight) loss_y = sce(pred_y, ty, obj_weight)
loss_w = l1loss(pred_w, tw, obj_weight) loss_w = l1loss(pred_w, tw, obj_weight)
loss_h = l1loss(pred_h, th, obj_weight) loss_h = l1loss(pred_h, th, obj_weight)
loss_conf_target = sce(pred_conf, tconf, obj_mask) loss_obj = sce(pred_conf, tconf, conf_mask)
loss_conf_notarget = sce(pred_conf, tconf, noobj_mask)
loss_class = sce(pred_cls, tcls, obj_mask_expand) loss_class = sce(pred_cls, tcls, obj_mask_expand)
# print("loss_xy: ", loss_x + loss_y) # print("python loss_xy: ", loss_x + loss_y)
# print("loss_wh: ", loss_w + loss_h) # print("python loss_wh: ", loss_w + loss_h)
# print("loss_conf_target: ", loss_conf_target) # print("python loss_obj: ", loss_obj)
# print("loss_conf_notarget: ", loss_conf_notarget) # print("python loss_class: ", loss_class)
# print("loss_class: ", loss_class)
return attrs['loss_weight_xy'] * (loss_x + loss_y) \ return loss_x + loss_y + loss_w + loss_h + loss_obj + loss_class
+ attrs['loss_weight_wh'] * (loss_w + loss_h) \
+ attrs['loss_weight_conf_target'] * loss_conf_target \
+ attrs['loss_weight_conf_notarget'] * loss_conf_notarget \
+ attrs['loss_weight_class'] * loss_class
class TestYolov3LossOp(OpTest): class TestYolov3LossOp(OpTest):
def setUp(self): def setUp(self):
self.loss_weight_xy = 1.0
self.loss_weight_wh = 1.0
self.loss_weight_conf_target = 1.0
self.loss_weight_conf_notarget = 1.0
self.loss_weight_class = 1.0
self.initTestCase() self.initTestCase()
self.op_type = 'yolov3_loss' self.op_type = 'yolov3_loss'
x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32')) x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32'))
...@@ -192,11 +179,6 @@ class TestYolov3LossOp(OpTest): ...@@ -192,11 +179,6 @@ class TestYolov3LossOp(OpTest):
"class_num": self.class_num, "class_num": self.class_num,
"ignore_thresh": self.ignore_thresh, "ignore_thresh": self.ignore_thresh,
"input_size": self.input_size, "input_size": self.input_size,
"loss_weight_xy": self.loss_weight_xy,
"loss_weight_wh": self.loss_weight_wh,
"loss_weight_conf_target": self.loss_weight_conf_target,
"loss_weight_conf_notarget": self.loss_weight_conf_notarget,
"loss_weight_class": self.loss_weight_class,
} }
self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel} self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel}
...@@ -215,17 +197,12 @@ class TestYolov3LossOp(OpTest): ...@@ -215,17 +197,12 @@ class TestYolov3LossOp(OpTest):
max_relative_error=0.31) max_relative_error=0.31)
def initTestCase(self): def initTestCase(self):
self.anchors = [12, 12] self.anchors = [12, 12, 11, 13]
self.class_num = 5 self.class_num = 5
self.ignore_thresh = 0.3 self.ignore_thresh = 0.5
self.input_size = 416 self.input_size = 416
self.x_shape = (3, len(self.anchors) // 2 * (5 + self.class_num), 5, 5) self.x_shape = (3, len(self.anchors) // 2 * (5 + self.class_num), 5, 5)
self.gtbox_shape = (3, 5, 4) self.gtbox_shape = (3, 5, 4)
self.loss_weight_xy = 1.2
self.loss_weight_wh = 0.8
self.loss_weight_conf_target = 2.0
self.loss_weight_conf_notarget = 1.0
self.loss_weight_class = 1.5
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册