未验证 提交 3b381aac 编写于 作者: 0 0x45f 提交者: GitHub

Use _C_ops.yolov3_loss in eager mode for test_yolov3.py (#40831)

* Use _C_ops.yolov3_loss in eager mode for test_yolov3.py

* fix code for test_yolov3_loss_op

* remove useless import

* Fix dygraph_mode flag
上级 fe8acb67
......@@ -96,6 +96,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"nce",
{"Input", "Label", "Weight", "Bias", "SampleWeight", "CustomDistProbs",
"CustomDistAlias", "CustomDistAliasProbs"}},
{"yolov3_loss", {"X", "GTBox", "GTLabel", "GTScore"}},
{"check_finite_and_unscale", {"X", "Scale", "FloatStatus"}},
{"group_norm", {"X", "Scale", "Bias"}},
{"linear_chain_crf", {"Emission", "Transition", "Label", "Length"}},
......
......@@ -1072,7 +1072,6 @@ def yolov3_loss(x,
anchor_mask=anchor_mask, class_num=80,
ignore_thresh=0.7, downsample_ratio=32)
"""
helper = LayerHelper('yolov3_loss', **locals())
if not isinstance(x, Variable):
raise TypeError("Input x of yolov3_loss must be Variable")
......@@ -1095,8 +1094,16 @@ def yolov3_loss(x,
raise TypeError(
"Attr use_label_smooth of yolov3_loss must be a bool value")
loss = helper.create_variable_for_type_inference(dtype=x.dtype)
if _non_static_mode():
attrs = ("anchors", anchors, "anchor_mask", anchor_mask, "class_num",
class_num, "ignore_thresh", ignore_thresh, "downsample_ratio",
downsample_ratio, "use_label_smooth", use_label_smooth,
"scale_x_y", scale_x_y)
loss, _, _ = _C_ops.yolov3_loss(x, gt_box, gt_label, gt_score, *attrs)
return loss
helper = LayerHelper('yolov3_loss', **locals())
loss = helper.create_variable_for_type_inference(dtype=x.dtype)
objectness_mask = helper.create_variable_for_type_inference(dtype='int32')
gt_match_mask = helper.create_variable_for_type_inference(dtype='int32')
......
......@@ -194,10 +194,10 @@ def yolo_loss(x,
scale_x_y=1.)
"""
if _non_static_mode() and gt_score is None:
if _non_static_mode():
loss, _, _ = _C_ops.yolov3_loss(
x, gt_box, gt_label, 'anchors', anchors, 'anchor_mask', anchor_mask,
'class_num', class_num, 'ignore_thresh', ignore_thresh,
x, gt_box, gt_label, gt_score, 'anchors', anchors, 'anchor_mask',
anchor_mask, 'class_num', class_num, 'ignore_thresh', ignore_thresh,
'downsample_ratio', downsample_ratio, 'use_label_smooth',
use_label_smooth, 'scale_x_y', scale_x_y)
return loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册