未验证 提交 5e1185de 编写于 作者: X xiongkun 提交者: GitHub

[dy2static] bugfix: make stop_gradient a cache key (#50883)

* [dy2static] bugfix: make stop_gradient a cache key
1. make stop_gradient cache key in dy2static.

* fix ci errors

* fix ci error

* fix ci error

* fix ci error
上级 3bba4af7
......@@ -70,8 +70,8 @@ class TestFunctionSpec(unittest.TestCase):
foo_spec.unified_args_and_kwargs([10], {'c': 4})
def test_args_to_input_spec(self):
a_spec = InputSpec([None, 10], name='a')
b_spec = InputSpec([10], name='b')
a_spec = InputSpec([None, 10], name='a', stop_gradient=True)
b_spec = InputSpec([10], name='b', stop_gradient=True)
a_tensor = paddle.static.data(name='a_var', shape=[4, 10])
b_tensor = paddle.static.data(name='b_var', shape=[4, 10])
......@@ -85,7 +85,8 @@ class TestFunctionSpec(unittest.TestCase):
self.assertTrue(len(input_with_spec) == 4)
self.assertTrue(input_with_spec[0] == a_spec) # a
self.assertTrue(input_with_spec[1] == b_spec) # b
ans_b_spec = InputSpec([4, 10], name='b', stop_gradient=True)
self.assertTrue(input_with_spec[1] == ans_b_spec) # b
self.assertTrue(input_with_spec[2] == 1) # c
self.assertTrue(input_with_spec[3] == 2) # d
......
......@@ -36,12 +36,12 @@ class GridGenerator(nn.Layer):
in_channels, 6, weight_attr=param_attr, bias_attr=bias_attr
)
@paddle.jit.to_static(
input_spec=[
paddle.static.InputSpec(shape=[None, 3, 32, 100], dtype='float32'),
paddle.static.InputSpec(shape=[32, 100], dtype='float32'),
]
)
# @paddle.jit.to_static(
# input_spec=[
# paddle.static.InputSpec(shape=[None, 3, 32, 100], dtype='float32'),
# paddle.static.InputSpec(shape=[32, 100], dtype='float32'),
# ]
# )
def forward(self, batch_C_prime, I_r_size):
"""
Generate the grid for the grid_sampler.
......
......@@ -111,22 +111,6 @@ class FunctionSpec:
return tuple(args), kwargs
def _replace_value_with_input_spec(self, args):
args_with_spec = []
for idx, input_var in enumerate(flatten(args)):
if isinstance(input_var, np.ndarray):
input_var = paddle.static.InputSpec.from_numpy(input_var)
_set_spec_stop_gradient(input_var, True)
elif isinstance(input_var, (core.VarBase, core.eager.Tensor)):
stop_gradient = input_var.stop_gradient
input_var = paddle.static.InputSpec.from_tensor(input_var)
_set_spec_stop_gradient(input_var, stop_gradient)
args_with_spec.append(input_var)
args_with_spec = pack_sequence_as(args, args_with_spec)
return args_with_spec
def args_to_input_spec(self, args, kwargs):
"""
Converts input arguments into InputSpec.
......@@ -167,8 +151,8 @@ class FunctionSpec:
# replace argument with corresponding InputSpec.
args_with_spec = convert_to_input_spec(args, self._input_spec)
else:
args_with_spec = self._replace_value_with_input_spec(args)
kwargs_with_spec = self._replace_value_with_input_spec(kwargs)
args_with_spec = _replace_value_with_input_spec(args)
kwargs_with_spec = _replace_value_with_input_spec(kwargs)
# If without specificing name in input_spec, add default name
# according to argument name from decorated function.
......@@ -297,6 +281,28 @@ def get_buffers(layer_instance, include_sublayer=True):
return buffers
def _replace_value_with_input_spec(args):
args_with_spec = []
for idx, input_var in enumerate(flatten(args)):
if isinstance(input_var, np.ndarray):
input_var = paddle.static.InputSpec.from_numpy(input_var)
input_var.stop_gradient = True
elif isinstance(input_var, (core.VarBase, core.eager.Tensor)):
stop_gradient = input_var.stop_gradient
input_var = paddle.static.InputSpec.from_tensor(input_var)
input_var.stop_gradient = stop_gradient
elif isinstance(input_var, paddle.fluid.framework.Variable):
stop_gradient = input_var.stop_gradient
input_var = paddle.static.InputSpec(
input_var.shape, input_var.dtype, input_var.name
)
input_var.stop_gradient = stop_gradient
args_with_spec.append(input_var)
args_with_spec = pack_sequence_as(args, args_with_spec)
return args_with_spec
def convert_to_input_spec(inputs, input_spec):
"""
Replaces tensor in structured `inputs` by InputSpec in `input_spec`.
......@@ -358,7 +364,18 @@ def convert_to_input_spec(inputs, input_spec):
input_with_spec[name] = input
return input_with_spec
elif isinstance(input_spec, paddle.static.InputSpec):
return input_spec
"""we compare input_spec with real_input_spec constructed from arguments."""
real_spec = _replace_value_with_input_spec([inputs])[0]
if not isinstance(real_spec, paddle.static.InputSpec):
raise RuntimeError(
"Give input spec into a non-tensorable arguments `{}`.".format(
inputs
)
)
real_spec.name = input_spec.name
if spec_greater(input_spec, real_spec):
return input_spec
return real_spec
else:
# NOTE(Aurelius84): Support non-Tensor type as input spec info
return input_spec
......@@ -422,15 +439,6 @@ def _replace_spec_name(name, input_spec):
return input_spec
def _set_spec_stop_gradient(spec, stop_gradient):
"""
Set new attribute ``stop_gradient`` for InputSpec to avoid generating redundant grad_op
while append_backward.
"""
assert isinstance(spec, paddle.static.InputSpec)
spec.stop_gradient = stop_gradient
def _hash_spec_names(args_specs, kwargs_specs):
"""
Generater hash spec with args/kwargs InputSpec names.
......@@ -462,3 +470,19 @@ def _hash_spec_names(args_specs, kwargs_specs):
value = [to_idx(name) for name in spec_names]
return tuple(value)
def spec_greater(first, other):
def _shape_greater(first_shape, second_shape):
if len(first_shape) != len(second_shape):
return False
for first_n, second_n in zip(first_shape, second_shape):
if first_n != -1 and first_n != second_n:
return False
return True
return (
other.stop_gradient == first.stop_gradient
and other.dtype == first.dtype
and _shape_greater(first.shape, other.shape)
)
......@@ -1191,6 +1191,7 @@ class ProgramCache:
)
if not _in_amp_guard() and not _in_pure_fp16_guard():
concrete_program._to_prim()
return concrete_program, partial_program_from(concrete_program)
def __getitem__(self, item):
......
......@@ -148,7 +148,7 @@ class InputSpec:
print(label) # InputSpec(shape=(-1, 1), dtype=paddle.int64, name=label)
"""
def __init__(self, shape, dtype='float32', name=None):
def __init__(self, shape, dtype='float32', name=None, stop_gradient=False):
# replace `None` in shape with -1
self.shape = self._verify(shape)
# convert dtype into united represention
......@@ -157,13 +157,18 @@ class InputSpec:
dtype = convert_np_dtype_to_dtype_(dtype)
self.dtype = dtype
self.name = name
self.stop_gradient = stop_gradient
def _create_feed_layer(self):
return data(self.name, shape=self.shape, dtype=self.dtype)
def __repr__(self):
return '{}(shape={}, dtype={}, name={})'.format(
type(self).__name__, self.shape, self.dtype, self.name
return '{}(shape={}, dtype={}, name={}, stop_gradient={})'.format(
type(self).__name__,
self.shape,
self.dtype,
self.name,
self.stop_gradient,
)
@classmethod
......@@ -327,10 +332,10 @@ class InputSpec:
# foo(x_var)
# foo(x_np) # x_np is a numpy.ndarray.
# x_var and x_np hold same shape and dtype, they should also share a same program.
return hash((tuple(self.shape), self.dtype))
return hash((tuple(self.shape), self.dtype, self.stop_gradient))
def __eq__(self, other):
slots = ['shape', 'dtype', 'name']
slots = ['shape', 'dtype', 'name', 'stop_gradient']
return type(self) is type(other) and all(
getattr(self, attr) == getattr(other, attr) for attr in slots
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册