Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wux_labs
Tensorflow
提交
8d8abfb5
T
Tensorflow
项目概览
wux_labs
/
Tensorflow
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8d8abfb5
编写于
10月 14, 2022
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
10月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
TypeDispatchTable uses FunctionType
PiperOrigin-RevId: 481275791
上级
2cffb704
变更
7
展开全部
隐藏空白更改
内联
并排
Showing
7 changed file
with
447 addition
and
385 deletion
+447
-385
tensorflow/core/function/polymorphism/BUILD
tensorflow/core/function/polymorphism/BUILD
+2
-2
tensorflow/core/function/polymorphism/function_cache.py
tensorflow/core/function/polymorphism/function_cache.py
+98
-47
tensorflow/core/function/polymorphism/function_cache_test.py
tensorflow/core/function/polymorphism/function_cache_test.py
+180
-134
tensorflow/core/function/polymorphism/type_dispatch.py
tensorflow/core/function/polymorphism/type_dispatch.py
+27
-29
tensorflow/core/function/polymorphism/type_dispatch_test.py
tensorflow/core/function/polymorphism/type_dispatch_test.py
+115
-151
tensorflow/python/eager/polymorphic_function/function_context.py
...low/python/eager/polymorphic_function/function_context.py
+10
-7
tensorflow/python/eager/polymorphic_function/tracing_compiler.py
...low/python/eager/polymorphic_function/tracing_compiler.py
+15
-15
未找到文件。
tensorflow/core/function/polymorphism/BUILD
浏览文件 @
8d8abfb5
...
...
@@ -13,7 +13,7 @@ pytype_strict_library(
srcs_version
=
"PY3"
,
visibility
=
[
"//tensorflow:internal"
],
deps
=
[
"
:function_type
"
,
"
//tensorflow/python/types
"
,
],
)
...
...
@@ -22,7 +22,6 @@ py_strict_test(
srcs
=
[
"type_dispatch_test.py"
],
python_version
=
"PY3"
,
deps
=
[
":function_type"
,
":type_dispatch"
,
"//tensorflow/python/platform:client_testlib"
,
"//tensorflow/python/types"
,
...
...
@@ -56,6 +55,7 @@ py_strict_test(
"//tensorflow/core/function/polymorphism:function_type"
,
"//tensorflow/core/function/trace_type"
,
"//tensorflow/python:array_ops"
,
"//tensorflow/python/eager/polymorphic_function:function_context"
,
"//tensorflow/python/platform:client_testlib"
,
"//tensorflow/python/types"
,
],
...
...
tensorflow/core/function/polymorphism/function_cache.py
浏览文件 @
8d8abfb5
...
...
@@ -15,11 +15,12 @@
"""Cache to manage concrete functions and their signatures."""
import
collections
from
typing
import
Any
,
NamedTuple
,
Optional
from
typing
import
Any
,
NamedTuple
,
Optional
,
Sequence
from
tensorflow.core.function
import
trace_type
from
tensorflow.core.function.polymorphism
import
function_type
as
function_type_lib
from
tensorflow.core.function.polymorphism
import
type_dispatch
from
tensorflow.python.types
import
trace
# TODO(b/182990542): Enable and remove flag when stable.
DELETE_WITH_WEAKREF
=
False
...
...
@@ -30,77 +31,127 @@ class FunctionContext(NamedTuple):
context
:
Any
# TODO(fmuham): Remove inheritance from TraceType.
class
FunctionCacheKey
(
trace
.
TraceType
):
"""The unique key associated with a concrete function.
Attributes:
function_type: A FunctionType corresponding to the function arguments.
call_context: The FunctionContext for when the function was called.
"""
def
__init__
(
self
,
function_type
:
function_type_lib
.
FunctionType
,
call_context
:
FunctionContext
):
self
.
function_type
=
function_type
self
.
call_context
=
call_context
def
is_subtype_of
(
self
,
other
:
trace
.
TraceType
)
->
bool
:
if
not
isinstance
(
other
,
FunctionCacheKey
):
return
False
if
self
.
call_context
!=
other
.
call_context
:
return
False
return
self
.
function_type
.
is_supertype_of
(
other
.
function_type
)
def
most_specific_common_supertype
(
self
,
others
:
Sequence
[
trace
.
TraceType
])
->
Optional
[
"FunctionCacheKey"
]:
if
not
all
(
isinstance
(
other
,
FunctionCacheKey
)
and
self
.
call_context
==
other
.
call_context
for
other
in
others
):
return
None
function_type_common
=
self
.
function_type
.
most_specific_common_subtype
(
[
other
.
function_type
for
other
in
others
])
if
function_type_common
is
None
:
return
None
return
FunctionCacheKey
(
function_type_common
,
self
.
call_context
)
def
_placeholder_value
(
self
)
->
Any
:
"""Value used for tracing a function signature with this TraceType."""
return
self
.
function_type
.
placeholder_arguments
().
args
[
0
]
def
__hash__
(
self
)
->
int
:
return
hash
((
self
.
call_context
,
self
.
function_type
))
def
__eq__
(
self
,
other
)
->
bool
:
if
not
isinstance
(
other
,
trace
.
TraceType
):
return
NotImplemented
if
not
isinstance
(
other
,
FunctionCacheKey
):
return
False
return
(
self
.
call_context
==
other
.
call_context
and
self
.
function_type
==
other
.
function_type
)
def
__repr__
(
self
)
->
str
:
return
(
f
"
{
type
(
self
).
__name__
}
(function_type=
{
repr
(
self
.
function_type
)
}
,"
f
" call_context=
{
repr
(
self
.
call_context
)
}
)"
)
# TODO(fmuham): Rename to FunctionLibrary.
class
FunctionCache
:
"""A container for managing concrete functions."""
__slots__
=
[
"_primary"
,
"_dispatch_
dict
"
,
"_garbage_collectors"
]
__slots__
=
[
"_primary"
,
"_dispatch_
table
"
,
"_garbage_collectors"
]
def
__init__
(
self
):
#
Maps (FunctionContext, FunctionType)
to a concrete function.
#
The primary cache, mapping FunctionCacheKey
to a concrete function.
self
.
_primary
=
collections
.
OrderedDict
()
# Maps FunctionContext to a TypeDispatchTable containing FunctionTypes of
# that particular context.
self
.
_dispatch_dict
=
{}
# Maps a FunctionCacheKey K to a FunctionCacheKey V such that it is safe
# to dispatch K to the concrete function of V that exists in _primary.
# Used to lookup posible concrete functions when K is not in _primary.
self
.
_dispatch_table
=
type_dispatch
.
TypeDispatchTable
()
def
lookup
(
self
,
context
:
FunctionContext
,
function_type
:
function_type_lib
.
FunctionType
)
->
Optional
[
Any
]:
"""Looks up a concrete function based on the context and type."""
if
context
in
self
.
_dispatch_dict
:
dispatch_type
=
self
.
_dispatch_dict
[
context
].
dispatch
(
function_type
)
if
dispatch_type
:
return
self
.
_primary
[(
context
,
dispatch_type
)]
# Note: Instead of returning any viable function, we can return the most
# specfic one by maintaining trees of traces where children are more specific
# traces of their parents.
def
lookup
(
self
,
key
:
FunctionCacheKey
,
use_function_subtyping
:
bool
):
"""Looks up a concrete function based on the key."""
if
not
use_function_subtyping
:
return
self
.
_primary
.
get
(
key
,
None
)
dispatch_key
=
self
.
_dispatch_table
.
dispatch
(
key
)
if
dispatch_key
is
not
None
:
return
self
.
_primary
[
dispatch_key
]
return
None
def
delete
(
self
,
context
:
FunctionContext
,
function_type
:
function_type_lib
.
FunctionType
)
->
bool
:
"""Deletes a concrete function given the context and type."""
if
(
context
,
function_type
)
not
in
self
.
_primary
:
def
delete
(
self
,
key
:
FunctionCacheKey
):
"""Deletes a concrete function given the key it was added with."""
if
key
not
in
self
.
_primary
:
return
False
del
self
.
_primary
[
(
context
,
function_type
)
]
self
.
_dispatch_
dict
[
context
].
delete
(
function_type
)
del
self
.
_primary
[
key
]
self
.
_dispatch_
table
.
delete
(
key
)
return
True
def
add
(
self
,
context
:
FunctionContext
,
function_type
:
function_type_lib
.
FunctionType
,
deletion_observer
:
trace_type
.
WeakrefDeletionObserver
,
concrete_fn
:
Any
):
def
add
(
self
,
key
:
FunctionCacheKey
,
deletion_observer
:
trace_type
.
WeakrefDeletionObserver
,
concrete
:
...):
"""Adds a new concrete function alongside its key.
Args:
context: A FunctionContext representing the current context.
function_type: A FunctionType representing concrete_fn signature.
deletion_observer: A WeakrefDeletionObserver for the concrete_fn validity.
concrete_fn: The concrete function to be added to the cache.
key: A FunctionCacheKey object corresponding to the provided `concrete`.
deletion_observer: A WeakrefDeletionObserver object for the `key`.
concrete: The concrete function to be added to the cache.
"""
self
.
_primary
[(
context
,
function_type
)]
=
concrete_fn
if
context
not
in
self
.
_dispatch_dict
:
self
.
_dispatch_dict
[
context
]
=
type_dispatch
.
TypeDispatchTable
()
self
.
_dispatch_dict
[
context
].
add_target
(
function_type
)
listener_fn
=
(
lambda
:
self
.
delete
(
context
,
function_type
)
)
if
DELETE_WITH_WEAKREF
else
lambda
:
None
deletion_observer
.
add_listener
(
listener_fn
)
def
generalize
(
self
,
context
:
FunctionContext
,
function_type
:
function_type_lib
.
FunctionType
)
->
function_type_lib
.
FunctionType
:
"""Try to generalize a FunctionType within a FunctionContext."""
if
context
in
self
.
_dispatch_dict
:
return
self
.
_dispatch_dict
[
context
].
try_generalizing_function_type
(
function_type
)
else
:
return
function_type
self
.
_primary
[
key
]
=
concrete
self
.
_dispatch_table
.
add_target
(
key
)
deletion_observer
.
add_listener
(
lambda
:
self
.
delete
(
key
)
if
DELETE_WITH_WEAKREF
else
None
)
def
generalize
(
self
,
key
:
FunctionCacheKey
)
->
FunctionCacheKey
:
return
self
.
_dispatch_table
.
try_generalizing_trace_type
(
key
)
# pylint: disable=protected-access
# TODO(b/205971333): Remove this function.
def
clear
(
self
):
"""Removes all concrete functions from the cache."""
self
.
_primary
.
clear
()
self
.
_dispatch_
dict
.
clear
()
self
.
_dispatch_
table
.
clear
()
def
values
(
self
):
"""Returns a list of all `ConcreteFunction` instances held by this cache."""
...
...
tensorflow/core/function/polymorphism/function_cache_test.py
浏览文件 @
8d8abfb5
此差异已折叠。
点击以展开。
tensorflow/core/function/polymorphism/type_dispatch.py
浏览文件 @
8d8abfb5
...
...
@@ -17,7 +17,7 @@
import
collections
from
typing
import
Optional
,
Iterable
from
tensorflow.
core.function.polymorphism
import
function_typ
e
from
tensorflow.
python.types
import
trac
e
# The maximum number of dispatch lookups to cache.
_MAX_DISPATCH_CACHE
=
1024
...
...
@@ -28,9 +28,9 @@ class TypeDispatchTable:
A type dispatch table is a list, L, of target types. Given a request type, R,
the table selects a target type, T, according to the following dispatch rules:
1. R == T or R is su
pertype of T (functions are contravariant on args)
2. There does not exist O in L such that R is su
per
type of O and O is a
su
per
type of T (in other words, T is the closest to R, within list L).
1. R == T or R is su
btype of T
2. There does not exist O in L such that R is su
b
type of O and O is a
su
b
type of T (in other words, T is the closest to R, within list L).
3. If the above two rules are satisfied by multiple targets, the earliest
inserted one is chosen.
"""
...
...
@@ -46,19 +46,19 @@ class TypeDispatchTable:
# Does not contain exact matches, i.e, if cache[a] is b then a is not b.
self
.
_dispatch_cache
=
collections
.
OrderedDict
()
def
add_target
(
self
,
target
:
function_type
.
Function
Type
)
->
None
:
def
add_target
(
self
,
target
:
trace
.
Trace
Type
)
->
None
:
"""Adds a new target type."""
self
.
_dispatch_table
[
target
]
=
None
for
request
in
self
.
_dispatch_cache
:
if
target
.
is_su
per
type_of
(
self
.
_dispatch_cache
[
request
]):
if
target
.
is_su
b
type_of
(
self
.
_dispatch_cache
[
request
]):
self
.
_dispatch_cache
[
request
]
=
target
@
property
def
targets
(
self
)
->
Iterable
[
function_type
.
Function
Type
]:
def
targets
(
self
)
->
Iterable
[
trace
.
Trace
Type
]:
"""Returns an iterable to all targets in the table."""
return
self
.
_dispatch_table
.
keys
()
def
delete
(
self
,
target
:
function_type
.
Function
Type
)
->
None
:
def
delete
(
self
,
target
:
trace
.
Trace
Type
)
->
None
:
"""Deletes a target in the table if it exists."""
if
target
in
self
.
_dispatch_table
:
del
self
.
_dispatch_table
[
target
]
...
...
@@ -72,10 +72,8 @@ class TypeDispatchTable:
self
.
_dispatch_table
.
clear
()
self
.
_dispatch_cache
.
clear
()
def
dispatch
(
self
,
request
:
function_type
.
FunctionType
)
->
Optional
[
function_type
.
FunctionType
]:
"""Returns the most specific supertype target if it exists in the table."""
def
dispatch
(
self
,
request
:
trace
.
TraceType
)
->
Optional
[
trace
.
TraceType
]:
"""Returns the deepest subtype target if it exists in the table."""
# For known exact matches.
if
request
in
self
.
_dispatch_table
:
return
request
...
...
@@ -88,15 +86,15 @@ class TypeDispatchTable:
self
.
_dispatch_cache
[
request
]
=
result
return
result
most_specific_su
per
type
=
None
most_specific_su
b
type
=
None
for
other
in
self
.
_dispatch_table
:
if
request
.
is_su
per
type_of
(
other
):
if
most_specific_su
pertype
is
None
or
other
.
is_super
type_of
(
most_specific_su
per
type
):
most_specific_su
per
type
=
other
if
request
.
is_su
b
type_of
(
other
):
if
most_specific_su
btype
is
None
or
other
.
is_sub
type_of
(
most_specific_su
b
type
):
most_specific_su
b
type
=
other
self
.
_cache_dispatch
(
request
,
most_specific_su
per
type
)
return
most_specific_su
per
type
self
.
_cache_dispatch
(
request
,
most_specific_su
b
type
)
return
most_specific_su
b
type
def
_cache_dispatch
(
self
,
request
,
target
):
"""Caches the dispatch lookup result for a target."""
...
...
@@ -106,26 +104,26 @@ class TypeDispatchTable:
self
.
_dispatch_cache
.
popitem
(
last
=
False
)
self
.
_dispatch_cache
[
request
]
=
target
def
try_generalizing_
function_type
(
self
,
target
:
function_type
.
FunctionType
)
->
function_type
.
Function
Type
:
def
try_generalizing_
trace_type
(
self
,
target
:
trace
.
TraceType
)
->
trace
.
Trace
Type
:
"""Returns a generalized subtype of the one given.
This heuristic aims to reduce the number of future traces by computing a
type that represents more general function inputs.
The original "experimental_relax_shapes" heuristic identified a known type
which shared a common su
b
type with the current unknown type and then
traced with that common su
btype. However, the notion of "common sub
type"
was only limited to shapes. This heuristic extends that to
Function
Type.
which shared a common su
per
type with the current unknown type and then
traced with that common su
pertype. However, the notion of "common super
type"
was only limited to shapes. This heuristic extends that to
Trace
Type.
Returns `target` if a
generalized sub
type can not be found.
Returns `target` if a
common super
type can not be found.
Args:
target: The
Function
Type to generalize
target: The
Trace
Type to generalize
"""
relaxed
=
target
for
other
in
self
.
_dispatch_table
:
su
btype
=
relaxed
.
most_specific_common_sub
type
([
other
])
if
su
b
type
is
not
None
:
relaxed
=
su
b
type
su
pertype
=
relaxed
.
most_specific_common_super
type
([
other
])
if
su
per
type
is
not
None
:
relaxed
=
su
per
type
return
relaxed
tensorflow/core/function/polymorphism/type_dispatch_test.py
浏览文件 @
8d8abfb5
...
...
@@ -16,7 +16,6 @@
from
typing
import
Optional
from
tensorflow.core.function.polymorphism
import
function_type
from
tensorflow.core.function.polymorphism
import
type_dispatch
from
tensorflow.python.platform
import
test
from
tensorflow.python.types
import
trace
...
...
@@ -27,7 +26,7 @@ class MockShape(trace.TraceType):
def
__init__
(
self
,
*
shape
:
Optional
[
int
]):
self
.
shape
=
shape
def
is_subtype_of
(
self
,
other
:
"MockShape"
)
->
bool
:
def
is_subtype_of
(
self
,
other
:
"MockShape"
)
->
bool
:
if
len
(
self
.
shape
)
!=
len
(
other
.
shape
):
return
False
...
...
@@ -57,214 +56,186 @@ class MockShape(trace.TraceType):
return
self
.
shape
==
other
.
shape
def
make_shape_function_type
(
*
shape
):
return
function_type
.
FunctionType
([
function_type
.
Parameter
(
"x"
,
function_type
.
Parameter
.
POSITIONAL_ONLY
,
False
,
MockShape
(
*
shape
))
])
class
TypeDispatchTableTest
(
test
.
TestCase
):
def
testVertical
(
self
):
table
=
type_dispatch
.
TypeDispatchTable
()
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
None
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
None
,
1
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
,
1
))
table
.
add_target
(
make_shape_function_ty
pe
(
1
,
1
,
1
))
table
.
add_target
(
MockSha
pe
(
None
,
None
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
None
,
1
))
table
.
add_target
(
MockSha
pe
(
None
,
1
,
1
))
table
.
add_target
(
MockSha
pe
(
1
,
1
,
1
))
self
.
assertEqual
(
list
(
table
.
targets
),
[
make_shape_function_ty
pe
(
None
,
None
,
None
),
make_shape_function_ty
pe
(
None
,
None
,
1
),
make_shape_function_ty
pe
(
None
,
1
,
1
),
make_shape_function_ty
pe
(
1
,
1
,
1
)
MockSha
pe
(
None
,
None
,
None
),
MockSha
pe
(
None
,
None
,
1
),
MockSha
pe
(
None
,
1
,
1
),
MockSha
pe
(
1
,
1
,
1
)
])
def
testHorizontal
(
self
):
table
=
type_dispatch
.
TypeDispatchTable
()
table
.
add_target
(
make_shape_function_ty
pe
(
1
,))
table
.
add_target
(
make_shape_function_ty
pe
(
1
,
2
))
table
.
add_target
(
make_shape_function_ty
pe
(
1
,
2
,
3
))
table
.
add_target
(
MockSha
pe
(
1
,))
table
.
add_target
(
MockSha
pe
(
1
,
2
))
table
.
add_target
(
MockSha
pe
(
1
,
2
,
3
))
self
.
assertEqual
(
list
(
table
.
targets
),
[
make_shape_function_ty
pe
(
1
,),
make_shape_function_ty
pe
(
1
,
2
),
make_shape_function_ty
pe
(
1
,
2
,
3
)
MockSha
pe
(
1
,),
MockSha
pe
(
1
,
2
),
MockSha
pe
(
1
,
2
,
3
)
])
def
testDuplicateNodes
(
self
):
table
=
type_dispatch
.
TypeDispatchTable
()
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
1
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
2
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
None
))
table
.
add_target
(
MockSha
pe
(
1
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
2
))
table
.
add_target
(
MockSha
pe
(
None
,
None
))
self
.
assertEqual
(
list
(
table
.
targets
),
[
make_shape_function_ty
pe
(
None
,
None
),
make_shape_function_ty
pe
(
1
,
None
),
make_shape_function_ty
pe
(
None
,
2
)
MockSha
pe
(
None
,
None
),
MockSha
pe
(
1
,
None
),
MockSha
pe
(
None
,
2
)
])
def
testDeletion
(
self
):
table
=
type_dispatch
.
TypeDispatchTable
()
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
2
))
table
.
add_target
(
MockSha
pe
(
None
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
1
))
table
.
add_target
(
MockSha
pe
(
None
,
2
))
self
.
assertEqual
(
list
(
table
.
targets
),
[
make_shape_function_ty
pe
(
None
,
None
),
make_shape_function_ty
pe
(
None
,
1
),
make_shape_function_ty
pe
(
None
,
2
)
MockSha
pe
(
None
,
None
),
MockSha
pe
(
None
,
1
),
MockSha
pe
(
None
,
2
)
])
table
.
delete
(
make_shape_function_ty
pe
(
None
,
2
))
# Should remove the target
table
.
delete
(
MockSha
pe
(
None
,
2
))
# Should remove the target
self
.
assertEqual
(
list
(
table
.
targets
),
[
make_shape_function_ty
pe
(
None
,
None
),
make_shape_function_ty
pe
(
None
,
1
),
MockSha
pe
(
None
,
None
),
MockSha
pe
(
None
,
1
),
])
table
.
delete
(
make_shape_function_ty
pe
(
None
,
2
))
# Should have no effect
table
.
delete
(
MockSha
pe
(
None
,
2
))
# Should have no effect
self
.
assertEqual
(
list
(
table
.
targets
),
[
make_shape_function_ty
pe
(
None
,
None
),
make_shape_function_ty
pe
(
None
,
1
),
MockSha
pe
(
None
,
None
),
MockSha
pe
(
None
,
1
),
])
def
testContains
(
self
):
table
=
type_dispatch
.
TypeDispatchTable
()
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
None
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
))
table
.
add_target
(
make_shape_function_ty
pe
(
1
,
1
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
2
,
1
))
table
.
add_target
(
MockSha
pe
(
None
,
None
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
1
))
table
.
add_target
(
MockSha
pe
(
1
,
1
))
table
.
add_target
(
MockSha
pe
(
None
,
2
,
1
))
self
.
assertIn
(
make_shape_function_ty
pe
(
None
,
None
,
None
),
table
.
targets
)
self
.
assertIn
(
make_shape_function_ty
pe
(
None
,
1
),
table
.
targets
)
self
.
assertIn
(
make_shape_function_ty
pe
(
1
,
1
),
table
.
targets
)
self
.
assertIn
(
make_shape_function_ty
pe
(
None
,
2
,
1
),
table
.
targets
)
self
.
assertIn
(
MockSha
pe
(
None
,
None
,
None
),
table
.
targets
)
self
.
assertIn
(
MockSha
pe
(
None
,
1
),
table
.
targets
)
self
.
assertIn
(
MockSha
pe
(
1
,
1
),
table
.
targets
)
self
.
assertIn
(
MockSha
pe
(
None
,
2
,
1
),
table
.
targets
)
self
.
assertNotIn
(
make_shape_function_ty
pe
(
None
,
None
,
1
),
table
.
targets
)
self
.
assertNotIn
(
make_shape_function_ty
pe
(
1
,
None
),
table
.
targets
)
self
.
assertNotIn
(
make_shape_function_ty
pe
(
1
,
2
),
table
.
targets
)
self
.
assertNotIn
(
make_shape_function_ty
pe
(
None
,
2
,
None
),
table
.
targets
)
self
.
assertNotIn
(
MockSha
pe
(
None
,
None
,
1
),
table
.
targets
)
self
.
assertNotIn
(
MockSha
pe
(
1
,
None
),
table
.
targets
)
self
.
assertNotIn
(
MockSha
pe
(
1
,
2
),
table
.
targets
)
self
.
assertNotIn
(
MockSha
pe
(
None
,
2
,
None
),
table
.
targets
)
def
testDispatchExactMatches
(
self
):
table
=
type_dispatch
.
TypeDispatchTable
()
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
None
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
,
2
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
2
,
2
))
table
.
add_target
(
MockSha
pe
(
None
,
None
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
1
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
1
,
2
))
table
.
add_target
(
MockSha
pe
(
None
,
2
,
2
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
None
,
1
,
2
)),
make_shape_function_type
(
None
,
1
,
2
))
table
.
dispatch
(
MockShape
(
None
,
1
,
2
)),
MockShape
(
None
,
1
,
2
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
None
,
1
,
None
)),
make_shape_function_type
(
None
,
1
,
None
))
table
.
dispatch
(
MockShape
(
None
,
1
,
None
)),
MockShape
(
None
,
1
,
None
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_ty
pe
(
None
,
None
,
None
)),
make_shape_function_ty
pe
(
None
,
None
,
None
))
table
.
dispatch
(
MockSha
pe
(
None
,
None
,
None
)),
MockSha
pe
(
None
,
None
,
None
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
None
,
2
,
2
)),
make_shape_function_type
(
None
,
2
,
2
))
table
.
dispatch
(
MockShape
(
None
,
2
,
2
)),
MockShape
(
None
,
2
,
2
))
def
testDispatchMoreSpecific
(
self
):
table
=
type_dispatch
.
TypeDispatchTable
()
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
None
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
,
2
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
2
,
2
))
table
.
add_target
(
MockSha
pe
(
None
,
None
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
1
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
1
,
2
))
table
.
add_target
(
MockSha
pe
(
None
,
2
,
2
))
self
.
assertEqual
(
table
.
dispatch
(
MockShape
(
1
,
1
,
2
)),
MockShape
(
None
,
1
,
2
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
1
,
1
,
2
)),
make_shape_function_type
(
None
,
1
,
2
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
1
,
1
,
3
)),
make_shape_function_type
(
None
,
1
,
None
))
table
.
dispatch
(
MockShape
(
1
,
1
,
3
)),
MockShape
(
None
,
1
,
None
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
1
,
3
,
3
)),
make_shape_function_type
(
None
,
None
,
None
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
1
,
2
,
2
)),
make_shape_function_type
(
None
,
2
,
2
))
table
.
dispatch
(
MockShape
(
1
,
3
,
3
)),
MockShape
(
None
,
None
,
None
))
self
.
assertEqual
(
table
.
dispatch
(
MockShape
(
1
,
2
,
2
)),
MockShape
(
None
,
2
,
2
))
def
testDispatchNoMatches
(
self
):
table
=
type_dispatch
.
TypeDispatchTable
()
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
,
2
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
2
,
2
))
table
.
add_target
(
MockSha
pe
(
None
,
1
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
1
,
2
))
table
.
add_target
(
MockSha
pe
(
None
,
2
,
2
))
self
.
assertIsNone
(
table
.
dispatch
(
make_shape_function_ty
pe
(
1
,
2
)))
self
.
assertIsNone
(
table
.
dispatch
(
make_shape_function_ty
pe
(
1
,
2
,
3
)))
self
.
assertIsNone
(
table
.
dispatch
(
make_shape_function_ty
pe
(
1
,
2
,
3
,
4
)))
self
.
assertIsNone
(
table
.
dispatch
(
MockSha
pe
(
1
,
2
)))
self
.
assertIsNone
(
table
.
dispatch
(
MockSha
pe
(
1
,
2
,
3
)))
self
.
assertIsNone
(
table
.
dispatch
(
MockSha
pe
(
1
,
2
,
3
,
4
)))
def
testDispatchCachedAddUpdates
(
self
):
table
=
type_dispatch
.
TypeDispatchTable
()
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
None
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
None
,
None
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
1
,
1
,
2
)),
make_shape_function_type
(
None
,
None
,
None
))
table
.
dispatch
(
MockShape
(
1
,
1
,
2
)),
MockShape
(
None
,
None
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
1
,
None
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
1
,
1
,
2
)),
make_shape_function_type
(
None
,
1
,
None
))
table
.
dispatch
(
MockShape
(
1
,
1
,
2
)),
MockShape
(
None
,
1
,
None
))
table
.
add_target
(
make_shape_function_type
(
None
,
1
,
2
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
1
,
1
,
2
)),
make_shape_function_type
(
None
,
1
,
2
))
table
.
add_target
(
MockShape
(
None
,
1
,
2
))
self
.
assertEqual
(
table
.
dispatch
(
MockShape
(
1
,
1
,
2
)),
MockShape
(
None
,
1
,
2
))
table
.
add_target
(
make_shape_function_type
(
1
,
1
,
2
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
1
,
1
,
2
)),
make_shape_function_type
(
1
,
1
,
2
))
table
.
add_target
(
MockShape
(
1
,
1
,
2
))
self
.
assertEqual
(
table
.
dispatch
(
MockShape
(
1
,
1
,
2
)),
MockShape
(
1
,
1
,
2
))
def
testDispatchCachedDeleteUpdates
(
self
):
table
=
type_dispatch
.
TypeDispatchTable
()
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
None
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
,
2
))
table
.
add_target
(
make_shape_function_ty
pe
(
1
,
1
,
2
))
table
.
add_target
(
MockSha
pe
(
None
,
None
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
1
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
1
,
2
))
table
.
add_target
(
MockSha
pe
(
1
,
1
,
2
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
1
,
1
,
2
)),
make_shape_function_type
(
1
,
1
,
2
))
self
.
assertEqual
(
table
.
dispatch
(
MockShape
(
1
,
1
,
2
)),
MockShape
(
1
,
1
,
2
))
table
.
delete
(
make_shape_function_type
(
1
,
1
,
2
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
1
,
1
,
2
)),
make_shape_function_type
(
None
,
1
,
2
))
table
.
delete
(
MockShape
(
1
,
1
,
2
))
self
.
assertEqual
(
table
.
dispatch
(
MockShape
(
1
,
1
,
2
)),
MockShape
(
None
,
1
,
2
))
table
.
delete
(
make_shape_function_ty
pe
(
None
,
1
,
2
))
table
.
delete
(
MockSha
pe
(
None
,
1
,
2
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
1
,
1
,
2
)),
make_shape_function_type
(
None
,
1
,
None
))
table
.
dispatch
(
MockShape
(
1
,
1
,
2
)),
MockShape
(
None
,
1
,
None
))
table
.
delete
(
make_shape_function_ty
pe
(
None
,
1
,
None
))
table
.
delete
(
MockSha
pe
(
None
,
1
,
None
))
self
.
assertEqual
(
table
.
dispatch
(
make_shape_function_type
(
1
,
1
,
2
)),
make_shape_function_type
(
None
,
None
,
None
))
table
.
dispatch
(
MockShape
(
1
,
1
,
2
)),
MockShape
(
None
,
None
,
None
))
def
testDispatchCacheOrderingDeterminism
(
self
):
table_1
=
type_dispatch
.
TypeDispatchTable
()
table_1
.
add_target
(
make_shape_function_ty
pe
(
1
,
None
,
None
))
table_1
.
add_target
(
make_shape_function_ty
pe
(
None
,
2
,
None
))
table_1
.
add_target
(
make_shape_function_ty
pe
(
None
,
None
,
3
))
table_1
.
add_target
(
MockSha
pe
(
1
,
None
,
None
))
table_1
.
add_target
(
MockSha
pe
(
None
,
2
,
None
))
table_1
.
add_target
(
MockSha
pe
(
None
,
None
,
3
))
table_2
=
type_dispatch
.
TypeDispatchTable
()
table_2
.
add_target
(
make_shape_function_ty
pe
(
None
,
2
,
None
))
table_2
.
add_target
(
make_shape_function_ty
pe
(
1
,
None
,
None
))
table_2
.
add_target
(
make_shape_function_ty
pe
(
None
,
None
,
3
))
table_2
.
add_target
(
MockSha
pe
(
None
,
2
,
None
))
table_2
.
add_target
(
MockSha
pe
(
1
,
None
,
None
))
table_2
.
add_target
(
MockSha
pe
(
None
,
None
,
3
))
table_3
=
type_dispatch
.
TypeDispatchTable
()
table_3
.
add_target
(
make_shape_function_ty
pe
(
None
,
None
,
3
))
table_3
.
add_target
(
make_shape_function_ty
pe
(
1
,
None
,
None
))
table_3
.
add_target
(
make_shape_function_ty
pe
(
None
,
2
,
None
))
table_3
.
add_target
(
MockSha
pe
(
None
,
None
,
3
))
table_3
.
add_target
(
MockSha
pe
(
1
,
None
,
None
))
table_3
.
add_target
(
MockSha
pe
(
None
,
2
,
None
))
# table_1, table_2, table_3 have the same targets
self
.
assertEqual
(
set
(
table_1
.
targets
),
set
(
table_2
.
targets
))
...
...
@@ -272,43 +243,36 @@ class TypeDispatchTableTest(test.TestCase):
# But they dispatch to the first target they find which does not have any
# more specific viable target.
shape
=
make_shape_function_type
(
1
,
2
,
3
)
self
.
assertEqual
(
table_1
.
dispatch
(
shape
),
make_shape_function_type
(
1
,
None
,
None
))
self
.
assertEqual
(
table_2
.
dispatch
(
shape
),
make_shape_function_type
(
None
,
2
,
None
))
self
.
assertEqual
(
table_3
.
dispatch
(
shape
),
make_shape_function_type
(
None
,
None
,
3
))
shape
=
MockShape
(
1
,
2
,
3
)
self
.
assertEqual
(
table_1
.
dispatch
(
shape
),
MockShape
(
1
,
None
,
None
))
self
.
assertEqual
(
table_2
.
dispatch
(
shape
),
MockShape
(
None
,
2
,
None
))
self
.
assertEqual
(
table_3
.
dispatch
(
shape
),
MockShape
(
None
,
None
,
3
))
def
testGeneralizedExisting
(
self
):
table
=
type_dispatch
.
TypeDispatchTable
()
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
None
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
,
2
))
table
.
add_target
(
MockSha
pe
(
None
,
None
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
1
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
1
,
2
))
self
.
assertEqual
(
table
.
try_generalizing_function_type
(
make_shape_function_type
(
None
,
1
,
3
)),
make_shape_function_type
(
None
,
None
,
None
))
table
.
try_generalizing_trace_type
(
MockShape
(
None
,
1
,
3
)),
MockShape
(
None
,
None
,
None
))
def
testGeneralizedNovel
(
self
):
table
=
type_dispatch
.
TypeDispatchTable
()
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
,
None
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
,
2
))
table
.
add_target
(
MockSha
pe
(
None
,
1
,
None
))
table
.
add_target
(
MockSha
pe
(
None
,
1
,
2
))
self
.
assertEqual
(
table
.
try_generalizing_function_type
(
make_shape_function_type
(
None
,
2
,
3
)),
make_shape_function_type
(
None
,
None
,
None
))
table
.
try_generalizing_trace_type
(
MockShape
(
None
,
2
,
3
)),
MockShape
(
None
,
None
,
None
))
def
testGeneralizedUnknown
(
self
):
table
=
type_dispatch
.
TypeDispatchTable
()
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
1
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
2
))
table
.
add_target
(
make_shape_function_ty
pe
(
None
,
3
))
table
.
add_target
(
MockSha
pe
(
None
,
1
))
table
.
add_target
(
MockSha
pe
(
None
,
2
))
table
.
add_target
(
MockSha
pe
(
None
,
3
))
self
.
assertEqual
(
table
.
try_generalizing_function_type
(
make_shape_function_type
(
None
,
4
,
3
)),
make_shape_function_type
(
None
,
4
,
3
))
table
.
try_generalizing_trace_type
(
MockShape
(
None
,
4
,
3
)),
MockShape
(
None
,
4
,
3
))
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/eager/polymorphic_function/function_context.py
浏览文件 @
8d8abfb5
...
...
@@ -125,21 +125,24 @@ def _enclosing_xla_context():
def
make_cache_key
(
args
:
Any
,
captures
:
Any
=
None
,
)
->
Tuple
[
function_cache
.
FunctionContext
,
function_type
.
FunctionType
,
trace_type
.
WeakrefDeletionObserver
]:
)
->
Tuple
[
function_cache
.
FunctionCacheKey
,
trace_type
.
WeakrefDeletionObserver
]:
"""Computes the cache key given the function arguments."""
if
captures
is
None
:
captures
=
dict
()
signature_context
=
trace_type
.
InternalTracingContext
()
args_signature
=
trace_type
.
from_value
(
args
,
signature_context
)
captures_dict_tracetype
=
trace_type
.
from_value
(
captures
,
signature_context
)
args_signature
=
trace_type
.
from_value
(
args
,
signature_context
)
captures_dict_tracetype
=
trace_type
.
from_value
(
captures
,
signature_context
)
# TODO(fmuham): Use the actual FunctionType
dummy_function_type
=
function_type
.
FunctionType
([
function_type
.
Parameter
(
"args_kwargs"
,
function_type
.
Parameter
.
POSITIONAL_ONLY
,
False
,
function_type
.
Parameter
.
POSITIONAL_ONLY
,
False
,
args_signature
)
],
collections
.
OrderedDict
(
captures_dict_tracetype
.
mapping
))
return
(
make_function_context
(),
dummy_function_type
,
signature_context
.
deletion_observer
)
return
function_cache
.
FunctionCacheKey
(
dummy_function_type
,
make_function_context
()),
signature_context
.
deletion_observer
tensorflow/python/eager/polymorphic_function/tracing_compiler.py
浏览文件 @
8d8abfb5
...
...
@@ -95,8 +95,8 @@ class TracingCompiler:
capture_by_value: Experimental. Whether to capture resource variables by
value or reference. If None, will inherit from a parent context or
default to False.
jit_compile: Force-compile the function with XLA, cf.
tf.function doc on
jit_compile.
jit_compile: Force-compile the function with XLA, cf.
tf.function doc on
jit_compile.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
...
...
@@ -105,7 +105,9 @@ class TracingCompiler:
self
.
_python_function
=
python_function
pure_function
=
attributes
and
monomorphic_function
.
IMPLEMENTS_ATTRIBUTE_NAME
in
attributes
self
.
_function_spec
=
function_spec
.
FunctionSpec
.
from_function_and_signature
(
python_function
,
input_signature
,
is_pure
=
pure_function
)
python_function
,
input_signature
,
is_pure
=
pure_function
)
self
.
_name
=
name
self
.
_autograph
=
autograph
self
.
_autograph_options
=
autograph_options
...
...
@@ -331,10 +333,9 @@ class TracingCompiler:
# cache_key_deletion_observer is useless here. It's based on all captures.
# A new cache key will be built later when saving ConcreteFunction because
# only active captures should be saved.
lookup_func_context
,
lookup_func_type
,
_
=
function_context
.
make_cache_key
(
(
args
,
kwargs
),
captures
)
concrete_function
=
self
.
_function_cache
.
lookup
(
lookup_func_context
,
lookup_func_type
)
lookup_func_key
,
_
=
function_context
.
make_cache_key
((
args
,
kwargs
),
captures
)
concrete_function
=
self
.
_function_cache
.
lookup
(
lookup_func_key
,
True
)
if
concrete_function
is
not
None
:
return
concrete_function
,
filtered_flat_args
...
...
@@ -342,8 +343,7 @@ class TracingCompiler:
with
trace
.
Trace
(
"tf.function-graph_building"
):
logging
.
vlog
(
1
,
"Creating new FuncGraph for Python function %r (key: %r)"
,
self
.
_python_function
,
lookup_func_context
,
lookup_func_type
)
self
.
_python_function
,
lookup_func_key
)
logging
.
vlog
(
2
,
"Python function signature [args: %s] [kwargs: %s]"
,
args
,
kwargs
)
ag_status
=
(
...
...
@@ -352,10 +352,10 @@ class TracingCompiler:
with
ag_ctx
.
ControlStatusCtx
(
status
=
ag_status
,
options
=
self
.
_autograph_options
):
if
self
.
input_signature
is
None
and
self
.
_reduce_retracing
:
general
_func_type
=
self
.
_function_cache
.
generalize
(
lookup_func_
context
,
lookup_func_type
)
placeholder_bound_args
=
general_func_type
.
placeholder_arguments
()
args
,
kwargs
=
placeholder_bound_args
.
args
[
0
]
general
ized_func_key
=
self
.
_function_cache
.
generalize
(
lookup_func_
key
)
# Only get placeholders for arguments, not captures
args
,
kwargs
=
generalized_func_key
.
_placeholder_value
()
# pylint: disable=protected-access
concrete_function
=
self
.
_create_concrete_function
(
args
,
kwargs
)
...
...
@@ -366,10 +366,10 @@ class TracingCompiler:
captures
=
graph_capture_container
.
get_snapshot
()
# Create a cache_key with args and captures
traced_func_
context
,
traced_func_type
,
traced_func_deletion_observer
=
(
traced_func_
key
,
traced_func_deletion_observer
=
(
function_context
.
make_cache_key
((
args
,
kwargs
),
captures
))
self
.
_function_cache
.
add
(
traced_func_
context
,
traced_func_type
,
self
.
_function_cache
.
add
(
traced_func_
key
,
traced_func_deletion_observer
,
concrete_function
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录