未验证 提交 9aa6bfc7 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix yolov3 return value in dygraph mode. test=develop (#40185)

上级 2ce007ca
...@@ -305,6 +305,7 @@ class TestYolov3LossDygraph(unittest.TestCase): ...@@ -305,6 +305,7 @@ class TestYolov3LossDygraph(unittest.TestCase):
use_label_smooth=True, use_label_smooth=True,
scale_x_y=1.) scale_x_y=1.)
assert loss is not None assert loss is not None
assert loss.shape == [2]
paddle.enable_static() paddle.enable_static()
......
...@@ -195,7 +195,7 @@ def yolo_loss(x, ...@@ -195,7 +195,7 @@ def yolo_loss(x,
""" """
if in_dygraph_mode() and gt_score is None: if in_dygraph_mode() and gt_score is None:
loss = _C_ops.yolov3_loss( loss, _, _ = _C_ops.yolov3_loss(
x, gt_box, gt_label, 'anchors', anchors, 'anchor_mask', anchor_mask, x, gt_box, gt_label, 'anchors', anchors, 'anchor_mask', anchor_mask,
'class_num', class_num, 'ignore_thresh', ignore_thresh, 'class_num', class_num, 'ignore_thresh', ignore_thresh,
'downsample_ratio', downsample_ratio, 'use_label_smooth', 'downsample_ratio', downsample_ratio, 'use_label_smooth',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册