未验证 提交 123cf165 编写于 作者: Q qingqing01 提交者: GitHub

Set stop_gradient=True for some variables in SSD API. (#9396)

上级 e0b5691e
...@@ -134,6 +134,7 @@ def detection_output(loc, ...@@ -134,6 +134,7 @@ def detection_output(loc,
scores = nn.softmax(input=scores) scores = nn.softmax(input=scores)
scores = ops.reshape(x=scores, shape=old_shape) scores = ops.reshape(x=scores, shape=old_shape)
scores = nn.transpose(scores, perm=[0, 2, 1]) scores = nn.transpose(scores, perm=[0, 2, 1])
scores.stop_gradient = True
nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype) nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype)
helper.append_op( helper.append_op(
type="multiclass_nms", type="multiclass_nms",
...@@ -148,6 +149,7 @@ def detection_output(loc, ...@@ -148,6 +149,7 @@ def detection_output(loc,
'score_threshold': score_threshold, 'score_threshold': score_threshold,
'nms_eta': 1.0 'nms_eta': 1.0
}) })
nmsed_outs.stop_gradient = True
return nmsed_outs return nmsed_outs
...@@ -837,4 +839,6 @@ def multi_box_head(inputs, ...@@ -837,4 +839,6 @@ def multi_box_head(inputs,
mbox_locs_concat = tensor.concat(mbox_locs, axis=1) mbox_locs_concat = tensor.concat(mbox_locs, axis=1)
mbox_confs_concat = tensor.concat(mbox_confs, axis=1) mbox_confs_concat = tensor.concat(mbox_confs, axis=1)
box.stop_gradient = True
var.stop_gradient = True
return mbox_locs_concat, mbox_confs_concat, box, var return mbox_locs_concat, mbox_confs_concat, box, var
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册