未验证 提交 4157579e 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Refine ProgramCache.last and Return recent one (#39541)

* Refine ProgramCache.last and Return recent one

* add comment

* fix unittest
上级 40d2b7c6
...@@ -707,6 +707,8 @@ class ProgramCache(object): ...@@ -707,6 +707,8 @@ class ProgramCache(object):
def __init__(self): def __init__(self):
# {hash_id : (concrete_program, partial_layer)} # {hash_id : (concrete_program, partial_layer)}
self._caches = collections.OrderedDict() self._caches = collections.OrderedDict()
# trace mostly recent used program
self._recent_key = None
def _build_once(self, cache_key): def _build_once(self, cache_key):
concrete_program = ConcreteProgram.from_func_spec( concrete_program = ConcreteProgram.from_func_spec(
...@@ -722,6 +724,7 @@ class ProgramCache(object): ...@@ -722,6 +724,7 @@ class ProgramCache(object):
raise ValueError('type(item) should be CacheKey, but received %s' % raise ValueError('type(item) should be CacheKey, but received %s' %
type_name(item)) type_name(item))
item_id = hash(item) item_id = hash(item)
self._recent_key = item_id
if item_id not in self._caches: if item_id not in self._caches:
self._caches[item_id] = self._build_once(item) self._caches[item_id] = self._build_once(item)
# Note: raise warnings if number of traced program is more than `max_tracing_count` # Note: raise warnings if number of traced program is more than `max_tracing_count`
...@@ -749,8 +752,8 @@ class ProgramCache(object): ...@@ -749,8 +752,8 @@ class ProgramCache(object):
def last(self): def last(self):
assert len( assert len(
self._caches) >= 1, "No valid cached program in ProgramCache." self._caches) >= 1, "No valid cached program in ProgramCache."
key = next(reversed(self._caches.keys())) assert self._recent_key is not None
return key, self._caches[key] return self._recent_key, self._caches[self._recent_key]
def __len__(self): def __len__(self):
return len(self._caches) return len(self._caches)
......
...@@ -214,6 +214,7 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase): ...@@ -214,6 +214,7 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
self.assertTrue(np.allclose(x_data + y_data, out_1.numpy())) self.assertTrue(np.allclose(x_data + y_data, out_1.numpy()))
self.assertTrue(len(foo.program_cache) == 1) self.assertTrue(len(foo.program_cache) == 1)
self.assertTrue(len(foo.program_cache.concrete_programs()) == 1) self.assertTrue(len(foo.program_cache.concrete_programs()) == 1)
first_program = foo.program_cache.last()
# [16, 10] + [10] (numpy) # [16, 10] + [10] (numpy)
out_2 = foo(to_variable(x_data), y_data) out_2 = foo(to_variable(x_data), y_data)
...@@ -232,6 +233,11 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase): ...@@ -232,6 +233,11 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
# create a new program # create a new program
self.assertTrue(len(foo.program_cache) == 2) self.assertTrue(len(foo.program_cache) == 2)
# test for recent program
foo(to_variable(x_data), y_data)
recent_program = foo.program_cache.last()
self.assertTrue(first_program == recent_program)
def test_get_concrete_program(self): def test_get_concrete_program(self):
foo = declarative(foo_func) foo = declarative(foo_func)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册