未验证 提交 35b50b7f 编写于 作者: F FlyingQianMM 提交者: GitHub

[cherry pick] fix sample code diff for OP(retinanet_detection_output,...

[cherry pick] fix sample code diff for OP(retinanet_detection_output, retinanet_target_assign), fix  default value for OP(sigmoid_focal_loss) (#23744). test=release/2.0-beta (#23858)
上级 8566b26b
...@@ -233,7 +233,7 @@ def retinanet_target_assign(bbox_pred, ...@@ -233,7 +233,7 @@ def retinanet_target_assign(bbox_pred,
dtype='float32') dtype='float32')
is_crowd = fluid.data(name='is_crowd', shape=[1], is_crowd = fluid.data(name='is_crowd', shape=[1],
dtype='float32') dtype='float32')
im_info = fluid.data(name='im_infoss', shape=[1, 3], im_info = fluid.data(name='im_info', shape=[1, 3],
dtype='float32') dtype='float32')
score_pred, loc_pred, score_target, loc_target, bbox_inside_weight, fg_num = \\ score_pred, loc_pred, score_target, loc_target, bbox_inside_weight, fg_num = \\
fluid.layers.retinanet_target_assign(bbox_pred, cls_logits, anchor_box, fluid.layers.retinanet_target_assign(bbox_pred, cls_logits, anchor_box,
...@@ -452,7 +452,7 @@ def rpn_target_assign(bbox_pred, ...@@ -452,7 +452,7 @@ def rpn_target_assign(bbox_pred,
return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight
def sigmoid_focal_loss(x, label, fg_num, gamma=2, alpha=0.25): def sigmoid_focal_loss(x, label, fg_num, gamma=2.0, alpha=0.25):
""" """
**Sigmoid Focal Loss Operator.** **Sigmoid Focal Loss Operator.**
...@@ -493,9 +493,9 @@ def sigmoid_focal_loss(x, label, fg_num, gamma=2, alpha=0.25): ...@@ -493,9 +493,9 @@ def sigmoid_focal_loss(x, label, fg_num, gamma=2, alpha=0.25):
is int32. is int32.
fg_num(Variable): A 1-D tensor with shape [1] represents the number of positive samples in a fg_num(Variable): A 1-D tensor with shape [1] represents the number of positive samples in a
mini-batch, which should be obtained before this OP. The data type of :attr:`fg_num` is int32. mini-batch, which should be obtained before this OP. The data type of :attr:`fg_num` is int32.
gamma(float): Hyper-parameter to balance the easy and hard examples. Default value is gamma(int|float): Hyper-parameter to balance the easy and hard examples. Default value is
set to 2.0. set to 2.0.
alpha(float): Hyper-parameter to balance the positive and negative example. Default value alpha(int|float): Hyper-parameter to balance the positive and negative example. Default value
is set to 0.25. is set to 0.25.
Returns: Returns:
...@@ -514,7 +514,7 @@ def sigmoid_focal_loss(x, label, fg_num, gamma=2, alpha=0.25): ...@@ -514,7 +514,7 @@ def sigmoid_focal_loss(x, label, fg_num, gamma=2, alpha=0.25):
loss = fluid.layers.sigmoid_focal_loss(x=input, loss = fluid.layers.sigmoid_focal_loss(x=input,
label=label, label=label,
fg_num=fg_num, fg_num=fg_num,
gamma=2., gamma=2.0,
alpha=0.25) alpha=0.25)
""" """
...@@ -2914,7 +2914,7 @@ def retinanet_detection_output(bboxes, ...@@ -2914,7 +2914,7 @@ def retinanet_detection_output(bboxes,
nms_top_k=1000, nms_top_k=1000,
keep_top_k=100, keep_top_k=100,
nms_threshold=0.3, nms_threshold=0.3,
nms_eta=1.): nms_eta=1.0):
""" """
**Detection Output Layer for the detector RetinaNet.** **Detection Output Layer for the detector RetinaNet.**
...@@ -3010,15 +3010,15 @@ def retinanet_detection_output(bboxes, ...@@ -3010,15 +3010,15 @@ def retinanet_detection_output(bboxes,
im_info = fluid.data( im_info = fluid.data(
name="im_info", shape=[1, 3], dtype='float32') name="im_info", shape=[1, 3], dtype='float32')
nmsed_outs = fluid.layers.retinanet_detection_output( nmsed_outs = fluid.layers.retinanet_detection_output(
bboxes=[bboxes_low, bboxes_high], bboxes=[bboxes_low, bboxes_high],
scores=[scores_low, scores_high], scores=[scores_low, scores_high],
anchors=[anchors_low, anchors_high], anchors=[anchors_low, anchors_high],
im_info=im_info, im_info=im_info,
score_threshold=0.05, score_threshold=0.05,
nms_top_k=1000, nms_top_k=1000,
keep_top_k=100, keep_top_k=100,
nms_threshold=0.45, nms_threshold=0.45,
nms_eta=1.) nms_eta=1.0)
""" """
check_type(bboxes, 'bboxes', (list), 'retinanet_detection_output') check_type(bboxes, 'bboxes', (list), 'retinanet_detection_output')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册