未验证 提交 f9c1b6f9 编写于 作者: Q qingqing01 提交者: GitHub

Fix the backward transpiler bug in ssd_loss API. (#8648)

上级 ace512a3
...@@ -496,10 +496,15 @@ def ssd_loss(location, ...@@ -496,10 +496,15 @@ def ssd_loss(location,
# 5.1 Compute confidence loss. # 5.1 Compute confidence loss.
target_label = __reshape_to_2d(target_label) target_label = __reshape_to_2d(target_label)
target_label = tensor.cast(x=target_label, dtype='int64') target_label = tensor.cast(x=target_label, dtype='int64')
conf_loss = nn.softmax_with_cross_entropy(confidence, target_label) conf_loss = nn.softmax_with_cross_entropy(confidence, target_label)
target_conf_weight = __reshape_to_2d(target_conf_weight) target_conf_weight = __reshape_to_2d(target_conf_weight)
conf_loss = conf_loss * target_conf_weight conf_loss = conf_loss * target_conf_weight
# the target_label and target_conf_weight do not have gradient.
target_label.stop_gradient = True
target_conf_weight.stop_gradient = True
# 5.2 Compute regression loss. # 5.2 Compute regression loss.
location = __reshape_to_2d(location) location = __reshape_to_2d(location)
target_bbox = __reshape_to_2d(target_bbox) target_bbox = __reshape_to_2d(target_bbox)
...@@ -508,6 +513,10 @@ def ssd_loss(location, ...@@ -508,6 +513,10 @@ def ssd_loss(location,
target_loc_weight = __reshape_to_2d(target_loc_weight) target_loc_weight = __reshape_to_2d(target_loc_weight)
loc_loss = loc_loss * target_loc_weight loc_loss = loc_loss * target_loc_weight
# the target_bbox and target_loc_weight do not have gradient.
target_bbox.stop_gradient = True
target_loc_weight.stop_gradient = True
# 5.3 Compute overall weighted loss. # 5.3 Compute overall weighted loss.
loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss
# reshape to [N, Np], N is the batch size and Np is the prior box number. # reshape to [N, Np], N is the batch size and Np is the prior box number.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册