diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index db7c03bb255d412f0e519b32b88f37c06741dcab..93d6d798dc4ebd794b469cea567b5aecd79f11ac 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -673,7 +673,7 @@ def _remove_no_grad_branch_(op_descs, op_desc for op_desc in op_descs if not _op_can_be_removed_(op_desc, no_grad_set) ] - # Insert fill_zeros_like_op + # Insert fill_any_like_op with value 0 to_insert = [] for idx, op_desc in enumerate(op_descs): for arg in op_desc.input_arg_names(): @@ -682,8 +682,11 @@ def _remove_no_grad_branch_(op_descs, x_in = _strip_grad_suffix_(arg) # the reason should be: arg can be input of another grad op # and the op is a not-to-remove op - new_op_desc = _create_op_desc_("fill_zeros_like", {"X": [x_in]}, - {"Out": [arg]}, {}) + new_op_desc = _create_op_desc_("fill_any_like", {"X": [x_in]}, + {"Out": [arg]}, { + 'value': 0, + 'dtype': -1 + }) # update the mapping between fwd and bwd if grad_op_id_to_fwd_op is not None and grad_op_id_to_fwd_op.get( op_desc.original_id(), None) is not None: diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 3df027931ccc58792eebbc37677fe98c3b016760..2910f4187a73e3d588862281eb12c8943f51523e 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1680,10 +1680,9 @@ def zeros_like(x, out=None): data = fluid.layers.zeros_like(x) # [0.0, 0.0, 0.0] """ - check_variable_and_dtype(x, "x", ['bool', 'float32', 'float64', 'int32', 'int64'], - 'ones_like') + 'zeros_like') helper = LayerHelper("zeros_like", **locals()) if out is None: out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -1691,9 +1690,12 @@ def zeros_like(x, out=None): check_variable_and_dtype( out, "out", ['bool', 'float32', 'float64', 'int32', 'int64'], 'zeros_like') - - helper.append_op(type='fill_zeros_like', + helper.append_op(type='fill_any_like', inputs={'X': [x]}, + attrs={ + 'value': 0, + "dtype": x.dtype + }, outputs={'Out': [out]}) out.stop_gradient = True return out