提交 b5cc8c77 编写于 作者: F Faizan Muhammad 提交者: TensorFlower Gardener

Use FunctionType in FunctionCacheKey

PiperOrigin-RevId: 481006100
上级 e4a52e12
......@@ -36,6 +36,7 @@ pytype_strict_library(
"//tensorflow/python/eager:__subpackages__",
],
deps = [
"//tensorflow/core/function/function_type",
"//tensorflow/core/function/polymorphism:type_dispatch",
"//tensorflow/core/function/trace_type",
"//tensorflow/python/types",
......@@ -51,6 +52,7 @@ py_strict_test(
visibility = ["//learning/brain/contrib/eager/python/examples:__pkg__"],
deps = [
":function_cache",
"//tensorflow/core/function/function_type",
"//tensorflow/core/function/trace_type",
"//tensorflow/python:array_ops",
"//tensorflow/python/eager/polymorphic_function:function_context",
......
......@@ -15,9 +15,10 @@
"""Cache to manage concrete functions and their signatures."""
import collections
from typing import Optional, Hashable, Sequence, Any, NamedTuple, Dict
from typing import Any, Dict, Hashable, NamedTuple, Optional, Sequence
from tensorflow.core.function import trace_type
from tensorflow.core.function.function_type import function_type as function_type_lib
from tensorflow.core.function.polymorphism import type_dispatch
from tensorflow.python.types import trace
......@@ -30,6 +31,7 @@ class FunctionContext(NamedTuple):
context: Any
# TODO(fmuham): Move into FunctionType.
class CaptureSnapshot(trace.TraceType):
"""Store tf.function captures to accommodate its specific tracing logic.
......@@ -86,8 +88,7 @@ class CaptureSnapshot(trace.TraceType):
for key, item in query.mapping.items())
def most_specific_common_supertype(
self,
types: Sequence[trace.TraceType]) -> Optional["CaptureSnapshot"]:
self, types: Sequence[trace.TraceType]) -> Optional["CaptureSnapshot"]:
"""See base class."""
common_keys = set(self.mapping.keys())
for other in types:
......@@ -118,22 +119,21 @@ class CaptureSnapshot(trace.TraceType):
return hash(frozenset(self.mapping.keys()))
# TODO(panzf): Rename `FunctionCacheKey` to `FunctionType`
# TODO(fmuham): Remove inheritance from TraceType.
class FunctionCacheKey(trace.TraceType):
"""The unique key associated with a concrete function.
Attributes:
args_signature: A TraceType corresponding to the function arguments.
function_type: A FunctionType corresponding to the function arguments.
captures_signature: A CaptureSnapshot corresponding to the function
captures.
call_context: The FunctionContext for when the args_signature was
generated.
call_context: The FunctionContext for when the function was called.
"""
def __init__(self, args_signature: trace.TraceType,
def __init__(self, function_type: function_type_lib.FunctionType,
captures_signature: CaptureSnapshot,
call_context: FunctionContext):
self.args_signature = args_signature
self.function_type = function_type
self.captures_signature = captures_signature
self.call_context = call_context
......@@ -144,8 +144,8 @@ class FunctionCacheKey(trace.TraceType):
if self.call_context != other.call_context:
return False
return (self.args_signature.is_subtype_of(other.args_signature)
and self.captures_signature.is_subtype_of(other.captures_signature))
return (self.function_type.is_supertype_of(other.function_type) and
self.captures_signature.is_subtype_of(other.captures_signature))
def most_specific_common_supertype(
self, others: Sequence[trace.TraceType]) -> Optional["FunctionCacheKey"]:
......@@ -154,27 +154,28 @@ class FunctionCacheKey(trace.TraceType):
self.call_context == other.call_context for other in others):
return None
# `args` and `captures` are independent when finding common supertypes.
args_common = self.args_signature.most_specific_common_supertype(
[other.args_signature for other in others])
function_type_common = self.function_type.most_specific_common_subtype(
[other.function_type for other in others])
if args_common is None:
if function_type_common is None:
return None
captures_common = self.captures_signature.most_specific_common_supertype(
[other.captures_signature for other in others])
return FunctionCacheKey(args_common, captures_common, self.call_context)
return FunctionCacheKey(function_type_common, captures_common,
self.call_context)
def _placeholder_value(self) -> Any:
"""Value used for tracing a function signature with this TraceType."""
return {"args": self.args_signature._placeholder_value(), # pylint: disable=protected-access
"captures": self.captures_signature._placeholder_value()} # pylint: disable=protected-access
return {
"args": self.function_type.placeholder_arguments().args[0],
"captures": self.captures_signature._placeholder_value() # pylint: disable=protected-access
}
def __hash__(self) -> int:
return hash((self.call_context,
self.args_signature,
self.captures_signature))
return hash(
(self.call_context, self.function_type, self.captures_signature))
def __eq__(self, other) -> bool:
if not isinstance(other, trace.TraceType):
......@@ -184,23 +185,20 @@ class FunctionCacheKey(trace.TraceType):
return False
return (self.call_context == other.call_context and
self.args_signature == other.args_signature and
self.function_type == other.function_type and
self.captures_signature == other.captures_signature)
def __repr__(self) -> str:
return (
f"{type(self).__name__}(args_signature={repr(self.args_signature)},"
f"(captures_signature={repr(self.captures_signature)},"
f" call_context={repr(self.call_context)})")
return (f"{type(self).__name__}(function_type={repr(self.function_type)},"
f"(captures_signature={repr(self.captures_signature)},"
f" call_context={repr(self.call_context)})")
# TODO(fmuham): Rename to FunctionLibrary.
class FunctionCache:
"""A container for managing concrete functions."""
__slots__ = [
"_primary", "_dispatch_table", "_garbage_collectors"
]
__slots__ = ["_primary", "_dispatch_table", "_garbage_collectors"]
def __init__(self):
# The primary cache, mapping FunctionCacheKey to a concrete function.
......@@ -236,8 +234,7 @@ class FunctionCache:
return True
def add(self, key: FunctionCacheKey,
deletion_observer: trace_type.WeakrefDeletionObserver,
concrete):
deletion_observer: trace_type.WeakrefDeletionObserver, concrete: ...):
"""Adds a new concrete function alongside its key.
Args:
......
......@@ -19,6 +19,7 @@ import timeit
from typing import Optional
from tensorflow.core.function import trace_type
from tensorflow.core.function.function_type import function_type
from tensorflow.core.function.polymorphism import function_cache
from tensorflow.python.eager.polymorphic_function import function_context
from tensorflow.python.ops import array_ops
......@@ -81,7 +82,7 @@ class MockShape(trace.TraceType):
def __init__(self, *shape: Optional[int]):
self.shape = shape
def is_subtype_of(self, other: "MockShape") ->bool:
def is_subtype_of(self, other: "MockShape") -> bool:
if len(self.shape) != len(other.shape):
return False
......@@ -112,6 +113,13 @@ class MockEmptyCaptureSnapshot(function_cache.CaptureSnapshot):
self.mapping = {}
def make_single_param_type(type_constraint):
return function_type.FunctionType([
function_type.Parameter("x", function_type.Parameter.POSITIONAL_ONLY,
False, type_constraint)
])
class FunctionCacheTest(test.TestCase):
def testConcreteFunctionDictRetainsInsertedKeys(self):
......@@ -158,14 +166,15 @@ class FunctionCacheTest(test.TestCase):
cache.delete(key_1)
self.assertIsNone(cache.lookup(key_1, False))
key_2 = function_cache.FunctionCacheKey(MockSubtypeOf2(2),
MockEmptyCaptureSnapshot(), None)
cache.add(key_2, trace_type.WeakrefDeletionObserver(),
"test_2")
key_2 = function_cache.FunctionCacheKey(
make_single_param_type(MockSubtypeOf2(2)), MockEmptyCaptureSnapshot(),
None)
cache.add(key_2, trace_type.WeakrefDeletionObserver(), "test_2")
self.assertEqual(cache.lookup(key_2, False), "test_2")
key_3 = function_cache.FunctionCacheKey(MockSubtypeOf2(3),
MockEmptyCaptureSnapshot(), None)
key_3 = function_cache.FunctionCacheKey(
make_single_param_type(MockSubtypeOf2(3)), MockEmptyCaptureSnapshot(),
None)
self.assertEqual(cache.lookup(key_3, True), "test_2")
cache.delete(key_2)
......@@ -175,12 +184,12 @@ class FunctionCacheTest(test.TestCase):
def testFunctionCacheKeyRespectsEquality(self):
ctx = function_cache.FunctionContext(0)
generic = MockGenericType
key_a = function_cache.FunctionCacheKey(generic(1),
MockEmptyCaptureSnapshot(), ctx)
key_b = function_cache.FunctionCacheKey(generic(2),
MockEmptyCaptureSnapshot(), ctx)
key_c = function_cache.FunctionCacheKey(generic(1),
MockEmptyCaptureSnapshot(), ctx)
key_a = function_cache.FunctionCacheKey(
make_single_param_type(generic(1)), MockEmptyCaptureSnapshot(), ctx)
key_b = function_cache.FunctionCacheKey(
make_single_param_type(generic(2)), MockEmptyCaptureSnapshot(), ctx)
key_c = function_cache.FunctionCacheKey(
make_single_param_type(generic(1)), MockEmptyCaptureSnapshot(), ctx)
self.assertNotEqual(key_a, key_b)
self.assertEqual(key_a, key_c)
......@@ -188,12 +197,15 @@ class FunctionCacheTest(test.TestCase):
def testFunctionCacheKeyRespectsSubtype(self):
ctx = function_cache.FunctionContext(0)
key_a = function_cache.FunctionCacheKey(MockSubtypeOf2(1),
MockEmptyCaptureSnapshot(), ctx)
key_b = function_cache.FunctionCacheKey(MockSubtypeOf2(2),
MockEmptyCaptureSnapshot(), ctx)
key_c = function_cache.FunctionCacheKey(MockSubtypeOf2(1),
MockEmptyCaptureSnapshot(), ctx)
key_a = function_cache.FunctionCacheKey(
make_single_param_type(MockSubtypeOf2(1)), MockEmptyCaptureSnapshot(),
ctx)
key_b = function_cache.FunctionCacheKey(
make_single_param_type(MockSubtypeOf2(2)), MockEmptyCaptureSnapshot(),
ctx)
key_c = function_cache.FunctionCacheKey(
make_single_param_type(MockSubtypeOf2(1)), MockEmptyCaptureSnapshot(),
ctx)
self.assertTrue(key_a.is_subtype_of(key_b))
self.assertFalse(key_b.is_subtype_of(key_a))
......@@ -201,68 +213,74 @@ class FunctionCacheTest(test.TestCase):
def testFunctionCacheKeyRespectsSupertype(self):
ctx = function_cache.FunctionContext(0)
key_a = function_cache.FunctionCacheKey(MockSupertypes2With3(1),
MockEmptyCaptureSnapshot(), ctx)
key_b = function_cache.FunctionCacheKey(MockSupertypes2With3(2),
MockEmptyCaptureSnapshot(), ctx)
key_a = function_cache.FunctionCacheKey(
make_single_param_type(MockSupertypes2With3(1)),
MockEmptyCaptureSnapshot(), ctx)
key_b = function_cache.FunctionCacheKey(
make_single_param_type(MockSupertypes2With3(2)),
MockEmptyCaptureSnapshot(), ctx)
self.assertEqual(
key_b.most_specific_common_supertype([key_a]),
function_cache.FunctionCacheKey(MockSupertypes2With3(3),
MockEmptyCaptureSnapshot(), ctx))
function_cache.FunctionCacheKey(
make_single_param_type(MockSupertypes2With3(3)),
MockEmptyCaptureSnapshot(), ctx))
self.assertIsNone(key_a.most_specific_common_supertype([key_b]))
def testMostSpecificFunctionCacheKeyIsLookedUp(self):
ctx = function_cache.FunctionContext(0)
cache = function_cache.FunctionCache()
cache.add(
function_cache.FunctionCacheKey(MockShape(1, 2, None),
MockEmptyCaptureSnapshot(), ctx),
function_cache.FunctionCacheKey(
make_single_param_type(MockShape(1, 2, None)),
MockEmptyCaptureSnapshot(), ctx),
trace_type.WeakrefDeletionObserver(), "a")
cache.add(
function_cache.FunctionCacheKey(MockShape(1, 2, 3),
MockEmptyCaptureSnapshot(), ctx),
function_cache.FunctionCacheKey(
make_single_param_type(MockShape(1, 2, 3)),
MockEmptyCaptureSnapshot(), ctx),
trace_type.WeakrefDeletionObserver(), "b")
self.assertEqual(
cache.lookup(
function_cache.FunctionCacheKey(MockShape(1, 2, 3),
MockEmptyCaptureSnapshot(),
ctx), True),
"b")
function_cache.FunctionCacheKey(
make_single_param_type(MockShape(1, 2, 3)),
MockEmptyCaptureSnapshot(), ctx), True), "b")
def testFirstMostSpecificFunctionCacheKeyIsLookedUp(self):
ctx = function_cache.FunctionContext(0)
cache = function_cache.FunctionCache()
cache.add(
function_cache.FunctionCacheKey(MockShape(1, 2, None),
MockEmptyCaptureSnapshot(), ctx),
function_cache.FunctionCacheKey(
make_single_param_type(MockShape(1, 2, None)),
MockEmptyCaptureSnapshot(), ctx),
trace_type.WeakrefDeletionObserver(), "a")
cache.add(
function_cache.FunctionCacheKey(MockShape(1, None, 3),
MockEmptyCaptureSnapshot(), ctx),
function_cache.FunctionCacheKey(
make_single_param_type(MockShape(1, None, 3)),
MockEmptyCaptureSnapshot(), ctx),
trace_type.WeakrefDeletionObserver(), "b")
self.assertEqual(
cache.lookup(
function_cache.FunctionCacheKey(
MockShape(1, 2, 3), MockEmptyCaptureSnapshot(), ctx), True),
"a")
make_single_param_type(MockShape(1, 2, 3)),
MockEmptyCaptureSnapshot(), ctx), True), "a")
def testMostSpecificFunctionCacheKeyIsOrderAgnostic(self):
ctx = function_cache.FunctionContext(0)
keys = [(function_cache.FunctionCacheKey(MockShape(1, 1, 1),
MockEmptyCaptureSnapshot(),
ctx), "a"),
(function_cache.FunctionCacheKey(MockShape(1, None, 1),
MockEmptyCaptureSnapshot(),
ctx), "b"),
(function_cache.FunctionCacheKey(MockShape(None, None, 1),
MockEmptyCaptureSnapshot(),
ctx), "c"),
(function_cache.FunctionCacheKey(MockShape(None, None, None),
MockEmptyCaptureSnapshot(),
ctx), "d")]
keys = [(function_cache.FunctionCacheKey(
make_single_param_type(MockShape(1, 1, 1)), MockEmptyCaptureSnapshot(),
ctx), "a"),
(function_cache.FunctionCacheKey(
make_single_param_type(MockShape(1, None, 1)),
MockEmptyCaptureSnapshot(), ctx), "b"),
(function_cache.FunctionCacheKey(
make_single_param_type(MockShape(None, None, 1)),
MockEmptyCaptureSnapshot(), ctx), "c"),
(function_cache.FunctionCacheKey(
make_single_param_type(MockShape(None, None, None)),
MockEmptyCaptureSnapshot(), ctx), "d")]
for permutation in itertools.permutations(keys):
cache = function_cache.FunctionCache()
......@@ -277,28 +295,24 @@ class FunctionCacheTest(test.TestCase):
self.assertEqual(
cache.lookup(
function_cache.FunctionCacheKey(MockShape(1, 1, 1),
MockEmptyCaptureSnapshot(),
ctx), True),
"a")
function_cache.FunctionCacheKey(
make_single_param_type(MockShape(1, 1, 1)),
MockEmptyCaptureSnapshot(), ctx), True), "a")
self.assertEqual(
cache.lookup(
function_cache.FunctionCacheKey(MockShape(1, 2, 1),
MockEmptyCaptureSnapshot(),
ctx), True),
"b")
function_cache.FunctionCacheKey(
make_single_param_type(MockShape(1, 2, 1)),
MockEmptyCaptureSnapshot(), ctx), True), "b")
self.assertEqual(
cache.lookup(
function_cache.FunctionCacheKey(MockShape(2, 2, 1),
MockEmptyCaptureSnapshot(),
ctx), True),
"c")
function_cache.FunctionCacheKey(
make_single_param_type(MockShape(2, 2, 1)),
MockEmptyCaptureSnapshot(), ctx), True), "c")
self.assertEqual(
cache.lookup(
function_cache.FunctionCacheKey(MockShape(2, 2, 2),
MockEmptyCaptureSnapshot(),
ctx), True),
"d")
function_cache.FunctionCacheKey(
make_single_param_type(MockShape(2, 2, 2)),
MockEmptyCaptureSnapshot(), ctx), True), "d")
def testWeakRefDeletionAlsoDeletesConcreteFunction(self):
if not function_cache.DELETE_WITH_WEAKREF:
......@@ -447,17 +461,19 @@ class FunctionCacheBenchmark(test.Benchmark):
for key in keys:
cache.add(*key, "testing")
cache.add(
function_cache.FunctionCacheKey(MockSubtypeOf2(2),
MockEmptyCaptureSnapshot(), None),
function_cache.FunctionCacheKey(
make_single_param_type(MockSubtypeOf2(2)),
MockEmptyCaptureSnapshot(), None),
trace_type.WeakrefDeletionObserver(), "testing")
cache.lookup(function_cache.FunctionCacheKey(MockSubtypeOf2(3),
MockEmptyCaptureSnapshot(),
None), True)
cache.lookup(
function_cache.FunctionCacheKey(
make_single_param_type(MockSubtypeOf2(3)),
MockEmptyCaptureSnapshot(), None), True)
iterations = 10000
lookup_key = function_cache.FunctionCacheKey(MockSubtypeOf2(2),
MockEmptyCaptureSnapshot(),
None)
lookup_key = function_cache.FunctionCacheKey(
make_single_param_type(MockSubtypeOf2(2)), MockEmptyCaptureSnapshot(),
None)
subtyping_time = timeit.timeit(
lambda: cache.lookup(lookup_key, True), number=iterations)
......@@ -490,14 +506,15 @@ class FunctionCacheBenchmark(test.Benchmark):
for key in keys:
cache.add(*key, "testing")
cache.add(
function_cache.FunctionCacheKey(MockSubtypeOf2(3),
MockEmptyCaptureSnapshot(), None),
function_cache.FunctionCacheKey(
make_single_param_type(MockSubtypeOf2(3)),
MockEmptyCaptureSnapshot(), None),
trace_type.WeakrefDeletionObserver(), "testing")
iterations = 10000
lookup_key = function_cache.FunctionCacheKey(MockSubtypeOf2(2),
MockEmptyCaptureSnapshot(),
None)
lookup_key = function_cache.FunctionCacheKey(
make_single_param_type(MockSubtypeOf2(2)), MockEmptyCaptureSnapshot(),
None)
subtyping_time = sum(
timeit.repeat(
stmt=lambda: cache.lookup(lookup_key, True),
......@@ -539,9 +556,7 @@ class CaptureSnapshotTest(test.TestCase):
"b": MockIntGenericType(1),
"c": MockIntGenericType(2)
})
self.snapshot_e = snapshot_type({
"d": MockIntGenericType(1)
})
self.snapshot_e = snapshot_type({"d": MockIntGenericType(1)})
def testCaptureSnapshotSubtype(self):
self.assertFalse(self.snapshot_a.is_subtype_of(self.snapshot_b))
......
......@@ -17,6 +17,7 @@
from typing import Any, NamedTuple, Tuple
from tensorflow.core.function import trace_type
from tensorflow.core.function.function_type import function_type
from tensorflow.core.function.polymorphism import function_cache
from tensorflow.python.eager import context
from tensorflow.python.framework import device as pydev
......@@ -135,7 +136,15 @@ def make_cache_key(
captures_signature = function_cache.CaptureSnapshot(
captures_dict_tracetype.mapping)
# TODO(fmuham): Use the actual FunctionType
dummy_function_type = function_type.FunctionType([
function_type.Parameter("args_kwargs",
function_type.Parameter.POSITIONAL_ONLY,
False,
args_signature)
])
return function_cache.FunctionCacheKey(
args_signature,
dummy_function_type,
captures_signature,
make_function_context()), signature_context.deletion_observer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册