未验证 提交 0ddcf30c 编写于 作者: C Charles-hit 提交者: GitHub

replace fill_zeros_like op with fill_any_like op (#45657)

* relace fill_zeros_like op with fill_any_like op in backward.py and tensor.py

* Remove unnecessary comments

* modify create op_desc param
上级 fcbb307c
......@@ -673,7 +673,7 @@ def _remove_no_grad_branch_(op_descs,
op_desc for op_desc in op_descs
if not _op_can_be_removed_(op_desc, no_grad_set)
]
# Insert fill_zeros_like_op
# Insert fill_any_like_op with value 0
to_insert = []
for idx, op_desc in enumerate(op_descs):
for arg in op_desc.input_arg_names():
......@@ -682,8 +682,11 @@ def _remove_no_grad_branch_(op_descs,
x_in = _strip_grad_suffix_(arg)
# the reason should be: arg can be input of another grad op
# and the op is a not-to-remove op
new_op_desc = _create_op_desc_("fill_zeros_like", {"X": [x_in]},
{"Out": [arg]}, {})
new_op_desc = _create_op_desc_("fill_any_like", {"X": [x_in]},
{"Out": [arg]}, {
'value': 0,
'dtype': -1
})
# update the mapping between fwd and bwd
if grad_op_id_to_fwd_op is not None and grad_op_id_to_fwd_op.get(
op_desc.original_id(), None) is not None:
......
......@@ -1680,10 +1680,9 @@ def zeros_like(x, out=None):
data = fluid.layers.zeros_like(x) # [0.0, 0.0, 0.0]
"""
check_variable_and_dtype(x, "x",
['bool', 'float32', 'float64', 'int32', 'int64'],
'ones_like')
'zeros_like')
helper = LayerHelper("zeros_like", **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -1691,9 +1690,12 @@ def zeros_like(x, out=None):
check_variable_and_dtype(
out, "out", ['bool', 'float32', 'float64', 'int32', 'int64'],
'zeros_like')
helper.append_op(type='fill_zeros_like',
helper.append_op(type='fill_any_like',
inputs={'X': [x]},
attrs={
'value': 0,
"dtype": x.dtype
},
outputs={'Out': [out]})
out.stop_gradient = True
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册