未验证 提交 0ddcf30c 编写于 作者: C Charles-hit 提交者: GitHub

replace fill_zeros_like op with fill_any_like op (#45657)

* relace fill_zeros_like op with fill_any_like op in backward.py and tensor.py

* Remove unnecessary comments

* modify create op_desc param
上级 fcbb307c
...@@ -673,7 +673,7 @@ def _remove_no_grad_branch_(op_descs, ...@@ -673,7 +673,7 @@ def _remove_no_grad_branch_(op_descs,
op_desc for op_desc in op_descs op_desc for op_desc in op_descs
if not _op_can_be_removed_(op_desc, no_grad_set) if not _op_can_be_removed_(op_desc, no_grad_set)
] ]
# Insert fill_zeros_like_op # Insert fill_any_like_op with value 0
to_insert = [] to_insert = []
for idx, op_desc in enumerate(op_descs): for idx, op_desc in enumerate(op_descs):
for arg in op_desc.input_arg_names(): for arg in op_desc.input_arg_names():
...@@ -682,8 +682,11 @@ def _remove_no_grad_branch_(op_descs, ...@@ -682,8 +682,11 @@ def _remove_no_grad_branch_(op_descs,
x_in = _strip_grad_suffix_(arg) x_in = _strip_grad_suffix_(arg)
# the reason should be: arg can be input of another grad op # the reason should be: arg can be input of another grad op
# and the op is a not-to-remove op # and the op is a not-to-remove op
new_op_desc = _create_op_desc_("fill_zeros_like", {"X": [x_in]}, new_op_desc = _create_op_desc_("fill_any_like", {"X": [x_in]},
{"Out": [arg]}, {}) {"Out": [arg]}, {
'value': 0,
'dtype': -1
})
# update the mapping between fwd and bwd # update the mapping between fwd and bwd
if grad_op_id_to_fwd_op is not None and grad_op_id_to_fwd_op.get( if grad_op_id_to_fwd_op is not None and grad_op_id_to_fwd_op.get(
op_desc.original_id(), None) is not None: op_desc.original_id(), None) is not None:
......
...@@ -1680,10 +1680,9 @@ def zeros_like(x, out=None): ...@@ -1680,10 +1680,9 @@ def zeros_like(x, out=None):
data = fluid.layers.zeros_like(x) # [0.0, 0.0, 0.0] data = fluid.layers.zeros_like(x) # [0.0, 0.0, 0.0]
""" """
check_variable_and_dtype(x, "x", check_variable_and_dtype(x, "x",
['bool', 'float32', 'float64', 'int32', 'int64'], ['bool', 'float32', 'float64', 'int32', 'int64'],
'ones_like') 'zeros_like')
helper = LayerHelper("zeros_like", **locals()) helper = LayerHelper("zeros_like", **locals())
if out is None: if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
...@@ -1691,9 +1690,12 @@ def zeros_like(x, out=None): ...@@ -1691,9 +1690,12 @@ def zeros_like(x, out=None):
check_variable_and_dtype( check_variable_and_dtype(
out, "out", ['bool', 'float32', 'float64', 'int32', 'int64'], out, "out", ['bool', 'float32', 'float64', 'int32', 'int64'],
'zeros_like') 'zeros_like')
helper.append_op(type='fill_any_like',
helper.append_op(type='fill_zeros_like',
inputs={'X': [x]}, inputs={'X': [x]},
attrs={
'value': 0,
"dtype": x.dtype
},
outputs={'Out': [out]}) outputs={'Out': [out]})
out.stop_gradient = True out.stop_gradient = True
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册