Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wux_labs
Tensorflow
提交
b5cc8c77
T
Tensorflow
项目概览
wux_labs
/
Tensorflow
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b5cc8c77
编写于
10月 13, 2022
作者:
F
Faizan Muhammad
提交者:
TensorFlower Gardener
10月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use FunctionType in FunctionCacheKey
PiperOrigin-RevId: 481006100
上级
e4a52e12
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
138 addition
and
115 deletion
+138
-115
tensorflow/core/function/polymorphism/BUILD
tensorflow/core/function/polymorphism/BUILD
+2
-0
tensorflow/core/function/polymorphism/function_cache.py
tensorflow/core/function/polymorphism/function_cache.py
+28
-31
tensorflow/core/function/polymorphism/function_cache_test.py
tensorflow/core/function/polymorphism/function_cache_test.py
+98
-83
tensorflow/python/eager/polymorphic_function/function_context.py
...low/python/eager/polymorphic_function/function_context.py
+10
-1
未找到文件。
tensorflow/core/function/polymorphism/BUILD
浏览文件 @
b5cc8c77
...
...
@@ -36,6 +36,7 @@ pytype_strict_library(
"//tensorflow/python/eager:__subpackages__"
,
],
deps
=
[
"//tensorflow/core/function/function_type"
,
"//tensorflow/core/function/polymorphism:type_dispatch"
,
"//tensorflow/core/function/trace_type"
,
"//tensorflow/python/types"
,
...
...
@@ -51,6 +52,7 @@ py_strict_test(
visibility
=
[
"//learning/brain/contrib/eager/python/examples:__pkg__"
],
deps
=
[
":function_cache"
,
"//tensorflow/core/function/function_type"
,
"//tensorflow/core/function/trace_type"
,
"//tensorflow/python:array_ops"
,
"//tensorflow/python/eager/polymorphic_function:function_context"
,
...
...
tensorflow/core/function/polymorphism/function_cache.py
浏览文件 @
b5cc8c77
...
...
@@ -15,9 +15,10 @@
"""Cache to manage concrete functions and their signatures."""
import
collections
from
typing
import
Optional
,
Hashable
,
Sequence
,
Any
,
NamedTuple
,
Dict
from
typing
import
Any
,
Dict
,
Hashable
,
NamedTuple
,
Optional
,
Sequence
from
tensorflow.core.function
import
trace_type
from
tensorflow.core.function.function_type
import
function_type
as
function_type_lib
from
tensorflow.core.function.polymorphism
import
type_dispatch
from
tensorflow.python.types
import
trace
...
...
@@ -30,6 +31,7 @@ class FunctionContext(NamedTuple):
context
:
Any
# TODO(fmuham): Move into FunctionType.
class
CaptureSnapshot
(
trace
.
TraceType
):
"""Store tf.function captures to accommodate its specific tracing logic.
...
...
@@ -86,8 +88,7 @@ class CaptureSnapshot(trace.TraceType):
for
key
,
item
in
query
.
mapping
.
items
())
def
most_specific_common_supertype
(
self
,
types
:
Sequence
[
trace
.
TraceType
])
->
Optional
[
"CaptureSnapshot"
]:
self
,
types
:
Sequence
[
trace
.
TraceType
])
->
Optional
[
"CaptureSnapshot"
]:
"""See base class."""
common_keys
=
set
(
self
.
mapping
.
keys
())
for
other
in
types
:
...
...
@@ -118,22 +119,21 @@ class CaptureSnapshot(trace.TraceType):
return
hash
(
frozenset
(
self
.
mapping
.
keys
()))
# TODO(
panzf): Rename `FunctionCacheKey` to `FunctionType`
# TODO(
fmuham): Remove inheritance from TraceType.
class
FunctionCacheKey
(
trace
.
TraceType
):
"""The unique key associated with a concrete function.
Attributes:
args_signature: A Trace
Type corresponding to the function arguments.
function_type: A Function
Type corresponding to the function arguments.
captures_signature: A CaptureSnapshot corresponding to the function
captures.
call_context: The FunctionContext for when the args_signature was
generated.
call_context: The FunctionContext for when the function was called.
"""
def
__init__
(
self
,
args_signature
:
trace
.
Trace
Type
,
def
__init__
(
self
,
function_type
:
function_type_lib
.
Function
Type
,
captures_signature
:
CaptureSnapshot
,
call_context
:
FunctionContext
):
self
.
args_signature
=
args_signatur
e
self
.
function_type
=
function_typ
e
self
.
captures_signature
=
captures_signature
self
.
call_context
=
call_context
...
...
@@ -144,8 +144,8 @@ class FunctionCacheKey(trace.TraceType):
if
self
.
call_context
!=
other
.
call_context
:
return
False
return
(
self
.
args_signature
.
is_subtype_of
(
other
.
args_signature
)
and
self
.
captures_signature
.
is_subtype_of
(
other
.
captures_signature
))
return
(
self
.
function_type
.
is_supertype_of
(
other
.
function_type
)
and
self
.
captures_signature
.
is_subtype_of
(
other
.
captures_signature
))
def
most_specific_common_supertype
(
self
,
others
:
Sequence
[
trace
.
TraceType
])
->
Optional
[
"FunctionCacheKey"
]:
...
...
@@ -154,27 +154,28 @@ class FunctionCacheKey(trace.TraceType):
self
.
call_context
==
other
.
call_context
for
other
in
others
):
return
None
# `args` and `captures` are independent when finding common supertypes.
args_common
=
self
.
args_signature
.
most_specific_common_supertype
(
[
other
.
args_signature
for
other
in
others
])
function_type_common
=
self
.
function_type
.
most_specific_common_subtype
(
[
other
.
function_type
for
other
in
others
])
if
args
_common
is
None
:
if
function_type
_common
is
None
:
return
None
captures_common
=
self
.
captures_signature
.
most_specific_common_supertype
(
[
other
.
captures_signature
for
other
in
others
])
return
FunctionCacheKey
(
args_common
,
captures_common
,
self
.
call_context
)
return
FunctionCacheKey
(
function_type_common
,
captures_common
,
self
.
call_context
)
def
_placeholder_value
(
self
)
->
Any
:
"""Value used for tracing a function signature with this TraceType."""
return
{
"args"
:
self
.
args_signature
.
_placeholder_value
(),
# pylint: disable=protected-access
"captures"
:
self
.
captures_signature
.
_placeholder_value
()}
# pylint: disable=protected-access
return
{
"args"
:
self
.
function_type
.
placeholder_arguments
().
args
[
0
],
"captures"
:
self
.
captures_signature
.
_placeholder_value
()
# pylint: disable=protected-access
}
def
__hash__
(
self
)
->
int
:
return
hash
((
self
.
call_context
,
self
.
args_signature
,
self
.
captures_signature
))
return
hash
(
(
self
.
call_context
,
self
.
function_type
,
self
.
captures_signature
))
def
__eq__
(
self
,
other
)
->
bool
:
if
not
isinstance
(
other
,
trace
.
TraceType
):
...
...
@@ -184,23 +185,20 @@ class FunctionCacheKey(trace.TraceType):
return
False
return
(
self
.
call_context
==
other
.
call_context
and
self
.
args_signature
==
other
.
args_signatur
e
and
self
.
function_type
==
other
.
function_typ
e
and
self
.
captures_signature
==
other
.
captures_signature
)
def
__repr__
(
self
)
->
str
:
return
(
f
"
{
type
(
self
).
__name__
}
(args_signature=
{
repr
(
self
.
args_signature
)
}
,"
f
"(captures_signature=
{
repr
(
self
.
captures_signature
)
}
,"
f
" call_context=
{
repr
(
self
.
call_context
)
}
)"
)
return
(
f
"
{
type
(
self
).
__name__
}
(function_type=
{
repr
(
self
.
function_type
)
}
,"
f
"(captures_signature=
{
repr
(
self
.
captures_signature
)
}
,"
f
" call_context=
{
repr
(
self
.
call_context
)
}
)"
)
# TODO(fmuham): Rename to FunctionLibrary.
class
FunctionCache
:
"""A container for managing concrete functions."""
__slots__
=
[
"_primary"
,
"_dispatch_table"
,
"_garbage_collectors"
]
__slots__
=
[
"_primary"
,
"_dispatch_table"
,
"_garbage_collectors"
]
def
__init__
(
self
):
# The primary cache, mapping FunctionCacheKey to a concrete function.
...
...
@@ -236,8 +234,7 @@ class FunctionCache:
return
True
def
add
(
self
,
key
:
FunctionCacheKey
,
deletion_observer
:
trace_type
.
WeakrefDeletionObserver
,
concrete
):
deletion_observer
:
trace_type
.
WeakrefDeletionObserver
,
concrete
:
...):
"""Adds a new concrete function alongside its key.
Args:
...
...
tensorflow/core/function/polymorphism/function_cache_test.py
浏览文件 @
b5cc8c77
...
...
@@ -19,6 +19,7 @@ import timeit
from
typing
import
Optional
from
tensorflow.core.function
import
trace_type
from
tensorflow.core.function.function_type
import
function_type
from
tensorflow.core.function.polymorphism
import
function_cache
from
tensorflow.python.eager.polymorphic_function
import
function_context
from
tensorflow.python.ops
import
array_ops
...
...
@@ -81,7 +82,7 @@ class MockShape(trace.TraceType):
def
__init__
(
self
,
*
shape
:
Optional
[
int
]):
self
.
shape
=
shape
def
is_subtype_of
(
self
,
other
:
"MockShape"
)
->
bool
:
def
is_subtype_of
(
self
,
other
:
"MockShape"
)
->
bool
:
if
len
(
self
.
shape
)
!=
len
(
other
.
shape
):
return
False
...
...
@@ -112,6 +113,13 @@ class MockEmptyCaptureSnapshot(function_cache.CaptureSnapshot):
self
.
mapping
=
{}
def
make_single_param_type
(
type_constraint
):
return
function_type
.
FunctionType
([
function_type
.
Parameter
(
"x"
,
function_type
.
Parameter
.
POSITIONAL_ONLY
,
False
,
type_constraint
)
])
class
FunctionCacheTest
(
test
.
TestCase
):
def
testConcreteFunctionDictRetainsInsertedKeys
(
self
):
...
...
@@ -158,14 +166,15 @@ class FunctionCacheTest(test.TestCase):
cache
.
delete
(
key_1
)
self
.
assertIsNone
(
cache
.
lookup
(
key_1
,
False
))
key_2
=
function_cache
.
FunctionCacheKey
(
MockSubtypeOf2
(
2
),
MockEmptyCaptureSnapshot
(),
None
)
cache
.
add
(
key_2
,
trace_type
.
WeakrefDeletionObserver
(),
"test_2"
)
key_2
=
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockSubtypeOf2
(
2
)),
MockEmptyCaptureSnapshot
(),
None
)
cache
.
add
(
key_2
,
trace_type
.
WeakrefDeletionObserver
(),
"test_2"
)
self
.
assertEqual
(
cache
.
lookup
(
key_2
,
False
),
"test_2"
)
key_3
=
function_cache
.
FunctionCacheKey
(
MockSubtypeOf2
(
3
),
MockEmptyCaptureSnapshot
(),
None
)
key_3
=
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockSubtypeOf2
(
3
)),
MockEmptyCaptureSnapshot
(),
None
)
self
.
assertEqual
(
cache
.
lookup
(
key_3
,
True
),
"test_2"
)
cache
.
delete
(
key_2
)
...
...
@@ -175,12 +184,12 @@ class FunctionCacheTest(test.TestCase):
def
testFunctionCacheKeyRespectsEquality
(
self
):
ctx
=
function_cache
.
FunctionContext
(
0
)
generic
=
MockGenericType
key_a
=
function_cache
.
FunctionCacheKey
(
generic
(
1
),
MockEmptyCaptureSnapshot
(),
ctx
)
key_b
=
function_cache
.
FunctionCacheKey
(
generic
(
2
),
MockEmptyCaptureSnapshot
(),
ctx
)
key_c
=
function_cache
.
FunctionCacheKey
(
generic
(
1
),
MockEmptyCaptureSnapshot
(),
ctx
)
key_a
=
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
generic
(
1
)),
MockEmptyCaptureSnapshot
(),
ctx
)
key_b
=
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
generic
(
2
)),
MockEmptyCaptureSnapshot
(),
ctx
)
key_c
=
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
generic
(
1
)),
MockEmptyCaptureSnapshot
(),
ctx
)
self
.
assertNotEqual
(
key_a
,
key_b
)
self
.
assertEqual
(
key_a
,
key_c
)
...
...
@@ -188,12 +197,15 @@ class FunctionCacheTest(test.TestCase):
def
testFunctionCacheKeyRespectsSubtype
(
self
):
ctx
=
function_cache
.
FunctionContext
(
0
)
key_a
=
function_cache
.
FunctionCacheKey
(
MockSubtypeOf2
(
1
),
MockEmptyCaptureSnapshot
(),
ctx
)
key_b
=
function_cache
.
FunctionCacheKey
(
MockSubtypeOf2
(
2
),
MockEmptyCaptureSnapshot
(),
ctx
)
key_c
=
function_cache
.
FunctionCacheKey
(
MockSubtypeOf2
(
1
),
MockEmptyCaptureSnapshot
(),
ctx
)
key_a
=
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockSubtypeOf2
(
1
)),
MockEmptyCaptureSnapshot
(),
ctx
)
key_b
=
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockSubtypeOf2
(
2
)),
MockEmptyCaptureSnapshot
(),
ctx
)
key_c
=
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockSubtypeOf2
(
1
)),
MockEmptyCaptureSnapshot
(),
ctx
)
self
.
assertTrue
(
key_a
.
is_subtype_of
(
key_b
))
self
.
assertFalse
(
key_b
.
is_subtype_of
(
key_a
))
...
...
@@ -201,68 +213,74 @@ class FunctionCacheTest(test.TestCase):
def
testFunctionCacheKeyRespectsSupertype
(
self
):
ctx
=
function_cache
.
FunctionContext
(
0
)
key_a
=
function_cache
.
FunctionCacheKey
(
MockSupertypes2With3
(
1
),
MockEmptyCaptureSnapshot
(),
ctx
)
key_b
=
function_cache
.
FunctionCacheKey
(
MockSupertypes2With3
(
2
),
MockEmptyCaptureSnapshot
(),
ctx
)
key_a
=
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockSupertypes2With3
(
1
)),
MockEmptyCaptureSnapshot
(),
ctx
)
key_b
=
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockSupertypes2With3
(
2
)),
MockEmptyCaptureSnapshot
(),
ctx
)
self
.
assertEqual
(
key_b
.
most_specific_common_supertype
([
key_a
]),
function_cache
.
FunctionCacheKey
(
MockSupertypes2With3
(
3
),
MockEmptyCaptureSnapshot
(),
ctx
))
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockSupertypes2With3
(
3
)),
MockEmptyCaptureSnapshot
(),
ctx
))
self
.
assertIsNone
(
key_a
.
most_specific_common_supertype
([
key_b
]))
def
testMostSpecificFunctionCacheKeyIsLookedUp
(
self
):
ctx
=
function_cache
.
FunctionContext
(
0
)
cache
=
function_cache
.
FunctionCache
()
cache
.
add
(
function_cache
.
FunctionCacheKey
(
MockShape
(
1
,
2
,
None
),
MockEmptyCaptureSnapshot
(),
ctx
),
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockShape
(
1
,
2
,
None
)),
MockEmptyCaptureSnapshot
(),
ctx
),
trace_type
.
WeakrefDeletionObserver
(),
"a"
)
cache
.
add
(
function_cache
.
FunctionCacheKey
(
MockShape
(
1
,
2
,
3
),
MockEmptyCaptureSnapshot
(),
ctx
),
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockShape
(
1
,
2
,
3
)),
MockEmptyCaptureSnapshot
(),
ctx
),
trace_type
.
WeakrefDeletionObserver
(),
"b"
)
self
.
assertEqual
(
cache
.
lookup
(
function_cache
.
FunctionCacheKey
(
MockShape
(
1
,
2
,
3
),
MockEmptyCaptureSnapshot
(),
ctx
),
True
),
"b"
)
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockShape
(
1
,
2
,
3
)),
MockEmptyCaptureSnapshot
(),
ctx
),
True
),
"b"
)
def
testFirstMostSpecificFunctionCacheKeyIsLookedUp
(
self
):
ctx
=
function_cache
.
FunctionContext
(
0
)
cache
=
function_cache
.
FunctionCache
()
cache
.
add
(
function_cache
.
FunctionCacheKey
(
MockShape
(
1
,
2
,
None
),
MockEmptyCaptureSnapshot
(),
ctx
),
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockShape
(
1
,
2
,
None
)),
MockEmptyCaptureSnapshot
(),
ctx
),
trace_type
.
WeakrefDeletionObserver
(),
"a"
)
cache
.
add
(
function_cache
.
FunctionCacheKey
(
MockShape
(
1
,
None
,
3
),
MockEmptyCaptureSnapshot
(),
ctx
),
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockShape
(
1
,
None
,
3
)),
MockEmptyCaptureSnapshot
(),
ctx
),
trace_type
.
WeakrefDeletionObserver
(),
"b"
)
self
.
assertEqual
(
cache
.
lookup
(
function_cache
.
FunctionCacheKey
(
MockShape
(
1
,
2
,
3
),
MockEmptyCaptureSnapshot
(),
ctx
),
True
),
"a"
)
make_single_param_type
(
MockShape
(
1
,
2
,
3
)
),
MockEmptyCaptureSnapshot
(),
ctx
),
True
),
"a"
)
def
testMostSpecificFunctionCacheKeyIsOrderAgnostic
(
self
):
ctx
=
function_cache
.
FunctionContext
(
0
)
keys
=
[(
function_cache
.
FunctionCacheKey
(
MockShape
(
1
,
1
,
1
),
MockEmptyCaptureSnapshot
(),
ctx
),
"a"
),
(
function_cache
.
FunctionCacheKey
(
MockShape
(
1
,
None
,
1
),
MockEmptyCaptureSnapshot
(
),
ctx
),
"b"
),
(
function_cache
.
FunctionCacheKey
(
MockShape
(
None
,
None
,
1
),
MockEmptyCaptureSnapshot
(
),
ctx
),
"c"
),
(
function_cache
.
FunctionCacheKey
(
MockShape
(
None
,
None
,
None
),
MockEmptyCaptureSnapshot
(
),
ctx
),
"d"
)]
keys
=
[(
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockShape
(
1
,
1
,
1
)),
MockEmptyCaptureSnapshot
(),
ctx
),
"a"
),
(
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockShape
(
1
,
None
,
1
)
),
MockEmptyCaptureSnapshot
(),
ctx
),
"b"
),
(
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockShape
(
None
,
None
,
1
)
),
MockEmptyCaptureSnapshot
(),
ctx
),
"c"
),
(
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockShape
(
None
,
None
,
None
)
),
MockEmptyCaptureSnapshot
(),
ctx
),
"d"
)]
for
permutation
in
itertools
.
permutations
(
keys
):
cache
=
function_cache
.
FunctionCache
()
...
...
@@ -277,28 +295,24 @@ class FunctionCacheTest(test.TestCase):
self
.
assertEqual
(
cache
.
lookup
(
function_cache
.
FunctionCacheKey
(
MockShape
(
1
,
1
,
1
),
MockEmptyCaptureSnapshot
(),
ctx
),
True
),
"a"
)
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockShape
(
1
,
1
,
1
)),
MockEmptyCaptureSnapshot
(),
ctx
),
True
),
"a"
)
self
.
assertEqual
(
cache
.
lookup
(
function_cache
.
FunctionCacheKey
(
MockShape
(
1
,
2
,
1
),
MockEmptyCaptureSnapshot
(),
ctx
),
True
),
"b"
)
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockShape
(
1
,
2
,
1
)),
MockEmptyCaptureSnapshot
(),
ctx
),
True
),
"b"
)
self
.
assertEqual
(
cache
.
lookup
(
function_cache
.
FunctionCacheKey
(
MockShape
(
2
,
2
,
1
),
MockEmptyCaptureSnapshot
(),
ctx
),
True
),
"c"
)
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockShape
(
2
,
2
,
1
)),
MockEmptyCaptureSnapshot
(),
ctx
),
True
),
"c"
)
self
.
assertEqual
(
cache
.
lookup
(
function_cache
.
FunctionCacheKey
(
MockShape
(
2
,
2
,
2
),
MockEmptyCaptureSnapshot
(),
ctx
),
True
),
"d"
)
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockShape
(
2
,
2
,
2
)),
MockEmptyCaptureSnapshot
(),
ctx
),
True
),
"d"
)
def
testWeakRefDeletionAlsoDeletesConcreteFunction
(
self
):
if
not
function_cache
.
DELETE_WITH_WEAKREF
:
...
...
@@ -447,17 +461,19 @@ class FunctionCacheBenchmark(test.Benchmark):
for
key
in
keys
:
cache
.
add
(
*
key
,
"testing"
)
cache
.
add
(
function_cache
.
FunctionCacheKey
(
MockSubtypeOf2
(
2
),
MockEmptyCaptureSnapshot
(),
None
),
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockSubtypeOf2
(
2
)),
MockEmptyCaptureSnapshot
(),
None
),
trace_type
.
WeakrefDeletionObserver
(),
"testing"
)
cache
.
lookup
(
function_cache
.
FunctionCacheKey
(
MockSubtypeOf2
(
3
),
MockEmptyCaptureSnapshot
(),
None
),
True
)
cache
.
lookup
(
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockSubtypeOf2
(
3
)),
MockEmptyCaptureSnapshot
(),
None
),
True
)
iterations
=
10000
lookup_key
=
function_cache
.
FunctionCacheKey
(
MockSubtypeOf2
(
2
),
MockEmptyCaptureSnapshot
(),
None
)
lookup_key
=
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockSubtypeOf2
(
2
)),
MockEmptyCaptureSnapshot
(),
None
)
subtyping_time
=
timeit
.
timeit
(
lambda
:
cache
.
lookup
(
lookup_key
,
True
),
number
=
iterations
)
...
...
@@ -490,14 +506,15 @@ class FunctionCacheBenchmark(test.Benchmark):
for
key
in
keys
:
cache
.
add
(
*
key
,
"testing"
)
cache
.
add
(
function_cache
.
FunctionCacheKey
(
MockSubtypeOf2
(
3
),
MockEmptyCaptureSnapshot
(),
None
),
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockSubtypeOf2
(
3
)),
MockEmptyCaptureSnapshot
(),
None
),
trace_type
.
WeakrefDeletionObserver
(),
"testing"
)
iterations
=
10000
lookup_key
=
function_cache
.
FunctionCacheKey
(
MockSubtypeOf2
(
2
),
MockEmptyCaptureSnapshot
(),
None
)
lookup_key
=
function_cache
.
FunctionCacheKey
(
make_single_param_type
(
MockSubtypeOf2
(
2
)),
MockEmptyCaptureSnapshot
(),
None
)
subtyping_time
=
sum
(
timeit
.
repeat
(
stmt
=
lambda
:
cache
.
lookup
(
lookup_key
,
True
),
...
...
@@ -539,9 +556,7 @@ class CaptureSnapshotTest(test.TestCase):
"b"
:
MockIntGenericType
(
1
),
"c"
:
MockIntGenericType
(
2
)
})
self
.
snapshot_e
=
snapshot_type
({
"d"
:
MockIntGenericType
(
1
)
})
self
.
snapshot_e
=
snapshot_type
({
"d"
:
MockIntGenericType
(
1
)})
def
testCaptureSnapshotSubtype
(
self
):
self
.
assertFalse
(
self
.
snapshot_a
.
is_subtype_of
(
self
.
snapshot_b
))
...
...
tensorflow/python/eager/polymorphic_function/function_context.py
浏览文件 @
b5cc8c77
...
...
@@ -17,6 +17,7 @@
from
typing
import
Any
,
NamedTuple
,
Tuple
from
tensorflow.core.function
import
trace_type
from
tensorflow.core.function.function_type
import
function_type
from
tensorflow.core.function.polymorphism
import
function_cache
from
tensorflow.python.eager
import
context
from
tensorflow.python.framework
import
device
as
pydev
...
...
@@ -135,7 +136,15 @@ def make_cache_key(
captures_signature
=
function_cache
.
CaptureSnapshot
(
captures_dict_tracetype
.
mapping
)
# TODO(fmuham): Use the actual FunctionType
dummy_function_type
=
function_type
.
FunctionType
([
function_type
.
Parameter
(
"args_kwargs"
,
function_type
.
Parameter
.
POSITIONAL_ONLY
,
False
,
args_signature
)
])
return
function_cache
.
FunctionCacheKey
(
args_signatur
e
,
dummy_function_typ
e
,
captures_signature
,
make_function_context
()),
signature_context
.
deletion_observer
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录