From b5cc8c77cf12449d431cad48d83ec97003f894e7 Mon Sep 17 00:00:00 2001 From: Faizan Muhammad Date: Thu, 13 Oct 2022 16:02:42 -0700 Subject: [PATCH] Use FunctionType in FunctionCacheKey PiperOrigin-RevId: 481006100 --- tensorflow/core/function/polymorphism/BUILD | 2 + .../function/polymorphism/function_cache.py | 59 +++--- .../polymorphism/function_cache_test.py | 181 ++++++++++-------- .../polymorphic_function/function_context.py | 11 +- 4 files changed, 138 insertions(+), 115 deletions(-) diff --git a/tensorflow/core/function/polymorphism/BUILD b/tensorflow/core/function/polymorphism/BUILD index b1a73f0399c..d8fd1405ce5 100644 --- a/tensorflow/core/function/polymorphism/BUILD +++ b/tensorflow/core/function/polymorphism/BUILD @@ -36,6 +36,7 @@ pytype_strict_library( "//tensorflow/python/eager:__subpackages__", ], deps = [ + "//tensorflow/core/function/function_type", "//tensorflow/core/function/polymorphism:type_dispatch", "//tensorflow/core/function/trace_type", "//tensorflow/python/types", @@ -51,6 +52,7 @@ py_strict_test( visibility = ["//learning/brain/contrib/eager/python/examples:__pkg__"], deps = [ ":function_cache", + "//tensorflow/core/function/function_type", "//tensorflow/core/function/trace_type", "//tensorflow/python:array_ops", "//tensorflow/python/eager/polymorphic_function:function_context", diff --git a/tensorflow/core/function/polymorphism/function_cache.py b/tensorflow/core/function/polymorphism/function_cache.py index a03058aeff9..031b224d24b 100644 --- a/tensorflow/core/function/polymorphism/function_cache.py +++ b/tensorflow/core/function/polymorphism/function_cache.py @@ -15,9 +15,10 @@ """Cache to manage concrete functions and their signatures.""" import collections -from typing import Optional, Hashable, Sequence, Any, NamedTuple, Dict +from typing import Any, Dict, Hashable, NamedTuple, Optional, Sequence from tensorflow.core.function import trace_type +from tensorflow.core.function.function_type import function_type as function_type_lib from tensorflow.core.function.polymorphism import type_dispatch from tensorflow.python.types import trace @@ -30,6 +31,7 @@ class FunctionContext(NamedTuple): context: Any +# TODO(fmuham): Move into FunctionType. class CaptureSnapshot(trace.TraceType): """Store tf.function captures to accommodate its specific tracing logic. @@ -86,8 +88,7 @@ class CaptureSnapshot(trace.TraceType): for key, item in query.mapping.items()) def most_specific_common_supertype( - self, - types: Sequence[trace.TraceType]) -> Optional["CaptureSnapshot"]: + self, types: Sequence[trace.TraceType]) -> Optional["CaptureSnapshot"]: """See base class.""" common_keys = set(self.mapping.keys()) for other in types: @@ -118,22 +119,21 @@ class CaptureSnapshot(trace.TraceType): return hash(frozenset(self.mapping.keys())) -# TODO(panzf): Rename `FunctionCacheKey` to `FunctionType` +# TODO(fmuham): Remove inheritance from TraceType. class FunctionCacheKey(trace.TraceType): """The unique key associated with a concrete function. Attributes: - args_signature: A TraceType corresponding to the function arguments. + function_type: A FunctionType corresponding to the function arguments. captures_signature: A CaptureSnapshot corresponding to the function captures. - call_context: The FunctionContext for when the args_signature was - generated. + call_context: The FunctionContext for when the function was called. """ - def __init__(self, args_signature: trace.TraceType, + def __init__(self, function_type: function_type_lib.FunctionType, captures_signature: CaptureSnapshot, call_context: FunctionContext): - self.args_signature = args_signature + self.function_type = function_type self.captures_signature = captures_signature self.call_context = call_context @@ -144,8 +144,8 @@ class FunctionCacheKey(trace.TraceType): if self.call_context != other.call_context: return False - return (self.args_signature.is_subtype_of(other.args_signature) - and self.captures_signature.is_subtype_of(other.captures_signature)) + return (self.function_type.is_supertype_of(other.function_type) and + self.captures_signature.is_subtype_of(other.captures_signature)) def most_specific_common_supertype( self, others: Sequence[trace.TraceType]) -> Optional["FunctionCacheKey"]: @@ -154,27 +154,28 @@ class FunctionCacheKey(trace.TraceType): self.call_context == other.call_context for other in others): return None - # `args` and `captures` are independent when finding common supertypes. - args_common = self.args_signature.most_specific_common_supertype( - [other.args_signature for other in others]) + function_type_common = self.function_type.most_specific_common_subtype( + [other.function_type for other in others]) - if args_common is None: + if function_type_common is None: return None captures_common = self.captures_signature.most_specific_common_supertype( [other.captures_signature for other in others]) - return FunctionCacheKey(args_common, captures_common, self.call_context) + return FunctionCacheKey(function_type_common, captures_common, + self.call_context) def _placeholder_value(self) -> Any: """Value used for tracing a function signature with this TraceType.""" - return {"args": self.args_signature._placeholder_value(), # pylint: disable=protected-access - "captures": self.captures_signature._placeholder_value()} # pylint: disable=protected-access + return { + "args": self.function_type.placeholder_arguments().args[0], + "captures": self.captures_signature._placeholder_value() # pylint: disable=protected-access + } def __hash__(self) -> int: - return hash((self.call_context, - self.args_signature, - self.captures_signature)) + return hash( + (self.call_context, self.function_type, self.captures_signature)) def __eq__(self, other) -> bool: if not isinstance(other, trace.TraceType): @@ -184,23 +185,20 @@ class FunctionCacheKey(trace.TraceType): return False return (self.call_context == other.call_context and - self.args_signature == other.args_signature and + self.function_type == other.function_type and self.captures_signature == other.captures_signature) def __repr__(self) -> str: - return ( - f"{type(self).__name__}(args_signature={repr(self.args_signature)}," - f"(captures_signature={repr(self.captures_signature)}," - f" call_context={repr(self.call_context)})") + return (f"{type(self).__name__}(function_type={repr(self.function_type)}," + f"(captures_signature={repr(self.captures_signature)}," + f" call_context={repr(self.call_context)})") # TODO(fmuham): Rename to FunctionLibrary. class FunctionCache: """A container for managing concrete functions.""" - __slots__ = [ - "_primary", "_dispatch_table", "_garbage_collectors" - ] + __slots__ = ["_primary", "_dispatch_table", "_garbage_collectors"] def __init__(self): # The primary cache, mapping FunctionCacheKey to a concrete function. @@ -236,8 +234,7 @@ class FunctionCache: return True def add(self, key: FunctionCacheKey, - deletion_observer: trace_type.WeakrefDeletionObserver, - concrete): + deletion_observer: trace_type.WeakrefDeletionObserver, concrete: ...): """Adds a new concrete function alongside its key. Args: diff --git a/tensorflow/core/function/polymorphism/function_cache_test.py b/tensorflow/core/function/polymorphism/function_cache_test.py index f81bfee4c78..6d99748d367 100644 --- a/tensorflow/core/function/polymorphism/function_cache_test.py +++ b/tensorflow/core/function/polymorphism/function_cache_test.py @@ -19,6 +19,7 @@ import timeit from typing import Optional from tensorflow.core.function import trace_type +from tensorflow.core.function.function_type import function_type from tensorflow.core.function.polymorphism import function_cache from tensorflow.python.eager.polymorphic_function import function_context from tensorflow.python.ops import array_ops @@ -81,7 +82,7 @@ class MockShape(trace.TraceType): def __init__(self, *shape: Optional[int]): self.shape = shape - def is_subtype_of(self, other: "MockShape") ->bool: + def is_subtype_of(self, other: "MockShape") -> bool: if len(self.shape) != len(other.shape): return False @@ -112,6 +113,13 @@ class MockEmptyCaptureSnapshot(function_cache.CaptureSnapshot): self.mapping = {} +def make_single_param_type(type_constraint): + return function_type.FunctionType([ + function_type.Parameter("x", function_type.Parameter.POSITIONAL_ONLY, + False, type_constraint) + ]) + + class FunctionCacheTest(test.TestCase): def testConcreteFunctionDictRetainsInsertedKeys(self): @@ -158,14 +166,15 @@ class FunctionCacheTest(test.TestCase): cache.delete(key_1) self.assertIsNone(cache.lookup(key_1, False)) - key_2 = function_cache.FunctionCacheKey(MockSubtypeOf2(2), - MockEmptyCaptureSnapshot(), None) - cache.add(key_2, trace_type.WeakrefDeletionObserver(), - "test_2") + key_2 = function_cache.FunctionCacheKey( + make_single_param_type(MockSubtypeOf2(2)), MockEmptyCaptureSnapshot(), + None) + cache.add(key_2, trace_type.WeakrefDeletionObserver(), "test_2") self.assertEqual(cache.lookup(key_2, False), "test_2") - key_3 = function_cache.FunctionCacheKey(MockSubtypeOf2(3), - MockEmptyCaptureSnapshot(), None) + key_3 = function_cache.FunctionCacheKey( + make_single_param_type(MockSubtypeOf2(3)), MockEmptyCaptureSnapshot(), + None) self.assertEqual(cache.lookup(key_3, True), "test_2") cache.delete(key_2) @@ -175,12 +184,12 @@ class FunctionCacheTest(test.TestCase): def testFunctionCacheKeyRespectsEquality(self): ctx = function_cache.FunctionContext(0) generic = MockGenericType - key_a = function_cache.FunctionCacheKey(generic(1), - MockEmptyCaptureSnapshot(), ctx) - key_b = function_cache.FunctionCacheKey(generic(2), - MockEmptyCaptureSnapshot(), ctx) - key_c = function_cache.FunctionCacheKey(generic(1), - MockEmptyCaptureSnapshot(), ctx) + key_a = function_cache.FunctionCacheKey( + make_single_param_type(generic(1)), MockEmptyCaptureSnapshot(), ctx) + key_b = function_cache.FunctionCacheKey( + make_single_param_type(generic(2)), MockEmptyCaptureSnapshot(), ctx) + key_c = function_cache.FunctionCacheKey( + make_single_param_type(generic(1)), MockEmptyCaptureSnapshot(), ctx) self.assertNotEqual(key_a, key_b) self.assertEqual(key_a, key_c) @@ -188,12 +197,15 @@ class FunctionCacheTest(test.TestCase): def testFunctionCacheKeyRespectsSubtype(self): ctx = function_cache.FunctionContext(0) - key_a = function_cache.FunctionCacheKey(MockSubtypeOf2(1), - MockEmptyCaptureSnapshot(), ctx) - key_b = function_cache.FunctionCacheKey(MockSubtypeOf2(2), - MockEmptyCaptureSnapshot(), ctx) - key_c = function_cache.FunctionCacheKey(MockSubtypeOf2(1), - MockEmptyCaptureSnapshot(), ctx) + key_a = function_cache.FunctionCacheKey( + make_single_param_type(MockSubtypeOf2(1)), MockEmptyCaptureSnapshot(), + ctx) + key_b = function_cache.FunctionCacheKey( + make_single_param_type(MockSubtypeOf2(2)), MockEmptyCaptureSnapshot(), + ctx) + key_c = function_cache.FunctionCacheKey( + make_single_param_type(MockSubtypeOf2(1)), MockEmptyCaptureSnapshot(), + ctx) self.assertTrue(key_a.is_subtype_of(key_b)) self.assertFalse(key_b.is_subtype_of(key_a)) @@ -201,68 +213,74 @@ class FunctionCacheTest(test.TestCase): def testFunctionCacheKeyRespectsSupertype(self): ctx = function_cache.FunctionContext(0) - key_a = function_cache.FunctionCacheKey(MockSupertypes2With3(1), - MockEmptyCaptureSnapshot(), ctx) - key_b = function_cache.FunctionCacheKey(MockSupertypes2With3(2), - MockEmptyCaptureSnapshot(), ctx) + key_a = function_cache.FunctionCacheKey( + make_single_param_type(MockSupertypes2With3(1)), + MockEmptyCaptureSnapshot(), ctx) + key_b = function_cache.FunctionCacheKey( + make_single_param_type(MockSupertypes2With3(2)), + MockEmptyCaptureSnapshot(), ctx) self.assertEqual( key_b.most_specific_common_supertype([key_a]), - function_cache.FunctionCacheKey(MockSupertypes2With3(3), - MockEmptyCaptureSnapshot(), ctx)) + function_cache.FunctionCacheKey( + make_single_param_type(MockSupertypes2With3(3)), + MockEmptyCaptureSnapshot(), ctx)) self.assertIsNone(key_a.most_specific_common_supertype([key_b])) def testMostSpecificFunctionCacheKeyIsLookedUp(self): ctx = function_cache.FunctionContext(0) cache = function_cache.FunctionCache() cache.add( - function_cache.FunctionCacheKey(MockShape(1, 2, None), - MockEmptyCaptureSnapshot(), ctx), + function_cache.FunctionCacheKey( + make_single_param_type(MockShape(1, 2, None)), + MockEmptyCaptureSnapshot(), ctx), trace_type.WeakrefDeletionObserver(), "a") cache.add( - function_cache.FunctionCacheKey(MockShape(1, 2, 3), - MockEmptyCaptureSnapshot(), ctx), + function_cache.FunctionCacheKey( + make_single_param_type(MockShape(1, 2, 3)), + MockEmptyCaptureSnapshot(), ctx), trace_type.WeakrefDeletionObserver(), "b") self.assertEqual( cache.lookup( - function_cache.FunctionCacheKey(MockShape(1, 2, 3), - MockEmptyCaptureSnapshot(), - ctx), True), - "b") + function_cache.FunctionCacheKey( + make_single_param_type(MockShape(1, 2, 3)), + MockEmptyCaptureSnapshot(), ctx), True), "b") def testFirstMostSpecificFunctionCacheKeyIsLookedUp(self): ctx = function_cache.FunctionContext(0) cache = function_cache.FunctionCache() cache.add( - function_cache.FunctionCacheKey(MockShape(1, 2, None), - MockEmptyCaptureSnapshot(), ctx), + function_cache.FunctionCacheKey( + make_single_param_type(MockShape(1, 2, None)), + MockEmptyCaptureSnapshot(), ctx), trace_type.WeakrefDeletionObserver(), "a") cache.add( - function_cache.FunctionCacheKey(MockShape(1, None, 3), - MockEmptyCaptureSnapshot(), ctx), + function_cache.FunctionCacheKey( + make_single_param_type(MockShape(1, None, 3)), + MockEmptyCaptureSnapshot(), ctx), trace_type.WeakrefDeletionObserver(), "b") self.assertEqual( cache.lookup( function_cache.FunctionCacheKey( - MockShape(1, 2, 3), MockEmptyCaptureSnapshot(), ctx), True), - "a") + make_single_param_type(MockShape(1, 2, 3)), + MockEmptyCaptureSnapshot(), ctx), True), "a") def testMostSpecificFunctionCacheKeyIsOrderAgnostic(self): ctx = function_cache.FunctionContext(0) - keys = [(function_cache.FunctionCacheKey(MockShape(1, 1, 1), - MockEmptyCaptureSnapshot(), - ctx), "a"), - (function_cache.FunctionCacheKey(MockShape(1, None, 1), - MockEmptyCaptureSnapshot(), - ctx), "b"), - (function_cache.FunctionCacheKey(MockShape(None, None, 1), - MockEmptyCaptureSnapshot(), - ctx), "c"), - (function_cache.FunctionCacheKey(MockShape(None, None, None), - MockEmptyCaptureSnapshot(), - ctx), "d")] + keys = [(function_cache.FunctionCacheKey( + make_single_param_type(MockShape(1, 1, 1)), MockEmptyCaptureSnapshot(), + ctx), "a"), + (function_cache.FunctionCacheKey( + make_single_param_type(MockShape(1, None, 1)), + MockEmptyCaptureSnapshot(), ctx), "b"), + (function_cache.FunctionCacheKey( + make_single_param_type(MockShape(None, None, 1)), + MockEmptyCaptureSnapshot(), ctx), "c"), + (function_cache.FunctionCacheKey( + make_single_param_type(MockShape(None, None, None)), + MockEmptyCaptureSnapshot(), ctx), "d")] for permutation in itertools.permutations(keys): cache = function_cache.FunctionCache() @@ -277,28 +295,24 @@ class FunctionCacheTest(test.TestCase): self.assertEqual( cache.lookup( - function_cache.FunctionCacheKey(MockShape(1, 1, 1), - MockEmptyCaptureSnapshot(), - ctx), True), - "a") + function_cache.FunctionCacheKey( + make_single_param_type(MockShape(1, 1, 1)), + MockEmptyCaptureSnapshot(), ctx), True), "a") self.assertEqual( cache.lookup( - function_cache.FunctionCacheKey(MockShape(1, 2, 1), - MockEmptyCaptureSnapshot(), - ctx), True), - "b") + function_cache.FunctionCacheKey( + make_single_param_type(MockShape(1, 2, 1)), + MockEmptyCaptureSnapshot(), ctx), True), "b") self.assertEqual( cache.lookup( - function_cache.FunctionCacheKey(MockShape(2, 2, 1), - MockEmptyCaptureSnapshot(), - ctx), True), - "c") + function_cache.FunctionCacheKey( + make_single_param_type(MockShape(2, 2, 1)), + MockEmptyCaptureSnapshot(), ctx), True), "c") self.assertEqual( cache.lookup( - function_cache.FunctionCacheKey(MockShape(2, 2, 2), - MockEmptyCaptureSnapshot(), - ctx), True), - "d") + function_cache.FunctionCacheKey( + make_single_param_type(MockShape(2, 2, 2)), + MockEmptyCaptureSnapshot(), ctx), True), "d") def testWeakRefDeletionAlsoDeletesConcreteFunction(self): if not function_cache.DELETE_WITH_WEAKREF: @@ -447,17 +461,19 @@ class FunctionCacheBenchmark(test.Benchmark): for key in keys: cache.add(*key, "testing") cache.add( - function_cache.FunctionCacheKey(MockSubtypeOf2(2), - MockEmptyCaptureSnapshot(), None), + function_cache.FunctionCacheKey( + make_single_param_type(MockSubtypeOf2(2)), + MockEmptyCaptureSnapshot(), None), trace_type.WeakrefDeletionObserver(), "testing") - cache.lookup(function_cache.FunctionCacheKey(MockSubtypeOf2(3), - MockEmptyCaptureSnapshot(), - None), True) + cache.lookup( + function_cache.FunctionCacheKey( + make_single_param_type(MockSubtypeOf2(3)), + MockEmptyCaptureSnapshot(), None), True) iterations = 10000 - lookup_key = function_cache.FunctionCacheKey(MockSubtypeOf2(2), - MockEmptyCaptureSnapshot(), - None) + lookup_key = function_cache.FunctionCacheKey( + make_single_param_type(MockSubtypeOf2(2)), MockEmptyCaptureSnapshot(), + None) subtyping_time = timeit.timeit( lambda: cache.lookup(lookup_key, True), number=iterations) @@ -490,14 +506,15 @@ class FunctionCacheBenchmark(test.Benchmark): for key in keys: cache.add(*key, "testing") cache.add( - function_cache.FunctionCacheKey(MockSubtypeOf2(3), - MockEmptyCaptureSnapshot(), None), + function_cache.FunctionCacheKey( + make_single_param_type(MockSubtypeOf2(3)), + MockEmptyCaptureSnapshot(), None), trace_type.WeakrefDeletionObserver(), "testing") iterations = 10000 - lookup_key = function_cache.FunctionCacheKey(MockSubtypeOf2(2), - MockEmptyCaptureSnapshot(), - None) + lookup_key = function_cache.FunctionCacheKey( + make_single_param_type(MockSubtypeOf2(2)), MockEmptyCaptureSnapshot(), + None) subtyping_time = sum( timeit.repeat( stmt=lambda: cache.lookup(lookup_key, True), @@ -539,9 +556,7 @@ class CaptureSnapshotTest(test.TestCase): "b": MockIntGenericType(1), "c": MockIntGenericType(2) }) - self.snapshot_e = snapshot_type({ - "d": MockIntGenericType(1) - }) + self.snapshot_e = snapshot_type({"d": MockIntGenericType(1)}) def testCaptureSnapshotSubtype(self): self.assertFalse(self.snapshot_a.is_subtype_of(self.snapshot_b)) diff --git a/tensorflow/python/eager/polymorphic_function/function_context.py b/tensorflow/python/eager/polymorphic_function/function_context.py index a3a8a1cc2ad..7189c362ca7 100644 --- a/tensorflow/python/eager/polymorphic_function/function_context.py +++ b/tensorflow/python/eager/polymorphic_function/function_context.py @@ -17,6 +17,7 @@ from typing import Any, NamedTuple, Tuple from tensorflow.core.function import trace_type +from tensorflow.core.function.function_type import function_type from tensorflow.core.function.polymorphism import function_cache from tensorflow.python.eager import context from tensorflow.python.framework import device as pydev @@ -135,7 +136,15 @@ def make_cache_key( captures_signature = function_cache.CaptureSnapshot( captures_dict_tracetype.mapping) + # TODO(fmuham): Use the actual FunctionType + dummy_function_type = function_type.FunctionType([ + function_type.Parameter("args_kwargs", + function_type.Parameter.POSITIONAL_ONLY, + False, + args_signature) + ]) + return function_cache.FunctionCacheKey( - args_signature, + dummy_function_type, captures_signature, make_function_context()), signature_context.deletion_observer -- GitLab