提交 8d8abfb5 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

TypeDispatchTable uses FunctionType

PiperOrigin-RevId: 481275791
上级 2cffb704
......@@ -13,7 +13,7 @@ pytype_strict_library(
srcs_version = "PY3",
visibility = ["//tensorflow:internal"],
deps = [
":function_type",
"//tensorflow/python/types",
],
)
......@@ -22,7 +22,6 @@ py_strict_test(
srcs = ["type_dispatch_test.py"],
python_version = "PY3",
deps = [
":function_type",
":type_dispatch",
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/types",
......@@ -56,6 +55,7 @@ py_strict_test(
"//tensorflow/core/function/polymorphism:function_type",
"//tensorflow/core/function/trace_type",
"//tensorflow/python:array_ops",
"//tensorflow/python/eager/polymorphic_function:function_context",
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/types",
],
......
......@@ -15,11 +15,12 @@
"""Cache to manage concrete functions and their signatures."""
import collections
from typing import Any, NamedTuple, Optional
from typing import Any, NamedTuple, Optional, Sequence
from tensorflow.core.function import trace_type
from tensorflow.core.function.polymorphism import function_type as function_type_lib
from tensorflow.core.function.polymorphism import type_dispatch
from tensorflow.python.types import trace
# TODO(b/182990542): Enable and remove flag when stable.
DELETE_WITH_WEAKREF = False
......@@ -30,77 +31,127 @@ class FunctionContext(NamedTuple):
context: Any
# TODO(fmuham): Remove inheritance from TraceType.
class FunctionCacheKey(trace.TraceType):
"""The unique key associated with a concrete function.
Attributes:
function_type: A FunctionType corresponding to the function arguments.
call_context: The FunctionContext for when the function was called.
"""
def __init__(self, function_type: function_type_lib.FunctionType,
call_context: FunctionContext):
self.function_type = function_type
self.call_context = call_context
def is_subtype_of(self, other: trace.TraceType) -> bool:
if not isinstance(other, FunctionCacheKey):
return False
if self.call_context != other.call_context:
return False
return self.function_type.is_supertype_of(other.function_type)
def most_specific_common_supertype(
self, others: Sequence[trace.TraceType]) -> Optional["FunctionCacheKey"]:
if not all(
isinstance(other, FunctionCacheKey) and
self.call_context == other.call_context for other in others):
return None
function_type_common = self.function_type.most_specific_common_subtype(
[other.function_type for other in others])
if function_type_common is None:
return None
return FunctionCacheKey(function_type_common, self.call_context)
def _placeholder_value(self) -> Any:
"""Value used for tracing a function signature with this TraceType."""
return self.function_type.placeholder_arguments().args[0]
def __hash__(self) -> int:
return hash((self.call_context, self.function_type))
def __eq__(self, other) -> bool:
if not isinstance(other, trace.TraceType):
return NotImplemented
if not isinstance(other, FunctionCacheKey):
return False
return (self.call_context == other.call_context and
self.function_type == other.function_type)
def __repr__(self) -> str:
return (f"{type(self).__name__}(function_type={repr(self.function_type)},"
f" call_context={repr(self.call_context)})")
# TODO(fmuham): Rename to FunctionLibrary.
class FunctionCache:
"""A container for managing concrete functions."""
__slots__ = ["_primary", "_dispatch_dict", "_garbage_collectors"]
__slots__ = ["_primary", "_dispatch_table", "_garbage_collectors"]
def __init__(self):
# Maps (FunctionContext, FunctionType) to a concrete function.
# The primary cache, mapping FunctionCacheKey to a concrete function.
self._primary = collections.OrderedDict()
# Maps FunctionContext to a TypeDispatchTable containing FunctionTypes of
# that particular context.
self._dispatch_dict = {}
# Maps a FunctionCacheKey K to a FunctionCacheKey V such that it is safe
# to dispatch K to the concrete function of V that exists in _primary.
# Used to lookup posible concrete functions when K is not in _primary.
self._dispatch_table = type_dispatch.TypeDispatchTable()
def lookup(self, context: FunctionContext,
function_type: function_type_lib.FunctionType) -> Optional[Any]:
"""Looks up a concrete function based on the context and type."""
if context in self._dispatch_dict:
dispatch_type = self._dispatch_dict[context].dispatch(function_type)
if dispatch_type:
return self._primary[(context, dispatch_type)]
# Note: Instead of returning any viable function, we can return the most
# specfic one by maintaining trees of traces where children are more specific
# traces of their parents.
def lookup(self, key: FunctionCacheKey, use_function_subtyping: bool):
"""Looks up a concrete function based on the key."""
if not use_function_subtyping:
return self._primary.get(key, None)
dispatch_key = self._dispatch_table.dispatch(key)
if dispatch_key is not None:
return self._primary[dispatch_key]
return None
def delete(self, context: FunctionContext,
function_type: function_type_lib.FunctionType) -> bool:
"""Deletes a concrete function given the context and type."""
if (context, function_type) not in self._primary:
def delete(self, key: FunctionCacheKey):
"""Deletes a concrete function given the key it was added with."""
if key not in self._primary:
return False
del self._primary[(context, function_type)]
self._dispatch_dict[context].delete(function_type)
del self._primary[key]
self._dispatch_table.delete(key)
return True
def add(self, context: FunctionContext,
function_type: function_type_lib.FunctionType,
deletion_observer: trace_type.WeakrefDeletionObserver,
concrete_fn: Any):
def add(self, key: FunctionCacheKey,
deletion_observer: trace_type.WeakrefDeletionObserver, concrete: ...):
"""Adds a new concrete function alongside its key.
Args:
context: A FunctionContext representing the current context.
function_type: A FunctionType representing concrete_fn signature.
deletion_observer: A WeakrefDeletionObserver for the concrete_fn validity.
concrete_fn: The concrete function to be added to the cache.
key: A FunctionCacheKey object corresponding to the provided `concrete`.
deletion_observer: A WeakrefDeletionObserver object for the `key`.
concrete: The concrete function to be added to the cache.
"""
self._primary[(context, function_type)] = concrete_fn
if context not in self._dispatch_dict:
self._dispatch_dict[context] = type_dispatch.TypeDispatchTable()
self._dispatch_dict[context].add_target(function_type)
listener_fn = (lambda: self.delete(context, function_type)
) if DELETE_WITH_WEAKREF else lambda: None
deletion_observer.add_listener(listener_fn)
def generalize(
self, context: FunctionContext,
function_type: function_type_lib.FunctionType
) -> function_type_lib.FunctionType:
"""Try to generalize a FunctionType within a FunctionContext."""
if context in self._dispatch_dict:
return self._dispatch_dict[context].try_generalizing_function_type(
function_type)
else:
return function_type
self._primary[key] = concrete
self._dispatch_table.add_target(key)
deletion_observer.add_listener(
lambda: self.delete(key) if DELETE_WITH_WEAKREF else None)
def generalize(self, key: FunctionCacheKey) -> FunctionCacheKey:
return self._dispatch_table.try_generalizing_trace_type(key) # pylint: disable=protected-access
# TODO(b/205971333): Remove this function.
def clear(self):
"""Removes all concrete functions from the cache."""
self._primary.clear()
self._dispatch_dict.clear()
self._dispatch_table.clear()
def values(self):
"""Returns a list of all `ConcreteFunction` instances held by this cache."""
......
......@@ -17,7 +17,7 @@
import collections
from typing import Optional, Iterable
from tensorflow.core.function.polymorphism import function_type
from tensorflow.python.types import trace
# The maximum number of dispatch lookups to cache.
_MAX_DISPATCH_CACHE = 1024
......@@ -28,9 +28,9 @@ class TypeDispatchTable:
A type dispatch table is a list, L, of target types. Given a request type, R,
the table selects a target type, T, according to the following dispatch rules:
1. R == T or R is supertype of T (functions are contravariant on args)
2. There does not exist O in L such that R is supertype of O and O is a
supertype of T (in other words, T is the closest to R, within list L).
1. R == T or R is subtype of T
2. There does not exist O in L such that R is subtype of O and O is a
subtype of T (in other words, T is the closest to R, within list L).
3. If the above two rules are satisfied by multiple targets, the earliest
inserted one is chosen.
"""
......@@ -46,19 +46,19 @@ class TypeDispatchTable:
# Does not contain exact matches, i.e, if cache[a] is b then a is not b.
self._dispatch_cache = collections.OrderedDict()
def add_target(self, target: function_type.FunctionType) -> None:
def add_target(self, target: trace.TraceType) -> None:
"""Adds a new target type."""
self._dispatch_table[target] = None
for request in self._dispatch_cache:
if target.is_supertype_of(self._dispatch_cache[request]):
if target.is_subtype_of(self._dispatch_cache[request]):
self._dispatch_cache[request] = target
@property
def targets(self) -> Iterable[function_type.FunctionType]:
def targets(self) -> Iterable[trace.TraceType]:
"""Returns an iterable to all targets in the table."""
return self._dispatch_table.keys()
def delete(self, target: function_type.FunctionType) -> None:
def delete(self, target: trace.TraceType) -> None:
"""Deletes a target in the table if it exists."""
if target in self._dispatch_table:
del self._dispatch_table[target]
......@@ -72,10 +72,8 @@ class TypeDispatchTable:
self._dispatch_table.clear()
self._dispatch_cache.clear()
def dispatch(
self, request: function_type.FunctionType
) -> Optional[function_type.FunctionType]:
"""Returns the most specific supertype target if it exists in the table."""
def dispatch(self, request: trace.TraceType) -> Optional[trace.TraceType]:
"""Returns the deepest subtype target if it exists in the table."""
# For known exact matches.
if request in self._dispatch_table:
return request
......@@ -88,15 +86,15 @@ class TypeDispatchTable:
self._dispatch_cache[request] = result
return result
most_specific_supertype = None
most_specific_subtype = None
for other in self._dispatch_table:
if request.is_supertype_of(other):
if most_specific_supertype is None or other.is_supertype_of(
most_specific_supertype):
most_specific_supertype = other
if request.is_subtype_of(other):
if most_specific_subtype is None or other.is_subtype_of(
most_specific_subtype):
most_specific_subtype = other
self._cache_dispatch(request, most_specific_supertype)
return most_specific_supertype
self._cache_dispatch(request, most_specific_subtype)
return most_specific_subtype
def _cache_dispatch(self, request, target):
"""Caches the dispatch lookup result for a target."""
......@@ -106,26 +104,26 @@ class TypeDispatchTable:
self._dispatch_cache.popitem(last=False)
self._dispatch_cache[request] = target
def try_generalizing_function_type(
self, target: function_type.FunctionType) -> function_type.FunctionType:
def try_generalizing_trace_type(self,
target: trace.TraceType) -> trace.TraceType:
"""Returns a generalized subtype of the one given.
This heuristic aims to reduce the number of future traces by computing a
type that represents more general function inputs.
The original "experimental_relax_shapes" heuristic identified a known type
which shared a common subtype with the current unknown type and then
traced with that common subtype. However, the notion of "common subtype"
was only limited to shapes. This heuristic extends that to FunctionType.
which shared a common supertype with the current unknown type and then
traced with that common supertype. However, the notion of "common supertype"
was only limited to shapes. This heuristic extends that to TraceType.
Returns `target` if a generalized subtype can not be found.
Returns `target` if a common supertype can not be found.
Args:
target: The FunctionType to generalize
target: The TraceType to generalize
"""
relaxed = target
for other in self._dispatch_table:
subtype = relaxed.most_specific_common_subtype([other])
if subtype is not None:
relaxed = subtype
supertype = relaxed.most_specific_common_supertype([other])
if supertype is not None:
relaxed = supertype
return relaxed
......@@ -16,7 +16,6 @@
from typing import Optional
from tensorflow.core.function.polymorphism import function_type
from tensorflow.core.function.polymorphism import type_dispatch
from tensorflow.python.platform import test
from tensorflow.python.types import trace
......@@ -27,7 +26,7 @@ class MockShape(trace.TraceType):
def __init__(self, *shape: Optional[int]):
self.shape = shape
def is_subtype_of(self, other: "MockShape") -> bool:
def is_subtype_of(self, other: "MockShape") ->bool:
if len(self.shape) != len(other.shape):
return False
......@@ -57,214 +56,186 @@ class MockShape(trace.TraceType):
return self.shape == other.shape
def make_shape_function_type(*shape):
return function_type.FunctionType([
function_type.Parameter("x", function_type.Parameter.POSITIONAL_ONLY,
False, MockShape(*shape))
])
class TypeDispatchTableTest(test.TestCase):
def testVertical(self):
table = type_dispatch.TypeDispatchTable()
table.add_target(make_shape_function_type(None, None, None))
table.add_target(make_shape_function_type(None, None, 1))
table.add_target(make_shape_function_type(None, 1, 1))
table.add_target(make_shape_function_type(1, 1, 1))
table.add_target(MockShape(None, None, None))
table.add_target(MockShape(None, None, 1))
table.add_target(MockShape(None, 1, 1))
table.add_target(MockShape(1, 1, 1))
self.assertEqual(
list(table.targets), [
make_shape_function_type(None, None, None),
make_shape_function_type(None, None, 1),
make_shape_function_type(None, 1, 1),
make_shape_function_type(1, 1, 1)
MockShape(None, None, None),
MockShape(None, None, 1),
MockShape(None, 1, 1),
MockShape(1, 1, 1)
])
def testHorizontal(self):
table = type_dispatch.TypeDispatchTable()
table.add_target(make_shape_function_type(1,))
table.add_target(make_shape_function_type(1, 2))
table.add_target(make_shape_function_type(1, 2, 3))
table.add_target(MockShape(1,))
table.add_target(MockShape(1, 2))
table.add_target(MockShape(1, 2, 3))
self.assertEqual(
list(table.targets), [
make_shape_function_type(1,),
make_shape_function_type(1, 2),
make_shape_function_type(1, 2, 3)
MockShape(1,),
MockShape(1, 2),
MockShape(1, 2, 3)
])
def testDuplicateNodes(self):
table = type_dispatch.TypeDispatchTable()
table.add_target(make_shape_function_type(None, None))
table.add_target(make_shape_function_type(1, None))
table.add_target(make_shape_function_type(None, 2))
table.add_target(make_shape_function_type(None, None))
table.add_target(MockShape(None, None))
table.add_target(MockShape(1, None))
table.add_target(MockShape(None, 2))
table.add_target(MockShape(None, None))
self.assertEqual(
list(table.targets), [
make_shape_function_type(None, None),
make_shape_function_type(1, None),
make_shape_function_type(None, 2)
MockShape(None, None),
MockShape(1, None),
MockShape(None, 2)
])
def testDeletion(self):
table = type_dispatch.TypeDispatchTable()
table.add_target(make_shape_function_type(None, None))
table.add_target(make_shape_function_type(None, 1))
table.add_target(make_shape_function_type(None, 2))
table.add_target(MockShape(None, None))
table.add_target(MockShape(None, 1))
table.add_target(MockShape(None, 2))
self.assertEqual(
list(table.targets), [
make_shape_function_type(None, None),
make_shape_function_type(None, 1),
make_shape_function_type(None, 2)
MockShape(None, None),
MockShape(None, 1),
MockShape(None, 2)
])
table.delete(make_shape_function_type(None, 2)) # Should remove the target
table.delete(MockShape(None, 2)) # Should remove the target
self.assertEqual(
list(table.targets), [
make_shape_function_type(None, None),
make_shape_function_type(None, 1),
MockShape(None, None),
MockShape(None, 1),
])
table.delete(make_shape_function_type(None, 2)) # Should have no effect
table.delete(MockShape(None, 2)) # Should have no effect
self.assertEqual(
list(table.targets), [
make_shape_function_type(None, None),
make_shape_function_type(None, 1),
MockShape(None, None),
MockShape(None, 1),
])
def testContains(self):
table = type_dispatch.TypeDispatchTable()
table.add_target(make_shape_function_type(None, None, None))
table.add_target(make_shape_function_type(None, 1))
table.add_target(make_shape_function_type(1, 1))
table.add_target(make_shape_function_type(None, 2, 1))
table.add_target(MockShape(None, None, None))
table.add_target(MockShape(None, 1))
table.add_target(MockShape(1, 1))
table.add_target(MockShape(None, 2, 1))
self.assertIn(make_shape_function_type(None, None, None), table.targets)
self.assertIn(make_shape_function_type(None, 1), table.targets)
self.assertIn(make_shape_function_type(1, 1), table.targets)
self.assertIn(make_shape_function_type(None, 2, 1), table.targets)
self.assertIn(MockShape(None, None, None), table.targets)
self.assertIn(MockShape(None, 1), table.targets)
self.assertIn(MockShape(1, 1), table.targets)
self.assertIn(MockShape(None, 2, 1), table.targets)
self.assertNotIn(make_shape_function_type(None, None, 1), table.targets)
self.assertNotIn(make_shape_function_type(1, None), table.targets)
self.assertNotIn(make_shape_function_type(1, 2), table.targets)
self.assertNotIn(make_shape_function_type(None, 2, None), table.targets)
self.assertNotIn(MockShape(None, None, 1), table.targets)
self.assertNotIn(MockShape(1, None), table.targets)
self.assertNotIn(MockShape(1, 2), table.targets)
self.assertNotIn(MockShape(None, 2, None), table.targets)
def testDispatchExactMatches(self):
table = type_dispatch.TypeDispatchTable()
table.add_target(make_shape_function_type(None, None, None))
table.add_target(make_shape_function_type(None, 1, None))
table.add_target(make_shape_function_type(None, 1, 2))
table.add_target(make_shape_function_type(None, 2, 2))
table.add_target(MockShape(None, None, None))
table.add_target(MockShape(None, 1, None))
table.add_target(MockShape(None, 1, 2))
table.add_target(MockShape(None, 2, 2))
self.assertEqual(
table.dispatch(make_shape_function_type(None, 1, 2)),
make_shape_function_type(None, 1, 2))
table.dispatch(MockShape(None, 1, 2)), MockShape(None, 1, 2))
self.assertEqual(
table.dispatch(make_shape_function_type(None, 1, None)),
make_shape_function_type(None, 1, None))
table.dispatch(MockShape(None, 1, None)), MockShape(None, 1, None))
self.assertEqual(
table.dispatch(make_shape_function_type(None, None, None)),
make_shape_function_type(None, None, None))
table.dispatch(MockShape(None, None, None)),
MockShape(None, None, None))
self.assertEqual(
table.dispatch(make_shape_function_type(None, 2, 2)),
make_shape_function_type(None, 2, 2))
table.dispatch(MockShape(None, 2, 2)), MockShape(None, 2, 2))
def testDispatchMoreSpecific(self):
table = type_dispatch.TypeDispatchTable()
table.add_target(make_shape_function_type(None, None, None))
table.add_target(make_shape_function_type(None, 1, None))
table.add_target(make_shape_function_type(None, 1, 2))
table.add_target(make_shape_function_type(None, 2, 2))
table.add_target(MockShape(None, None, None))
table.add_target(MockShape(None, 1, None))
table.add_target(MockShape(None, 1, 2))
table.add_target(MockShape(None, 2, 2))
self.assertEqual(table.dispatch(MockShape(1, 1, 2)), MockShape(None, 1, 2))
self.assertEqual(
table.dispatch(make_shape_function_type(1, 1, 2)),
make_shape_function_type(None, 1, 2))
self.assertEqual(
table.dispatch(make_shape_function_type(1, 1, 3)),
make_shape_function_type(None, 1, None))
table.dispatch(MockShape(1, 1, 3)), MockShape(None, 1, None))
self.assertEqual(
table.dispatch(make_shape_function_type(1, 3, 3)),
make_shape_function_type(None, None, None))
self.assertEqual(
table.dispatch(make_shape_function_type(1, 2, 2)),
make_shape_function_type(None, 2, 2))
table.dispatch(MockShape(1, 3, 3)), MockShape(None, None, None))
self.assertEqual(table.dispatch(MockShape(1, 2, 2)), MockShape(None, 2, 2))
def testDispatchNoMatches(self):
table = type_dispatch.TypeDispatchTable()
table.add_target(make_shape_function_type(None, 1, None))
table.add_target(make_shape_function_type(None, 1, 2))
table.add_target(make_shape_function_type(None, 2, 2))
table.add_target(MockShape(None, 1, None))
table.add_target(MockShape(None, 1, 2))
table.add_target(MockShape(None, 2, 2))
self.assertIsNone(table.dispatch(make_shape_function_type(1, 2)))
self.assertIsNone(table.dispatch(make_shape_function_type(1, 2, 3)))
self.assertIsNone(table.dispatch(make_shape_function_type(1, 2, 3, 4)))
self.assertIsNone(table.dispatch(MockShape(1, 2)))
self.assertIsNone(table.dispatch(MockShape(1, 2, 3)))
self.assertIsNone(table.dispatch(MockShape(1, 2, 3, 4)))
def testDispatchCachedAddUpdates(self):
table = type_dispatch.TypeDispatchTable()
table.add_target(make_shape_function_type(None, None, None))
table.add_target(MockShape(None, None, None))
self.assertEqual(
table.dispatch(make_shape_function_type(1, 1, 2)),
make_shape_function_type(None, None, None))
table.dispatch(MockShape(1, 1, 2)), MockShape(None, None, None))
table.add_target(make_shape_function_type(None, 1, None))
table.add_target(MockShape(None, 1, None))
self.assertEqual(
table.dispatch(make_shape_function_type(1, 1, 2)),
make_shape_function_type(None, 1, None))
table.dispatch(MockShape(1, 1, 2)), MockShape(None, 1, None))
table.add_target(make_shape_function_type(None, 1, 2))
self.assertEqual(
table.dispatch(make_shape_function_type(1, 1, 2)),
make_shape_function_type(None, 1, 2))
table.add_target(MockShape(None, 1, 2))
self.assertEqual(table.dispatch(MockShape(1, 1, 2)), MockShape(None, 1, 2))
table.add_target(make_shape_function_type(1, 1, 2))
self.assertEqual(
table.dispatch(make_shape_function_type(1, 1, 2)),
make_shape_function_type(1, 1, 2))
table.add_target(MockShape(1, 1, 2))
self.assertEqual(table.dispatch(MockShape(1, 1, 2)), MockShape(1, 1, 2))
def testDispatchCachedDeleteUpdates(self):
table = type_dispatch.TypeDispatchTable()
table.add_target(make_shape_function_type(None, None, None))
table.add_target(make_shape_function_type(None, 1, None))
table.add_target(make_shape_function_type(None, 1, 2))
table.add_target(make_shape_function_type(1, 1, 2))
table.add_target(MockShape(None, None, None))
table.add_target(MockShape(None, 1, None))
table.add_target(MockShape(None, 1, 2))
table.add_target(MockShape(1, 1, 2))
self.assertEqual(
table.dispatch(make_shape_function_type(1, 1, 2)),
make_shape_function_type(1, 1, 2))
self.assertEqual(table.dispatch(MockShape(1, 1, 2)), MockShape(1, 1, 2))
table.delete(make_shape_function_type(1, 1, 2))
self.assertEqual(
table.dispatch(make_shape_function_type(1, 1, 2)),
make_shape_function_type(None, 1, 2))
table.delete(MockShape(1, 1, 2))
self.assertEqual(table.dispatch(MockShape(1, 1, 2)), MockShape(None, 1, 2))
table.delete(make_shape_function_type(None, 1, 2))
table.delete(MockShape(None, 1, 2))
self.assertEqual(
table.dispatch(make_shape_function_type(1, 1, 2)),
make_shape_function_type(None, 1, None))
table.dispatch(MockShape(1, 1, 2)), MockShape(None, 1, None))
table.delete(make_shape_function_type(None, 1, None))
table.delete(MockShape(None, 1, None))
self.assertEqual(
table.dispatch(make_shape_function_type(1, 1, 2)),
make_shape_function_type(None, None, None))
table.dispatch(MockShape(1, 1, 2)), MockShape(None, None, None))
def testDispatchCacheOrderingDeterminism(self):
table_1 = type_dispatch.TypeDispatchTable()
table_1.add_target(make_shape_function_type(1, None, None))
table_1.add_target(make_shape_function_type(None, 2, None))
table_1.add_target(make_shape_function_type(None, None, 3))
table_1.add_target(MockShape(1, None, None))
table_1.add_target(MockShape(None, 2, None))
table_1.add_target(MockShape(None, None, 3))
table_2 = type_dispatch.TypeDispatchTable()
table_2.add_target(make_shape_function_type(None, 2, None))
table_2.add_target(make_shape_function_type(1, None, None))
table_2.add_target(make_shape_function_type(None, None, 3))
table_2.add_target(MockShape(None, 2, None))
table_2.add_target(MockShape(1, None, None))
table_2.add_target(MockShape(None, None, 3))
table_3 = type_dispatch.TypeDispatchTable()
table_3.add_target(make_shape_function_type(None, None, 3))
table_3.add_target(make_shape_function_type(1, None, None))
table_3.add_target(make_shape_function_type(None, 2, None))
table_3.add_target(MockShape(None, None, 3))
table_3.add_target(MockShape(1, None, None))
table_3.add_target(MockShape(None, 2, None))
# table_1, table_2, table_3 have the same targets
self.assertEqual(set(table_1.targets), set(table_2.targets))
......@@ -272,43 +243,36 @@ class TypeDispatchTableTest(test.TestCase):
# But they dispatch to the first target they find which does not have any
# more specific viable target.
shape = make_shape_function_type(1, 2, 3)
self.assertEqual(
table_1.dispatch(shape), make_shape_function_type(1, None, None))
self.assertEqual(
table_2.dispatch(shape), make_shape_function_type(None, 2, None))
self.assertEqual(
table_3.dispatch(shape), make_shape_function_type(None, None, 3))
shape = MockShape(1, 2, 3)
self.assertEqual(table_1.dispatch(shape), MockShape(1, None, None))
self.assertEqual(table_2.dispatch(shape), MockShape(None, 2, None))
self.assertEqual(table_3.dispatch(shape), MockShape(None, None, 3))
def testGeneralizedExisting(self):
table = type_dispatch.TypeDispatchTable()
table.add_target(make_shape_function_type(None, None, None))
table.add_target(make_shape_function_type(None, 1, None))
table.add_target(make_shape_function_type(None, 1, 2))
table.add_target(MockShape(None, None, None))
table.add_target(MockShape(None, 1, None))
table.add_target(MockShape(None, 1, 2))
self.assertEqual(
table.try_generalizing_function_type(
make_shape_function_type(None, 1, 3)),
make_shape_function_type(None, None, None))
table.try_generalizing_trace_type(MockShape(None, 1, 3)),
MockShape(None, None, None))
def testGeneralizedNovel(self):
table = type_dispatch.TypeDispatchTable()
table.add_target(make_shape_function_type(None, 1, None))
table.add_target(make_shape_function_type(None, 1, 2))
table.add_target(MockShape(None, 1, None))
table.add_target(MockShape(None, 1, 2))
self.assertEqual(
table.try_generalizing_function_type(
make_shape_function_type(None, 2, 3)),
make_shape_function_type(None, None, None))
table.try_generalizing_trace_type(MockShape(None, 2, 3)),
MockShape(None, None, None))
def testGeneralizedUnknown(self):
table = type_dispatch.TypeDispatchTable()
table.add_target(make_shape_function_type(None, 1))
table.add_target(make_shape_function_type(None, 2))
table.add_target(make_shape_function_type(None, 3))
table.add_target(MockShape(None, 1))
table.add_target(MockShape(None, 2))
table.add_target(MockShape(None, 3))
self.assertEqual(
table.try_generalizing_function_type(
make_shape_function_type(None, 4, 3)),
make_shape_function_type(None, 4, 3))
table.try_generalizing_trace_type(MockShape(None, 4, 3)),
MockShape(None, 4, 3))
if __name__ == "__main__":
test.main()
......@@ -125,21 +125,24 @@ def _enclosing_xla_context():
def make_cache_key(
args: Any,
captures: Any = None,
) -> Tuple[function_cache.FunctionContext, function_type.FunctionType,
trace_type.WeakrefDeletionObserver]:
) -> Tuple[function_cache.FunctionCacheKey, trace_type.WeakrefDeletionObserver]:
"""Computes the cache key given the function arguments."""
if captures is None:
captures = dict()
signature_context = trace_type.InternalTracingContext()
args_signature = trace_type.from_value(args, signature_context)
captures_dict_tracetype = trace_type.from_value(captures, signature_context)
args_signature = trace_type.from_value(
args, signature_context)
captures_dict_tracetype = trace_type.from_value(
captures, signature_context)
# TODO(fmuham): Use the actual FunctionType
dummy_function_type = function_type.FunctionType([
function_type.Parameter("args_kwargs",
function_type.Parameter.POSITIONAL_ONLY, False,
function_type.Parameter.POSITIONAL_ONLY,
False,
args_signature)
], collections.OrderedDict(captures_dict_tracetype.mapping))
return (make_function_context(), dummy_function_type,
signature_context.deletion_observer)
return function_cache.FunctionCacheKey(
dummy_function_type,
make_function_context()), signature_context.deletion_observer
......@@ -95,8 +95,8 @@ class TracingCompiler:
capture_by_value: Experimental. Whether to capture resource variables by
value or reference. If None, will inherit from a parent context or
default to False.
jit_compile: Force-compile the function with XLA, cf. tf.function doc on
jit_compile.
jit_compile: Force-compile the function with XLA, cf.
tf.function doc on jit_compile.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
......@@ -105,7 +105,9 @@ class TracingCompiler:
self._python_function = python_function
pure_function = attributes and monomorphic_function.IMPLEMENTS_ATTRIBUTE_NAME in attributes
self._function_spec = function_spec.FunctionSpec.from_function_and_signature(
python_function, input_signature, is_pure=pure_function)
python_function,
input_signature,
is_pure=pure_function)
self._name = name
self._autograph = autograph
self._autograph_options = autograph_options
......@@ -331,10 +333,9 @@ class TracingCompiler:
# cache_key_deletion_observer is useless here. It's based on all captures.
# A new cache key will be built later when saving ConcreteFunction because
# only active captures should be saved.
lookup_func_context, lookup_func_type, _ = function_context.make_cache_key(
(args, kwargs), captures)
concrete_function = self._function_cache.lookup(lookup_func_context,
lookup_func_type)
lookup_func_key, _ = function_context.make_cache_key((args, kwargs),
captures)
concrete_function = self._function_cache.lookup(lookup_func_key, True)
if concrete_function is not None:
return concrete_function, filtered_flat_args
......@@ -342,8 +343,7 @@ class TracingCompiler:
with trace.Trace("tf.function-graph_building"):
logging.vlog(1,
"Creating new FuncGraph for Python function %r (key: %r)",
self._python_function, lookup_func_context,
lookup_func_type)
self._python_function, lookup_func_key)
logging.vlog(2, "Python function signature [args: %s] [kwargs: %s]",
args, kwargs)
ag_status = (
......@@ -352,10 +352,10 @@ class TracingCompiler:
with ag_ctx.ControlStatusCtx(
status=ag_status, options=self._autograph_options):
if self.input_signature is None and self._reduce_retracing:
general_func_type = self._function_cache.generalize(
lookup_func_context, lookup_func_type)
placeholder_bound_args = general_func_type.placeholder_arguments()
args, kwargs = placeholder_bound_args.args[0]
generalized_func_key = self._function_cache.generalize(
lookup_func_key)
# Only get placeholders for arguments, not captures
args, kwargs = generalized_func_key._placeholder_value() # pylint: disable=protected-access
concrete_function = self._create_concrete_function(args, kwargs)
......@@ -366,10 +366,10 @@ class TracingCompiler:
captures = graph_capture_container.get_snapshot()
# Create a cache_key with args and captures
traced_func_context, traced_func_type, traced_func_deletion_observer = (
traced_func_key, traced_func_deletion_observer = (
function_context.make_cache_key((args, kwargs), captures))
self._function_cache.add(traced_func_context, traced_func_type,
self._function_cache.add(traced_func_key,
traced_func_deletion_observer,
concrete_function)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册