Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
e1799db4
M
Models
项目概览
曾经的那一瞬间
/
Models
11 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e1799db4
编写于
2月 27, 2020
作者:
Y
Yeqing Li
提交者:
A. Unique TensorFlower
2月 27, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
base_config.py in modeling/
PiperOrigin-RevId: 297757527
上级
e6a8db37
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
471 addition
and
11 deletion
+471
-11
official/modeling/hyperparams/base_config.py
official/modeling/hyperparams/base_config.py
+172
-11
official/modeling/hyperparams/base_config_test.py
official/modeling/hyperparams/base_config_test.py
+299
-0
未找到文件。
official/modeling/hyperparams/base_config.py
浏览文件 @
e1799db4
...
...
@@ -21,8 +21,10 @@ from __future__ import division
from
__future__
import
print_function
import
copy
from
typing
import
Any
,
List
,
Mapping
,
Optional
import
functools
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Type
from
absl
import
logging
import
dataclasses
import
tensorflow
as
tf
import
yaml
...
...
@@ -32,31 +34,190 @@ from official.modeling.hyperparams import params_dict
@
dataclasses
.
dataclass
class
Config
(
params_dict
.
ParamsDict
):
"""The base configuration class that supports YAML/JSON based overrides."""
default_params
:
dataclasses
.
InitVar
[
Mapping
[
str
,
Any
]]
=
None
restrictions
:
dataclasses
.
InitVar
[
List
[
str
]]
=
None
"""The base configuration class that supports YAML/JSON based overrides.
* It recursively enforces a whitelist of basic types and container types, so
it avoids surprises with copy and reuse caused by unanticipated types.
* It converts dict to Config even within sequences,
e.g. for config = Config({'key': [([{'a': 42}],)]),
type(config.key[0][0][0]) is Config rather than dict.
"""
# It's safe to add bytes and other immutable types here.
IMMUTABLE_TYPES
=
(
str
,
int
,
float
,
bool
,
type
(
None
))
# It's safe to add set, frozenset and other collections here.
SEQUENCE_TYPES
=
(
list
,
tuple
)
default_params
:
dataclasses
.
InitVar
[
Optional
[
Mapping
[
str
,
Any
]]]
=
None
restrictions
:
dataclasses
.
InitVar
[
Optional
[
List
[
str
]]]
=
None
@
classmethod
def
_isvalidsequence
(
cls
,
v
):
"""Check if the input values are valid sequences.
Args:
v: Input sequence.
Returns:
True if the sequence is valid. Valid sequence includes the sequence
type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
is dict or ParamsDict.
"""
if
not
isinstance
(
v
,
cls
.
SEQUENCE_TYPES
):
return
False
return
(
all
(
isinstance
(
e
,
cls
.
IMMUTABLE_TYPES
)
for
e
in
v
)
or
all
(
isinstance
(
e
,
dict
)
for
e
in
v
)
or
all
(
isinstance
(
e
,
params_dict
.
ParamsDict
)
for
e
in
v
))
@
classmethod
def
_import_config
(
cls
,
v
,
subconfig_type
):
"""Returns v with dicts converted to Configs, recursively."""
if
not
issubclass
(
subconfig_type
,
params_dict
.
ParamsDict
):
raise
TypeError
(
'Subconfig_type should be subclass of ParamsDict, found %r'
,
subconfig_type
)
if
isinstance
(
v
,
cls
.
IMMUTABLE_TYPES
):
return
v
elif
isinstance
(
v
,
cls
.
SEQUENCE_TYPES
):
# Only support one layer of sequence.
if
not
cls
.
_isvalidsequence
(
v
):
raise
TypeError
(
'Invalid sequence: only supports single level {!r} of {!r} or '
'dict or ParamsDict found: {!r}'
.
format
(
cls
.
SEQUENCE_TYPES
,
cls
.
IMMUTABLE_TYPES
,
v
))
import_fn
=
functools
.
partial
(
cls
.
_import_config
,
subconfig_type
=
subconfig_type
)
return
type
(
v
)(
map
(
import_fn
,
v
))
elif
isinstance
(
v
,
params_dict
.
ParamsDict
):
# Deepcopy here is a temporary solution for preserving type in nested
# Config object.
return
copy
.
deepcopy
(
v
)
elif
isinstance
(
v
,
dict
):
return
subconfig_type
(
v
)
else
:
raise
TypeError
(
'Unknown type: %r'
%
type
(
v
))
@
classmethod
def
_export_config
(
cls
,
v
):
"""Returns v with Configs converted to dicts, recursively."""
if
isinstance
(
v
,
cls
.
IMMUTABLE_TYPES
):
return
v
elif
isinstance
(
v
,
cls
.
SEQUENCE_TYPES
):
return
type
(
v
)(
map
(
cls
.
_export_config
,
v
))
elif
isinstance
(
v
,
params_dict
.
ParamsDict
):
return
v
.
as_dict
()
elif
isinstance
(
v
,
dict
):
raise
TypeError
(
'dict value not supported in converting.'
)
else
:
raise
TypeError
(
'Unknown type: {!r}'
.
format
(
type
(
v
)))
@
classmethod
def
_get_subconfig_type
(
cls
,
k
)
->
Type
[
params_dict
.
ParamsDict
]:
"""Get element type by the field name.
Args:
k: the key/name of the field.
Returns:
Config as default. If a type annotation is found for `k`,
1) returns the type of the annotation if it is subtype of ParamsDict;
2) returns the element type if the annotation of `k` is List[SubType]
or Tuple[SubType].
"""
subconfig_type
=
Config
if
k
in
cls
.
__annotations__
:
# Directly Config subtype.
type_annotation
=
cls
.
__annotations__
[
k
]
if
(
isinstance
(
type_annotation
,
type
)
and
issubclass
(
type_annotation
,
Config
)):
subconfig_type
=
cls
.
__annotations__
[
k
]
else
:
# Check if the field is a sequence of subtypes.
field_type
=
getattr
(
type_annotation
,
'__origin__'
,
type
(
None
))
if
(
isinstance
(
field_type
,
type
)
and
issubclass
(
field_type
,
cls
.
SEQUENCE_TYPES
)):
element_type
=
getattr
(
type_annotation
,
'__args__'
,
[
type
(
None
)])[
0
]
subconfig_type
=
(
element_type
if
issubclass
(
element_type
,
params_dict
.
ParamsDict
)
else
subconfig_type
)
return
subconfig_type
def
__post_init__
(
self
,
default_params
,
restrictions
,
*
args
,
**
kwargs
):
logging
.
error
(
'DEBUG before init %r'
,
type
(
self
))
super
().
__init__
(
default_params
=
default_params
,
restrictions
=
restrictions
,
*
args
,
**
kwargs
)
logging
.
error
(
'DEBUG after init %r'
,
type
(
self
))
def
_set
(
self
,
k
,
v
):
"""Overrides same method in ParamsDict.
Also called by ParamsDict methods.
Args:
k: key to set.
v: value.
Raises:
RuntimeError
"""
subconfig_type
=
self
.
_get_subconfig_type
(
k
)
if
isinstance
(
v
,
dict
):
if
k
not
in
self
.
__dict__
:
self
.
__dict__
[
k
]
=
params_dict
.
ParamsDict
(
v
,
[]
)
self
.
__dict__
[
k
]
=
subconfig_type
(
v
)
else
:
self
.
__dict__
[
k
].
override
(
v
)
else
:
self
.
__dict__
[
k
]
=
copy
.
deepcopy
(
v
)
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
def
__setattr__
(
self
,
k
,
v
):
if
k
in
params_dict
.
ParamsDict
.
RESERVED_ATTR
:
# Set the essential private ParamsDict attributes.
self
.
__dict__
[
k
]
=
copy
.
deepcopy
(
v
)
else
:
self
.
_set
(
k
,
v
)
if
k
not
in
self
.
RESERVED_ATTR
:
if
getattr
(
self
,
'_locked'
,
False
):
raise
ValueError
(
'The Config has been locked. '
'No change is allowed.'
)
self
.
_set
(
k
,
v
)
def
_override
(
self
,
override_dict
,
is_strict
=
True
):
"""Overrides same method in ParamsDict.
Also called by ParamsDict methods.
Args:
override_dict: dictionary to write to .
is_strict: If True, not allows to add new keys.
Raises:
KeyError: overriding reserved keys or keys not exist (is_strict=True).
"""
for
k
,
v
in
sorted
(
override_dict
.
items
()):
if
k
in
self
.
RESERVED_ATTR
:
raise
KeyError
(
'The key {!r} is internally reserved. '
'Can not be overridden.'
.
format
(
k
))
if
k
not
in
self
.
__dict__
:
if
is_strict
:
raise
KeyError
(
'The key {!r} does not exist. '
'To extend the existing keys, use '
'`override` with `is_strict` = False.'
.
format
(
k
))
else
:
self
.
_set
(
k
,
v
)
else
:
if
isinstance
(
v
,
dict
):
self
.
__dict__
[
k
].
_override
(
v
,
is_strict
)
# pylint: disable=protected-access
elif
isinstance
(
v
,
params_dict
.
ParamsDict
):
self
.
__dict__
[
k
].
_override
(
v
.
as_dict
(),
is_strict
)
# pylint: disable=protected-access
else
:
self
.
_set
(
k
,
v
)
def
as_dict
(
self
):
"""Returns a dict representation of params_dict.ParamsDict.
For the nested params_dict.ParamsDict, a nested dict will be returned.
"""
return
{
k
:
self
.
_export_config
(
v
)
for
k
,
v
in
self
.
__dict__
.
items
()
if
k
not
in
self
.
RESERVED_ATTR
}
def
replace
(
self
,
**
kwargs
):
"""Like `override`, but returns a copy with the current config unchanged."""
...
...
official/modeling/hyperparams/base_config_test.py
0 → 100644
浏览文件 @
e1799db4
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
pprint
from
typing
import
List
,
Tuple
from
absl.testing
import
parameterized
import
dataclasses
import
tensorflow
as
tf
from
official.modeling.hyperparams
import
base_config
@
dataclasses
.
dataclass
class
DumpConfig1
(
base_config
.
Config
):
a
:
int
=
1
b
:
str
=
'text'
@
dataclasses
.
dataclass
class
DumpConfig2
(
base_config
.
Config
):
c
:
int
=
2
d
:
str
=
'text'
e
:
DumpConfig1
=
DumpConfig1
()
@
dataclasses
.
dataclass
class
DumpConfig3
(
DumpConfig2
):
f
:
int
=
2
g
:
str
=
'text'
h
:
List
[
DumpConfig1
]
=
dataclasses
.
field
(
default_factory
=
lambda
:
[
DumpConfig1
(),
DumpConfig1
()])
g
:
Tuple
[
DumpConfig1
,
...]
=
(
DumpConfig1
(),)
class
BaseConfigTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
assertHasSameTypes
(
self
,
c
,
d
,
msg
=
''
):
"""Checks if a Config has the same structure as a given dict.
Args:
c: the Config object to be check.
d: the reference dict object.
msg: The error message to show when type mismatched.
"""
# Make sure d is not a Config. Assume d is either
# dictionary or primitive type and c is the Config or primitive types.
self
.
assertNotIsInstance
(
d
,
base_config
.
Config
)
if
isinstance
(
d
,
base_config
.
Config
.
IMMUTABLE_TYPES
):
self
.
assertEqual
(
pprint
.
pformat
(
c
),
pprint
.
pformat
(
d
),
msg
=
msg
)
elif
isinstance
(
d
,
base_config
.
Config
.
SEQUENCE_TYPES
):
self
.
assertEqual
(
type
(
c
),
type
(
d
),
msg
=
msg
)
for
i
,
v
in
enumerate
(
d
):
self
.
assertHasSameTypes
(
c
[
i
],
v
,
msg
=
'{}[{!r}]'
.
format
(
msg
,
i
))
elif
isinstance
(
d
,
dict
):
self
.
assertIsInstance
(
c
,
base_config
.
Config
,
msg
=
msg
)
for
k
,
v
in
sorted
(
d
.
items
()):
self
.
assertHasSameTypes
(
getattr
(
c
,
k
),
v
,
msg
=
'{}[{!r}]'
.
format
(
msg
,
k
))
else
:
raise
TypeError
(
'Unknown type: %r'
%
type
(
d
))
def
assertImportExport
(
self
,
v
):
config
=
base_config
.
Config
({
'key'
:
v
})
back
=
config
.
as_dict
()[
'key'
]
self
.
assertEqual
(
pprint
.
pformat
(
back
),
pprint
.
pformat
(
v
))
self
.
assertHasSameTypes
(
config
.
key
,
v
,
msg
=
'=%s v'
%
pprint
.
pformat
(
v
))
def
test_invalid_keys
(
self
):
params
=
base_config
.
Config
()
with
self
.
assertRaises
(
AttributeError
):
_
=
params
.
a
def
test_nested_config_types
(
self
):
config
=
DumpConfig3
()
self
.
assertIsInstance
(
config
.
e
,
DumpConfig1
)
self
.
assertIsInstance
(
config
.
h
[
0
],
DumpConfig1
)
self
.
assertIsInstance
(
config
.
h
[
1
],
DumpConfig1
)
self
.
assertIsInstance
(
config
.
g
[
0
],
DumpConfig1
)
config
.
override
({
'e'
:
{
'a'
:
2
,
'b'
:
'new text'
}})
self
.
assertIsInstance
(
config
.
e
,
DumpConfig1
)
self
.
assertEqual
(
config
.
e
.
a
,
2
)
self
.
assertEqual
(
config
.
e
.
b
,
'new text'
)
config
.
override
({
'h'
:
[{
'a'
:
3
,
'b'
:
'new text 2'
}]})
self
.
assertIsInstance
(
config
.
h
[
0
],
DumpConfig1
)
self
.
assertLen
(
config
.
h
,
1
)
self
.
assertEqual
(
config
.
h
[
0
].
a
,
3
)
self
.
assertEqual
(
config
.
h
[
0
].
b
,
'new text 2'
)
config
.
override
({
'g'
:
[{
'a'
:
4
,
'b'
:
'new text 3'
}]})
self
.
assertIsInstance
(
config
.
g
[
0
],
DumpConfig1
)
self
.
assertLen
(
config
.
g
,
1
)
self
.
assertEqual
(
config
.
g
[
0
].
a
,
4
)
self
.
assertEqual
(
config
.
g
[
0
].
b
,
'new text 3'
)
@
parameterized
.
parameters
(
(
'_locked'
,
"The key '_locked' is internally reserved."
),
(
'_restrictions'
,
"The key '_restrictions' is internally reserved."
),
(
'aa'
,
"The key 'aa' does not exist."
),
)
def
test_key_error
(
self
,
key
,
msg
):
params
=
base_config
.
Config
()
with
self
.
assertRaisesRegex
(
KeyError
,
msg
):
params
.
override
({
key
:
True
})
@
parameterized
.
parameters
(
(
'str data'
,),
(
123
,),
(
1.23
,),
(
None
,),
([
'str'
,
1
,
2.3
,
None
],),
((
'str'
,
1
,
2.3
,
None
),),
)
def
test_import_export_immutable_types
(
self
,
v
):
self
.
assertImportExport
(
v
)
out
=
base_config
.
Config
({
'key'
:
v
})
self
.
assertEqual
(
pprint
.
pformat
(
v
),
pprint
.
pformat
(
out
.
key
))
def
test_override_is_strict_true
(
self
):
params
=
base_config
.
Config
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
'cc'
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c1'
:
'ccc'
}},
is_strict
=
True
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c1
,
'ccc'
)
with
self
.
assertRaises
(
KeyError
):
params
.
override
({
'd'
:
'ddd'
},
is_strict
=
True
)
with
self
.
assertRaises
(
KeyError
):
params
.
override
({
'c'
:
{
'c3'
:
30
}},
is_strict
=
True
)
config
=
base_config
.
Config
({
'key'
:
[{
'a'
:
42
}]})
config
.
override
({
'key'
:
[{
'b'
:
43
}]})
self
.
assertEqual
(
config
.
key
[
0
].
b
,
43
)
with
self
.
assertRaisesRegex
(
AttributeError
,
'The key `a` does not exist'
):
_
=
config
.
key
[
0
].
a
@
parameterized
.
parameters
(
(
lambda
x
:
x
,
'Unknown type'
),
(
object
(),
'Unknown type'
),
(
set
(),
'Unknown type'
),
(
frozenset
(),
'Unknown type'
),
)
def
test_import_unsupport_types
(
self
,
v
,
msg
):
with
self
.
assertRaisesRegex
(
TypeError
,
msg
):
_
=
base_config
.
Config
({
'key'
:
v
})
@
parameterized
.
parameters
(
({
'a'
:
[{
'b'
:
2
,
},
{
'c'
:
3
,
}]
},),
({
'c'
:
[{
'f'
:
1.1
,
},
{
'h'
:
[
1
,
2
],
}]
},),
(({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
,
}
},),),
)
def
test_import_export_nested_structure
(
self
,
d
):
self
.
assertImportExport
(
d
)
@
parameterized
.
parameters
(
([{
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
}],),
(({
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
},),),
)
def
test_import_export_nested_sequences
(
self
,
v
):
self
.
assertImportExport
(
v
)
@
parameterized
.
parameters
(
([([{}],)],),
([[
'str'
,
1
,
2.3
,
None
]],),
(((
'str'
,
1
,
2.3
,
None
),),),
([
(
'str'
,
1
,
2.3
,
None
),
],),
([
(
'str'
,
1
,
2.3
,
None
),
],),
([[{
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
}]],),
([[[{
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
}]]],),
((({
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
},),),),
(((({
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
},),),),),
([({
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
},)],),
(([{
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
}],),),
)
def
test_import_export_unsupport_sequence
(
self
,
v
):
with
self
.
assertRaisesRegex
(
TypeError
,
'Invalid sequence: only supports single level'
):
_
=
base_config
.
Config
({
'key'
:
v
})
def
test_construct_subtype
(
self
):
pass
def
test_import_config
(
self
):
params
=
base_config
.
Config
({
'a'
:
[{
'b'
:
2
},
{
'c'
:
{
'd'
:
3
}}]})
self
.
assertLen
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
a
[
0
].
b
,
2
)
self
.
assertEqual
(
type
(
params
.
a
[
0
]),
base_config
.
Config
)
self
.
assertEqual
(
pprint
.
pformat
(
params
.
a
[
0
].
b
),
'2'
)
self
.
assertEqual
(
type
(
params
.
a
[
1
]),
base_config
.
Config
)
self
.
assertEqual
(
type
(
params
.
a
[
1
].
c
),
base_config
.
Config
)
self
.
assertEqual
(
pprint
.
pformat
(
params
.
a
[
1
].
c
.
d
),
'3'
)
def
test_override
(
self
):
params
=
base_config
.
Config
({
'a'
:
[{
'b'
:
2
},
{
'c'
:
{
'd'
:
3
}}]})
params
.
override
({
'a'
:
[{
'b'
:
4
},
{
'c'
:
{
'd'
:
5
}}]},
is_strict
=
False
)
self
.
assertEqual
(
type
(
params
.
a
),
list
)
self
.
assertEqual
(
type
(
params
.
a
[
0
]),
base_config
.
Config
)
self
.
assertEqual
(
pprint
.
pformat
(
params
.
a
[
0
].
b
),
'4'
)
self
.
assertEqual
(
type
(
params
.
a
[
1
]),
base_config
.
Config
)
self
.
assertEqual
(
type
(
params
.
a
[
1
].
c
),
base_config
.
Config
)
self
.
assertEqual
(
pprint
.
pformat
(
params
.
a
[
1
].
c
.
d
),
'5'
)
@
parameterized
.
parameters
(
([{}],),
(({},),),
)
def
test_config_vs_params_dict
(
self
,
v
):
d
=
{
'key'
:
v
}
self
.
assertEqual
(
type
(
base_config
.
Config
(
d
).
key
[
0
]),
base_config
.
Config
)
self
.
assertEqual
(
type
(
base_config
.
params_dict
.
ParamsDict
(
d
).
key
[
0
]),
dict
)
def
test_ppformat
(
self
):
self
.
assertEqual
(
pprint
.
pformat
([
's'
,
1
,
1.0
,
True
,
None
,
{},
[],
(),
{
(
2
,):
(
3
,
[
4
],
{
6
:
7
,
}),
8
:
9
,
}
]),
"['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]"
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录