diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py index c25574c39dafe02ef4a02b8f2c6fc67eb14d86ca..30012fb8666fcb5256efa889de7440f6d709cccd 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py @@ -394,3 +394,34 @@ def _set_spec_stop_gradient(spec, stop_gradient): """ assert isinstance(spec, paddle.static.InputSpec) spec.stop_gradient = stop_gradient + + +def _hash_spec_names(args_specs, kwargs_specs): + """ + Generater hash spec with args/kwargs InputSpec names. + Consider the following InputSpecs with same shape/dtype except for name: + 1. [InputSpec([3,3], 'float32', 'x'), InputSpec([3,3], 'float32', 'x')] + 2. [InputSpec([3,3], 'float32', 'x'), InputSpec([3,3], 'float32', 'y')] + Under @to_static, we should generate two different program not just one, because + the former has one input ('x'), but the latter has two input ('x', 'y'). + """ + spec_names = [ + spec.name for spec in flatten(args_specs) + if isinstance(spec, paddle.static.InputSpec) + ] + spec_names += [ + spec.name for spec in flatten(kwargs_specs) + if isinstance(spec, paddle.static.InputSpec) + ] + i, name_ids = 0, {} + + def to_idx(name): + nonlocal i + if name not in name_ids: + name_ids[name] = i + i += 1 + return name_ids[name] + + value = [to_idx(name) for name in spec_names] + + return tuple(value) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 19479a190c3b9e83e267fa1a7acbfc007f34ec58..f8800f3037b408b4ad6a8b33beb1282cff185f5e 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -43,7 +43,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import input_specs_compatible from paddle.fluid.dygraph.dygraph_to_static.utils import type_name from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap from paddle.fluid.dygraph.dygraph_to_static.utils import make_hashable -from paddle.fluid.dygraph.dygraph_to_static.function_spec import FunctionSpec +from paddle.fluid.dygraph.dygraph_to_static.function_spec import FunctionSpec, _hash_spec_names from paddle.fluid.dygraph.dygraph_to_static.function_spec import get_buffers, get_parameters from paddle.fluid.wrapped_decorator import signature_safe_contextmanager @@ -147,7 +147,7 @@ class CacheKey(object): """ __slots__ = [ 'function_spec', 'input_args_with_spec', 'input_kwargs_with_spec', - 'class_instance', 'kwargs' + 'class_instance', 'kwargs', '_spec_names_id' ] def __init__(self, function_spec, input_args_with_spec, @@ -168,6 +168,8 @@ class CacheKey(object): self.class_instance = class_instance # NOTE: `kwargs` is usually not considered as basic member for `__hash__` self.kwargs = kwargs + self._spec_names_id = _hash_spec_names(input_args_with_spec, + input_kwargs_with_spec) @classmethod def from_func_and_args(cls, function_spec, args, kwargs, class_instance): @@ -197,7 +199,7 @@ class CacheKey(object): return hash((id(self.function_spec), make_hashable(self.input_args_with_spec, error_msg), make_hashable(self.input_kwargs_with_spec, error_msg), - self.class_instance)) + self._spec_names_id, self.class_instance)) def __eq__(self, other): return (type(self) is type(other)) and hash(self) == hash(other) @@ -703,6 +705,7 @@ class ProgramCache(object): """ def __init__(self): + # {hash_id : (concrete_program, partial_layer)} self._caches = collections.OrderedDict() def _build_once(self, cache_key): @@ -718,9 +721,9 @@ class ProgramCache(object): if not isinstance(item, CacheKey): raise ValueError('type(item) should be CacheKey, but received %s' % type_name(item)) - - if item not in self._caches: - self._caches[item] = self._build_once(item) + item_id = hash(item) + if item_id not in self._caches: + self._caches[item_id] = self._build_once(item) # Note: raise warnings if number of traced program is more than `max_tracing_count` current_tracing_count = len(self._caches) if current_tracing_count > MAX_TRACED_PROGRAM_COUNT: @@ -729,18 +732,19 @@ class ProgramCache(object): "The reason may be: (1) passing tensors with different shapes, (2) passing python objects instead of tensors.". format(current_tracing_count, MAX_TRACED_PROGRAM_COUNT)) - return self._caches[item] + return self._caches[item_id] def get_program(self, item): if not isinstance(item, CacheKey): raise ValueError( "Input item's type should be FunctionSpec, but received %s" % type_name(item)) - if item not in self._caches: + item_id = hash(item) + if item_id not in self._caches: raise RuntimeError( "Failed to find program for input item, please decorate input function by `@paddle.jit.to_static`." ) - return self._caches[item] + return self._caches[item_id] def last(self): assert len( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_spec_names.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_spec_names.py new file mode 100644 index 0000000000000000000000000000000000000000..361fcbf9c73f5b056389e45c58b83e5e6308d000 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_spec_names.py @@ -0,0 +1,104 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.nn import Layer +import numpy as np +import unittest + + +class Net(Layer): + def __init__(self): + super(Net, self).__init__() + self.fc = paddle.nn.Linear(16, 3) + + def forward(self, x, y, m, n): + inputs = [x, y, m, n] + outs = [] + for var in inputs: + out = paddle.reshape(x, [-1, 16]) + out = self.fc(out) + outs.append(out) + + out = paddle.stack(outs) + return paddle.sum(out) + + +class TestArgsSpecName(unittest.TestCase): + def read_from_dataset(self): + self.x = paddle.randn([4, 2, 8]) + self.y = paddle.randn([4, 2, 8]) + self.m = paddle.randn([4, 2, 8]) + self.n = paddle.randn([4, 2, 8]) + + def test_spec_name_hash(self): + net = Net() + net = paddle.jit.to_static(net) + # Convert into program with four input + self.read_from_dataset() + self.run_test(net, [self.x, self.y, self.m, self.n], 1, [0, 1, 2, 3]) + + # Convert into program with three input + self.read_from_dataset() + self.run_test(net, [self.x, self.x, self.m, self.n], 2, [0, 0, 1, 2]) + + # Convert into program with two input + self.read_from_dataset() + self.run_test(net, [self.x, self.x, self.m, self.m], 3, [0, 0, 1, 1]) + + # Use Cache Program + self.read_from_dataset() + self.run_test(net, [self.n, self.n, self.y, self.y], 3, [0, 0, 1, 1]) + + # Convert into program with two input + self.read_from_dataset() + self.run_test(net, [self.x, self.y, self.x, self.y], 4, [0, 1, 0, 1]) + + # Use Cache Program + self.read_from_dataset() + self.run_test(net, [self.m, self.n, self.m, self.n], 4, [0, 1, 0, 1]) + + # Convert into program with one input + self.read_from_dataset() + self.run_test(net, [self.x, self.x, self.x, self.x], 5, [0, 0, 0, 0]) + + # Use Cache Program + self.read_from_dataset() + self.run_test(net, [self.m, self.m, self.m, self.m], 5, [0, 0, 0, 0]) + + def run_test(self, net, inputs, trace_count, mode): + out = net(*inputs) + self.assertEqual(net.forward.get_traced_count(), trace_count) + self.assert_feed_mode(net.forward.inputs, mode) + + def assert_feed_mode(self, inputs, expect_mode): + assert isinstance(inputs, list) + assert isinstance(expect_mode, list) + in_names = [var.name for var in inputs] + + i, name_ids = 0, {} + + def to_idx(name): + nonlocal i + if name not in name_ids: + name_ids[name] = i + i += 1 + return name_ids[name] + + mode = [to_idx(name) for name in in_names] + self.assertEquals(mode, expect_mode) + + +if __name__ == '__main__': + unittest.main()