diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index f8800f3037b408b4ad6a8b33beb1282cff185f5e..dc1095849a3d8fa5de689a518934e4dea8dff99f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -707,6 +707,8 @@ class ProgramCache(object): def __init__(self): # {hash_id : (concrete_program, partial_layer)} self._caches = collections.OrderedDict() + # trace mostly recent used program + self._recent_key = None def _build_once(self, cache_key): concrete_program = ConcreteProgram.from_func_spec( @@ -722,6 +724,7 @@ class ProgramCache(object): raise ValueError('type(item) should be CacheKey, but received %s' % type_name(item)) item_id = hash(item) + self._recent_key = item_id if item_id not in self._caches: self._caches[item_id] = self._build_once(item) # Note: raise warnings if number of traced program is more than `max_tracing_count` @@ -749,8 +752,8 @@ class ProgramCache(object): def last(self): assert len( self._caches) >= 1, "No valid cached program in ProgramCache." - key = next(reversed(self._caches.keys())) - return key, self._caches[key] + assert self._recent_key is not None + return self._recent_key, self._caches[self._recent_key] def __len__(self): return len(self._caches) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py index d18c691325094e10dc181ad7778a6ba1ab81a57f..67091f5fabb2ede1b589ba863c86b86607514dbb 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py @@ -214,6 +214,7 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase): self.assertTrue(np.allclose(x_data + y_data, out_1.numpy())) self.assertTrue(len(foo.program_cache) == 1) self.assertTrue(len(foo.program_cache.concrete_programs()) == 1) + first_program = foo.program_cache.last() # [16, 10] + [10] (numpy) out_2 = foo(to_variable(x_data), y_data) @@ -232,6 +233,11 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase): # create a new program self.assertTrue(len(foo.program_cache) == 2) + # test for recent program + foo(to_variable(x_data), y_data) + recent_program = foo.program_cache.last() + self.assertTrue(first_program == recent_program) + def test_get_concrete_program(self): foo = declarative(foo_func)