Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
24593b1c
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
24593b1c
编写于
10月 10, 2018
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
10月 10, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Adds `get_config` and `from_config` to Optimizers V2.
PiperOrigin-RevId: 216546565
上级
f0225119
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
178 addition
and
8 deletion
+178
-8
tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+6
-4
tensorflow/python/keras/optimizer_v2/adadelta.py
tensorflow/python/keras/optimizer_v2/adadelta.py
+10
-0
tensorflow/python/keras/optimizer_v2/adadelta_test.py
tensorflow/python/keras/optimizer_v2/adadelta_test.py
+17
-0
tensorflow/python/keras/optimizer_v2/adagrad.py
tensorflow/python/keras/optimizer_v2/adagrad.py
+8
-0
tensorflow/python/keras/optimizer_v2/adagrad_test.py
tensorflow/python/keras/optimizer_v2/adagrad_test.py
+13
-0
tensorflow/python/keras/optimizer_v2/adam.py
tensorflow/python/keras/optimizer_v2/adam.py
+10
-0
tensorflow/python/keras/optimizer_v2/adam_test.py
tensorflow/python/keras/optimizer_v2/adam_test.py
+11
-0
tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py
...ow/python/keras/optimizer_v2/checkpointable_utils_test.py
+6
-4
tensorflow/python/keras/optimizer_v2/optimizer_v2.py
tensorflow/python/keras/optimizer_v2/optimizer_v2.py
+36
-0
tensorflow/python/keras/optimizer_v2/rmsprop.py
tensorflow/python/keras/optimizer_v2/rmsprop.py
+11
-0
tensorflow/python/keras/optimizer_v2/rmsprop_test.py
tensorflow/python/keras/optimizer_v2/rmsprop_test.py
+22
-0
tensorflow/python/keras/optimizer_v2/sgd.py
tensorflow/python/keras/optimizer_v2/sgd.py
+14
-0
tensorflow/python/keras/optimizer_v2/sgd_test.py
tensorflow/python/keras/optimizer_v2/sgd_test.py
+14
-0
未找到文件。
tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
浏览文件 @
24593b1c
...
...
@@ -143,10 +143,12 @@ class CheckpointingTests(test.TestCase):
suffix
=
"/.ATTRIBUTES/VARIABLE_VALUE"
expected_checkpoint_names
=
[
name
+
suffix
for
name
in
expected_checkpoint_names
]
# The Dense layers also save get_config() JSON
expected_checkpoint_names
.
extend
(
[
"model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON"
,
"model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"
])
# The optimizer and Dense layers also save get_config() JSON
expected_checkpoint_names
.
extend
([
"optimizer/.ATTRIBUTES/OBJECT_CONFIG_JSON"
,
"model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON"
,
"model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"
])
named_variables
=
{
v
.
name
:
v
for
v
in
named_variables
}
six
.
assertCountEqual
(
self
,
expected_checkpoint_names
,
named_variables
.
keys
())
...
...
tensorflow/python/keras/optimizer_v2/adadelta.py
浏览文件 @
24593b1c
...
...
@@ -37,6 +37,7 @@ class Adadelta(optimizer_v2.OptimizerV2):
Tensor or a Python value.
Arguments:
learning_rate: float hyperparameter >= 0. Learning rate. It is recommended
to leave it at the default value.
rho: float hyperparameter >= 0. The decay rate.
...
...
@@ -114,3 +115,12 @@ class Adadelta(optimizer_v2.OptimizerV2):
grad
,
indices
,
use_locking
=
self
.
_use_locking
)
def
get_config
(
self
):
config
=
super
(
Adadelta
,
self
).
get_config
()
config
.
update
({
"learning_rate"
:
self
.
_serialize_hyperparameter
(
"learning_rate"
),
"rho"
:
self
.
_serialize_hyperparameter
(
"rho"
),
"epsilon"
:
self
.
_serialize_hyperparameter
(
"epsilon"
)
})
return
config
tensorflow/python/keras/optimizer_v2/adadelta_test.py
浏览文件 @
24593b1c
...
...
@@ -22,6 +22,7 @@ import numpy as np
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
ops
from
tensorflow.python.keras.optimizer_v2
import
adadelta
from
tensorflow.python.ops
import
embedding_ops
from
tensorflow.python.ops
import
math_ops
...
...
@@ -161,6 +162,22 @@ class AdadeltaOptimizerTest(test.TestCase):
self
.
assertAllCloseAccordingToType
(
[[
-
111
,
-
138
]],
var0
.
eval
())
def
testConfig
(
self
):
def
rho
():
return
ops
.
convert_to_tensor
(
1.0
)
epsilon
=
ops
.
convert_to_tensor
(
1.0
)
opt
=
adadelta
.
Adadelta
(
learning_rate
=
1.0
,
rho
=
rho
,
epsilon
=
epsilon
)
config
=
opt
.
get_config
()
opt2
=
adadelta
.
Adadelta
.
from_config
(
config
)
self
.
assertEqual
(
opt
.
_hyper
[
"learning_rate"
][
1
],
opt2
.
_hyper
[
"learning_rate"
][
1
])
self
.
assertEqual
(
opt
.
_hyper
[
"rho"
][
1
].
__name__
,
opt2
.
_hyper
[
"rho"
][
1
].
__name__
)
self
.
assertEqual
(
opt
.
_hyper
[
"epsilon"
][
1
],
opt2
.
_hyper
[
"epsilon"
][
1
])
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/keras/optimizer_v2/adagrad.py
浏览文件 @
24593b1c
...
...
@@ -117,3 +117,11 @@ class Adagrad(optimizer_v2.OptimizerV2):
grad
,
indices
,
use_locking
=
self
.
_use_locking
)
def
get_config
(
self
):
config
=
super
(
Adagrad
,
self
).
get_config
()
config
.
update
({
"learning_rate"
:
self
.
_serialize_hyperparameter
(
"learning_rate"
),
"initial_accumulator_value"
:
self
.
_initial_accumulator_value
})
return
config
tensorflow/python/keras/optimizer_v2/adagrad_test.py
浏览文件 @
24593b1c
...
...
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
types
as
python_types
import
numpy
as
np
from
tensorflow.python.framework
import
constant_op
...
...
@@ -271,6 +273,17 @@ class AdagradOptimizerTest(test.TestCase):
# Creating optimizer should cause no exception.
adagrad
.
Adagrad
(
3.0
,
initial_accumulator_value
=
0.1
)
def
testConfig
(
self
):
opt
=
adagrad
.
Adagrad
(
learning_rate
=
lambda
:
ops
.
convert_to_tensor
(
1.0
),
initial_accumulator_value
=
2.0
)
config
=
opt
.
get_config
()
opt2
=
adagrad
.
Adagrad
.
from_config
(
config
)
self
.
assertIsInstance
(
opt2
.
_hyper
[
"learning_rate"
][
1
],
python_types
.
LambdaType
)
self
.
assertEqual
(
opt
.
_initial_accumulator_value
,
opt2
.
_initial_accumulator_value
)
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/keras/optimizer_v2/adam.py
浏览文件 @
24593b1c
...
...
@@ -201,3 +201,13 @@ class Adam(optimizer_v2.OptimizerV2):
update_beta_2
=
beta_2_power
.
assign
(
beta_2_power
*
state
.
get_hyper
(
"beta_2"
),
use_locking
=
self
.
_use_locking
)
return
control_flow_ops
.
group
(
update_beta_1
,
update_beta_2
)
def
get_config
(
self
):
config
=
super
(
Adam
,
self
).
get_config
()
config
.
update
({
"learning_rate"
:
self
.
_serialize_hyperparameter
(
"learning_rate"
),
"beta_1"
:
self
.
_serialize_hyperparameter
(
"beta_1"
),
"beta_2"
:
self
.
_serialize_hyperparameter
(
"beta_2"
),
"epsilon"
:
self
.
_serialize_hyperparameter
(
"epsilon"
)
})
return
config
tensorflow/python/keras/optimizer_v2/adam_test.py
浏览文件 @
24593b1c
...
...
@@ -329,5 +329,16 @@ class AdamOptimizerTest(test.TestCase):
# for v1 and v2 respectively.
self
.
assertEqual
(
6
,
len
(
set
(
opt
.
variables
())))
def
testConfig
(
self
):
opt
=
adam
.
Adam
(
learning_rate
=
1.0
,
beta_1
=
2.0
,
beta_2
=
3.0
,
epsilon
=
4.0
)
config
=
opt
.
get_config
()
opt2
=
adam
.
Adam
.
from_config
(
config
)
self
.
assertEqual
(
opt
.
_hyper
[
"learning_rate"
][
1
],
opt2
.
_hyper
[
"learning_rate"
][
1
])
self
.
assertEqual
(
opt
.
_hyper
[
"beta_1"
][
1
],
opt2
.
_hyper
[
"beta_1"
][
1
])
self
.
assertEqual
(
opt
.
_hyper
[
"beta_2"
][
1
],
opt2
.
_hyper
[
"beta_2"
][
1
])
self
.
assertEqual
(
opt
.
_hyper
[
"epsilon"
][
1
],
opt2
.
_hyper
[
"epsilon"
][
1
])
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py
浏览文件 @
24593b1c
...
...
@@ -143,10 +143,12 @@ class CheckpointingTests(test.TestCase):
suffix
=
"/.ATTRIBUTES/VARIABLE_VALUE"
expected_checkpoint_names
=
[
name
+
suffix
for
name
in
expected_checkpoint_names
]
# The Dense layers also save get_config() JSON
expected_checkpoint_names
.
extend
(
[
"model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON"
,
"model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"
])
# The optimizer and Dense layers also save get_config() JSON
expected_checkpoint_names
.
extend
([
"optimizer/.ATTRIBUTES/OBJECT_CONFIG_JSON"
,
"model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON"
,
"model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"
])
named_variables
=
{
v
.
name
:
v
for
v
in
named_variables
}
six
.
assertCountEqual
(
self
,
expected_checkpoint_names
,
named_variables
.
keys
())
...
...
tensorflow/python/keras/optimizer_v2/optimizer_v2.py
浏览文件 @
24593b1c
...
...
@@ -1319,6 +1319,42 @@ class OptimizerV2(optimizer_v1.Optimizer):
variable
=
variable
,
optional_op_name
=
self
.
_name
)
def
get_config
(
self
):
"""Returns the config of the optimimizer.
An optimizer config is a Python dictionary (serializable)
containing the configuration of an optimizer.
The same optimizer can be reinstantiated later
(without any saved state) from this configuration.
Returns:
Python dictionary.
"""
return
{
"name"
:
self
.
_name
}
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
"""Creates an optimizer from its config.
This method is the reverse of `get_config`,
capable of instantiating the same optimizer from the config
dictionary.
Arguments:
config: A Python dictionary, typically the output of get_config.
custom_objects: A Python dictionary mapping names to additional Python
objects used to create this optimizer, such as a function used for a
hyperparameter.
Returns:
An optimizer instance.
"""
return
cls
(
**
config
)
def
_serialize_hyperparameter
(
self
,
hyperparameter_name
):
"""Serialize a hyperparameter that can be a float, callable, or Tensor."""
return
self
.
_hyper
[
hyperparameter_name
][
1
]
# --------------
# Unsupported parent methods
# --------------
...
...
tensorflow/python/keras/optimizer_v2/rmsprop.py
浏览文件 @
24593b1c
...
...
@@ -237,3 +237,14 @@ class RMSProp(optimizer_v2.OptimizerV2):
grad
,
indices
,
use_locking
=
self
.
_use_locking
)
def
get_config
(
self
):
config
=
super
(
RMSProp
,
self
).
get_config
()
config
.
update
({
"learning_rate"
:
self
.
_serialize_hyperparameter
(
"learning_rate"
),
"rho"
:
self
.
_serialize_hyperparameter
(
"rho"
),
"momentum"
:
self
.
_serialize_hyperparameter
(
"momentum"
),
"epsilon"
:
self
.
_serialize_hyperparameter
(
"epsilon"
),
"centered"
:
self
.
_centered
})
return
config
tensorflow/python/keras/optimizer_v2/rmsprop_test.py
浏览文件 @
24593b1c
...
...
@@ -20,6 +20,7 @@ from __future__ import print_function
import
copy
import
math
import
types
as
python_types
from
absl.testing
import
parameterized
import
numpy
as
np
...
...
@@ -439,6 +440,27 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
(
0.01
*
2.0
/
math
.
sqrt
(
0.90001
*
0.9
+
1e-5
)))
]),
var1
.
eval
())
def
testConfig
(
self
):
def
momentum
():
return
ops
.
convert_to_tensor
(
3.0
)
opt
=
rmsprop
.
RMSProp
(
learning_rate
=
1.0
,
rho
=
2.0
,
momentum
=
momentum
,
epsilon
=
lambda
:
ops
.
convert_to_tensor
(
4.0
),
centered
=
True
)
config
=
opt
.
get_config
()
opt2
=
rmsprop
.
RMSProp
.
from_config
(
config
)
self
.
assertEqual
(
opt
.
_hyper
[
"learning_rate"
][
1
],
opt2
.
_hyper
[
"learning_rate"
][
1
])
self
.
assertEqual
(
opt
.
_hyper
[
"rho"
][
1
],
opt2
.
_hyper
[
"rho"
][
1
])
self
.
assertEqual
(
opt
.
_hyper
[
"momentum"
][
1
].
__name__
,
opt2
.
_hyper
[
"momentum"
][
1
].
__name__
)
self
.
assertIsInstance
(
opt2
.
_hyper
[
"epsilon"
][
1
],
python_types
.
LambdaType
)
self
.
assertEqual
(
True
,
opt2
.
_centered
)
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/keras/optimizer_v2/sgd.py
浏览文件 @
24593b1c
...
...
@@ -168,3 +168,17 @@ class SGD(optimizer_v2.OptimizerV2):
grad
.
values
*
state
.
get_hyper
(
"learning_rate"
,
var
.
dtype
.
base_dtype
),
grad
.
indices
,
grad
.
dense_shape
)
return
var
.
scatter_sub
(
delta
,
use_locking
=
self
.
_use_locking
)
def
get_config
(
self
):
config
=
super
(
SGD
,
self
).
get_config
()
# Control whether momentum variables are created.
if
not
self
.
_use_momentum
:
momentum
=
None
else
:
momentum
=
self
.
_serializer_hyperparameter
(
"momentum"
)
config
.
update
({
"learning_rate"
:
self
.
_serialize_hyperparameter
(
"learning_rate"
),
"momentum"
:
momentum
,
"nesterov"
:
self
.
_use_nesterov
})
return
config
tensorflow/python/keras/optimizer_v2/sgd_test.py
浏览文件 @
24593b1c
...
...
@@ -754,6 +754,20 @@ class MomentumOptimizerTest(test.TestCase):
(
0.9
*
0.01
+
0.01
)
*
2.0
)
]),
var1
.
eval
())
def
testConfig
(
self
):
opt
=
sgd
.
SGD
(
learning_rate
=
1.0
,
momentum
=
2.0
,
nesterov
=
True
)
config
=
opt
.
get_config
()
opt2
=
sgd
.
SGD
.
from_config
(
config
)
self
.
assertEqual
(
opt
.
_hyper
[
"learning_rate"
][
1
],
opt2
.
_hyper
[
"learning_rate"
][
1
])
self
.
assertEqual
(
opt
.
_hyper
[
"momentum"
][
1
],
opt2
.
_hyper
[
"momentum"
][
1
])
self
.
assertEqual
(
opt2
.
_use_nesterov
,
True
)
opt
=
sgd
.
SGD
(
momentum
=
None
)
config
=
opt
.
get_config
()
opt2
=
sgd
.
SGD
.
from_config
(
config
)
self
.
assertEqual
(
False
,
opt2
.
_use_momentum
)
if
__name__
==
"__main__"
:
test
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录