未验证 提交 5e1185de 编写于 作者: X xiongkun 提交者: GitHub

[dy2static] bugfix: make stop_gradient a cache key (#50883)

* [dy2static] bugfix: make stop_gradient a cache key
1. make stop_gradient cache key in dy2static.

* fix ci errors

* fix ci error

* fix ci error

* fix ci error
上级 3bba4af7
...@@ -70,8 +70,8 @@ class TestFunctionSpec(unittest.TestCase): ...@@ -70,8 +70,8 @@ class TestFunctionSpec(unittest.TestCase):
foo_spec.unified_args_and_kwargs([10], {'c': 4}) foo_spec.unified_args_and_kwargs([10], {'c': 4})
def test_args_to_input_spec(self): def test_args_to_input_spec(self):
a_spec = InputSpec([None, 10], name='a') a_spec = InputSpec([None, 10], name='a', stop_gradient=True)
b_spec = InputSpec([10], name='b') b_spec = InputSpec([10], name='b', stop_gradient=True)
a_tensor = paddle.static.data(name='a_var', shape=[4, 10]) a_tensor = paddle.static.data(name='a_var', shape=[4, 10])
b_tensor = paddle.static.data(name='b_var', shape=[4, 10]) b_tensor = paddle.static.data(name='b_var', shape=[4, 10])
...@@ -85,7 +85,8 @@ class TestFunctionSpec(unittest.TestCase): ...@@ -85,7 +85,8 @@ class TestFunctionSpec(unittest.TestCase):
self.assertTrue(len(input_with_spec) == 4) self.assertTrue(len(input_with_spec) == 4)
self.assertTrue(input_with_spec[0] == a_spec) # a self.assertTrue(input_with_spec[0] == a_spec) # a
self.assertTrue(input_with_spec[1] == b_spec) # b ans_b_spec = InputSpec([4, 10], name='b', stop_gradient=True)
self.assertTrue(input_with_spec[1] == ans_b_spec) # b
self.assertTrue(input_with_spec[2] == 1) # c self.assertTrue(input_with_spec[2] == 1) # c
self.assertTrue(input_with_spec[3] == 2) # d self.assertTrue(input_with_spec[3] == 2) # d
......
...@@ -36,12 +36,12 @@ class GridGenerator(nn.Layer): ...@@ -36,12 +36,12 @@ class GridGenerator(nn.Layer):
in_channels, 6, weight_attr=param_attr, bias_attr=bias_attr in_channels, 6, weight_attr=param_attr, bias_attr=bias_attr
) )
@paddle.jit.to_static( # @paddle.jit.to_static(
input_spec=[ # input_spec=[
paddle.static.InputSpec(shape=[None, 3, 32, 100], dtype='float32'), # paddle.static.InputSpec(shape=[None, 3, 32, 100], dtype='float32'),
paddle.static.InputSpec(shape=[32, 100], dtype='float32'), # paddle.static.InputSpec(shape=[32, 100], dtype='float32'),
] # ]
) # )
def forward(self, batch_C_prime, I_r_size): def forward(self, batch_C_prime, I_r_size):
""" """
Generate the grid for the grid_sampler. Generate the grid for the grid_sampler.
......
...@@ -111,22 +111,6 @@ class FunctionSpec: ...@@ -111,22 +111,6 @@ class FunctionSpec:
return tuple(args), kwargs return tuple(args), kwargs
def _replace_value_with_input_spec(self, args):
args_with_spec = []
for idx, input_var in enumerate(flatten(args)):
if isinstance(input_var, np.ndarray):
input_var = paddle.static.InputSpec.from_numpy(input_var)
_set_spec_stop_gradient(input_var, True)
elif isinstance(input_var, (core.VarBase, core.eager.Tensor)):
stop_gradient = input_var.stop_gradient
input_var = paddle.static.InputSpec.from_tensor(input_var)
_set_spec_stop_gradient(input_var, stop_gradient)
args_with_spec.append(input_var)
args_with_spec = pack_sequence_as(args, args_with_spec)
return args_with_spec
def args_to_input_spec(self, args, kwargs): def args_to_input_spec(self, args, kwargs):
""" """
Converts input arguments into InputSpec. Converts input arguments into InputSpec.
...@@ -167,8 +151,8 @@ class FunctionSpec: ...@@ -167,8 +151,8 @@ class FunctionSpec:
# replace argument with corresponding InputSpec. # replace argument with corresponding InputSpec.
args_with_spec = convert_to_input_spec(args, self._input_spec) args_with_spec = convert_to_input_spec(args, self._input_spec)
else: else:
args_with_spec = self._replace_value_with_input_spec(args) args_with_spec = _replace_value_with_input_spec(args)
kwargs_with_spec = self._replace_value_with_input_spec(kwargs) kwargs_with_spec = _replace_value_with_input_spec(kwargs)
# If without specificing name in input_spec, add default name # If without specificing name in input_spec, add default name
# according to argument name from decorated function. # according to argument name from decorated function.
...@@ -297,6 +281,28 @@ def get_buffers(layer_instance, include_sublayer=True): ...@@ -297,6 +281,28 @@ def get_buffers(layer_instance, include_sublayer=True):
return buffers return buffers
def _replace_value_with_input_spec(args):
args_with_spec = []
for idx, input_var in enumerate(flatten(args)):
if isinstance(input_var, np.ndarray):
input_var = paddle.static.InputSpec.from_numpy(input_var)
input_var.stop_gradient = True
elif isinstance(input_var, (core.VarBase, core.eager.Tensor)):
stop_gradient = input_var.stop_gradient
input_var = paddle.static.InputSpec.from_tensor(input_var)
input_var.stop_gradient = stop_gradient
elif isinstance(input_var, paddle.fluid.framework.Variable):
stop_gradient = input_var.stop_gradient
input_var = paddle.static.InputSpec(
input_var.shape, input_var.dtype, input_var.name
)
input_var.stop_gradient = stop_gradient
args_with_spec.append(input_var)
args_with_spec = pack_sequence_as(args, args_with_spec)
return args_with_spec
def convert_to_input_spec(inputs, input_spec): def convert_to_input_spec(inputs, input_spec):
""" """
Replaces tensor in structured `inputs` by InputSpec in `input_spec`. Replaces tensor in structured `inputs` by InputSpec in `input_spec`.
...@@ -358,7 +364,18 @@ def convert_to_input_spec(inputs, input_spec): ...@@ -358,7 +364,18 @@ def convert_to_input_spec(inputs, input_spec):
input_with_spec[name] = input input_with_spec[name] = input
return input_with_spec return input_with_spec
elif isinstance(input_spec, paddle.static.InputSpec): elif isinstance(input_spec, paddle.static.InputSpec):
return input_spec """we compare input_spec with real_input_spec constructed from arguments."""
real_spec = _replace_value_with_input_spec([inputs])[0]
if not isinstance(real_spec, paddle.static.InputSpec):
raise RuntimeError(
"Give input spec into a non-tensorable arguments `{}`.".format(
inputs
)
)
real_spec.name = input_spec.name
if spec_greater(input_spec, real_spec):
return input_spec
return real_spec
else: else:
# NOTE(Aurelius84): Support non-Tensor type as input spec info # NOTE(Aurelius84): Support non-Tensor type as input spec info
return input_spec return input_spec
...@@ -422,15 +439,6 @@ def _replace_spec_name(name, input_spec): ...@@ -422,15 +439,6 @@ def _replace_spec_name(name, input_spec):
return input_spec return input_spec
def _set_spec_stop_gradient(spec, stop_gradient):
"""
Set new attribute ``stop_gradient`` for InputSpec to avoid generating redundant grad_op
while append_backward.
"""
assert isinstance(spec, paddle.static.InputSpec)
spec.stop_gradient = stop_gradient
def _hash_spec_names(args_specs, kwargs_specs): def _hash_spec_names(args_specs, kwargs_specs):
""" """
Generater hash spec with args/kwargs InputSpec names. Generater hash spec with args/kwargs InputSpec names.
...@@ -462,3 +470,19 @@ def _hash_spec_names(args_specs, kwargs_specs): ...@@ -462,3 +470,19 @@ def _hash_spec_names(args_specs, kwargs_specs):
value = [to_idx(name) for name in spec_names] value = [to_idx(name) for name in spec_names]
return tuple(value) return tuple(value)
def spec_greater(first, other):
def _shape_greater(first_shape, second_shape):
if len(first_shape) != len(second_shape):
return False
for first_n, second_n in zip(first_shape, second_shape):
if first_n != -1 and first_n != second_n:
return False
return True
return (
other.stop_gradient == first.stop_gradient
and other.dtype == first.dtype
and _shape_greater(first.shape, other.shape)
)
...@@ -1191,6 +1191,7 @@ class ProgramCache: ...@@ -1191,6 +1191,7 @@ class ProgramCache:
) )
if not _in_amp_guard() and not _in_pure_fp16_guard(): if not _in_amp_guard() and not _in_pure_fp16_guard():
concrete_program._to_prim() concrete_program._to_prim()
return concrete_program, partial_program_from(concrete_program) return concrete_program, partial_program_from(concrete_program)
def __getitem__(self, item): def __getitem__(self, item):
......
...@@ -148,7 +148,7 @@ class InputSpec: ...@@ -148,7 +148,7 @@ class InputSpec:
print(label) # InputSpec(shape=(-1, 1), dtype=paddle.int64, name=label) print(label) # InputSpec(shape=(-1, 1), dtype=paddle.int64, name=label)
""" """
def __init__(self, shape, dtype='float32', name=None): def __init__(self, shape, dtype='float32', name=None, stop_gradient=False):
# replace `None` in shape with -1 # replace `None` in shape with -1
self.shape = self._verify(shape) self.shape = self._verify(shape)
# convert dtype into united represention # convert dtype into united represention
...@@ -157,13 +157,18 @@ class InputSpec: ...@@ -157,13 +157,18 @@ class InputSpec:
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
self.dtype = dtype self.dtype = dtype
self.name = name self.name = name
self.stop_gradient = stop_gradient
def _create_feed_layer(self): def _create_feed_layer(self):
return data(self.name, shape=self.shape, dtype=self.dtype) return data(self.name, shape=self.shape, dtype=self.dtype)
def __repr__(self): def __repr__(self):
return '{}(shape={}, dtype={}, name={})'.format( return '{}(shape={}, dtype={}, name={}, stop_gradient={})'.format(
type(self).__name__, self.shape, self.dtype, self.name type(self).__name__,
self.shape,
self.dtype,
self.name,
self.stop_gradient,
) )
@classmethod @classmethod
...@@ -327,10 +332,10 @@ class InputSpec: ...@@ -327,10 +332,10 @@ class InputSpec:
# foo(x_var) # foo(x_var)
# foo(x_np) # x_np is a numpy.ndarray. # foo(x_np) # x_np is a numpy.ndarray.
# x_var and x_np hold same shape and dtype, they should also share a same program. # x_var and x_np hold same shape and dtype, they should also share a same program.
return hash((tuple(self.shape), self.dtype)) return hash((tuple(self.shape), self.dtype, self.stop_gradient))
def __eq__(self, other): def __eq__(self, other):
slots = ['shape', 'dtype', 'name'] slots = ['shape', 'dtype', 'name', 'stop_gradient']
return type(self) is type(other) and all( return type(self) is type(other) and all(
getattr(self, attr) == getattr(other, attr) for attr in slots getattr(self, attr) == getattr(other, attr) for attr in slots
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册