Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
0bdd941c
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
0bdd941c
编写于
12月 11, 2018
作者:
Z
Zhenyu Tan
提交者:
TensorFlower Gardener
12月 11, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
expose v2 api for optimizers and migrate away from keras v1 optimizers.
PiperOrigin-RevId: 225102983
上级
dcd966ea
变更
39
隐藏空白更改
内联
并排
Showing
39 changed file
with
942 addition
and
336 deletion
+942
-336
tensorflow/compiler/tf2xla/kernels/training_ops.cc
tensorflow/compiler/tf2xla/kernels/training_ops.cc
+59
-0
tensorflow/compiler/tf2xla/resource_operation_table.cc
tensorflow/compiler/tf2xla/resource_operation_table.cc
+1
-0
tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py
...flow/contrib/distribute/python/keras_optimizer_v2_test.py
+3
-109
tensorflow/contrib/tpu/python/tpu/keras_support.py
tensorflow/contrib/tpu/python/tpu/keras_support.py
+4
-0
tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
+1
-0
tensorflow/python/keras/engine/training.py
tensorflow/python/keras/engine/training.py
+4
-2
tensorflow/python/keras/optimizer_v2/adadelta.py
tensorflow/python/keras/optimizer_v2/adadelta.py
+3
-1
tensorflow/python/keras/optimizer_v2/adadelta_test.py
tensorflow/python/keras/optimizer_v2/adadelta_test.py
+13
-2
tensorflow/python/keras/optimizer_v2/adagrad.py
tensorflow/python/keras/optimizer_v2/adagrad.py
+3
-1
tensorflow/python/keras/optimizer_v2/adagrad_test.py
tensorflow/python/keras/optimizer_v2/adagrad_test.py
+17
-6
tensorflow/python/keras/optimizer_v2/adam.py
tensorflow/python/keras/optimizer_v2/adam.py
+8
-6
tensorflow/python/keras/optimizer_v2/adam_test.py
tensorflow/python/keras/optimizer_v2/adam_test.py
+10
-2
tensorflow/python/keras/optimizer_v2/adamax.py
tensorflow/python/keras/optimizer_v2/adamax.py
+2
-0
tensorflow/python/keras/optimizer_v2/adamax_test.py
tensorflow/python/keras/optimizer_v2/adamax_test.py
+10
-2
tensorflow/python/keras/optimizer_v2/ftrl.py
tensorflow/python/keras/optimizer_v2/ftrl.py
+2
-0
tensorflow/python/keras/optimizer_v2/ftrl_test.py
tensorflow/python/keras/optimizer_v2/ftrl_test.py
+5
-2
tensorflow/python/keras/optimizer_v2/gradient_descent.py
tensorflow/python/keras/optimizer_v2/gradient_descent.py
+9
-7
tensorflow/python/keras/optimizer_v2/gradient_descent_test.py
...orflow/python/keras/optimizer_v2/gradient_descent_test.py
+26
-8
tensorflow/python/keras/optimizer_v2/nadam.py
tensorflow/python/keras/optimizer_v2/nadam.py
+3
-0
tensorflow/python/keras/optimizer_v2/nadam_test.py
tensorflow/python/keras/optimizer_v2/nadam_test.py
+12
-0
tensorflow/python/keras/optimizer_v2/optimizer_v2.py
tensorflow/python/keras/optimizer_v2/optimizer_v2.py
+178
-75
tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
+24
-42
tensorflow/python/keras/optimizer_v2/rmsprop.py
tensorflow/python/keras/optimizer_v2/rmsprop.py
+8
-6
tensorflow/python/keras/optimizer_v2/rmsprop_test.py
tensorflow/python/keras/optimizer_v2/rmsprop_test.py
+19
-4
tensorflow/python/keras/optimizers.py
tensorflow/python/keras/optimizers.py
+0
-7
tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt
...api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt
+37
-4
tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt
.../api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt
+37
-4
tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt
...ols/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt
+37
-4
tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt
...s/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt
+38
-4
tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt
...pi/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt
+36
-3
tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
...pi/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
+37
-4
tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt
...ls/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt
+37
-4
tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt
...api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt
+37
-4
tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt
.../api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt
+37
-4
tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt
...ols/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt
+37
-4
tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt
...s/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt
+38
-4
tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt
...pi/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt
+36
-3
tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
...pi/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
+37
-4
tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt
...ls/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt
+37
-4
未找到文件。
tensorflow/compiler/tf2xla/kernels/training_ops.cc
浏览文件 @
0bdd941c
...
...
@@ -172,6 +172,65 @@ class ResourceApplyMomentum : public XlaOpKernel {
REGISTER_XLA_OP
(
Name
(
"ResourceApplyMomentum"
).
TypeConstraint
(
"T"
,
kFloatTypes
),
ResourceApplyMomentum
);
class
ResourceApplyKerasMomentum
:
public
XlaOpKernel
{
public:
explicit
ResourceApplyKerasMomentum
(
OpKernelConstruction
*
ctx
)
:
XlaOpKernel
(
ctx
)
{
OP_REQUIRES_OK
(
ctx
,
ctx
->
GetAttr
(
"use_nesterov"
,
&
use_nesterov_
));
}
void
Compile
(
XlaOpKernelContext
*
ctx
)
override
{
DataType
type
=
ctx
->
input_type
(
2
);
TensorShape
var_shape
,
accum_shape
;
xla
::
XlaOp
var
,
accum
;
OP_REQUIRES_OK
(
ctx
,
ctx
->
ReadVariableInput
(
0
,
type
,
&
var_shape
,
&
var
));
OP_REQUIRES_OK
(
ctx
,
ctx
->
ReadVariableInput
(
1
,
type
,
&
accum_shape
,
&
accum
));
OP_REQUIRES
(
ctx
,
var_shape
.
IsSameSize
(
accum_shape
),
errors
::
InvalidArgument
(
"var and accum do not have the same shape"
,
var_shape
.
DebugString
(),
" "
,
accum_shape
.
DebugString
()));
TensorShape
lr_shape
=
ctx
->
InputShape
(
2
);
OP_REQUIRES
(
ctx
,
TensorShapeUtils
::
IsScalar
(
lr_shape
),
errors
::
InvalidArgument
(
"lr is not a scalar: "
,
lr_shape
.
DebugString
()));
TensorShape
grad_shape
=
ctx
->
InputShape
(
3
);
OP_REQUIRES
(
ctx
,
var_shape
.
IsSameSize
(
grad_shape
),
errors
::
InvalidArgument
(
"var and grad do not have the same shape"
,
var_shape
.
DebugString
(),
" "
,
grad_shape
.
DebugString
()));
TensorShape
momentum_shape
=
ctx
->
InputShape
(
4
);
OP_REQUIRES
(
ctx
,
TensorShapeUtils
::
IsScalar
(
momentum_shape
),
errors
::
InvalidArgument
(
"momentum is not a scalar: "
,
momentum_shape
.
DebugString
()));
xla
::
XlaOp
lr
=
ctx
->
Input
(
2
);
xla
::
XlaOp
grad
=
ctx
->
Input
(
3
);
xla
::
XlaOp
momentum
=
ctx
->
Input
(
4
);
accum
=
accum
*
momentum
-
grad
*
lr
;
if
(
use_nesterov_
)
{
// See https://github.com/tensorflow/tensorflow/pull/2798 for an
// explanation of the reparameterization used here.
var
=
var
+
accum
*
momentum
-
grad
*
lr
;
}
else
{
var
=
var
+
accum
;
}
OP_REQUIRES_OK
(
ctx
,
ctx
->
AssignVariable
(
0
,
type
,
var
));
OP_REQUIRES_OK
(
ctx
,
ctx
->
AssignVariable
(
1
,
type
,
accum
));
}
private:
bool
use_nesterov_
;
};
REGISTER_XLA_OP
(
Name
(
"ResourceApplyKerasMomentum"
).
TypeConstraint
(
"T"
,
kFloatTypes
),
ResourceApplyKerasMomentum
);
class
ResourceApplyAdagrad
:
public
XlaOpKernel
{
public:
explicit
ResourceApplyAdagrad
(
OpKernelConstruction
*
ctx
)
:
XlaOpKernel
(
ctx
)
{}
...
...
tensorflow/compiler/tf2xla/resource_operation_table.cc
浏览文件 @
0bdd941c
...
...
@@ -65,6 +65,7 @@ CreateResourceOpInfoMap() {
add
(
"ResourceApplyFtrlV2"
,
kReadWrite
,
kVariable
);
add
(
"ResourceApplyGradientDescent"
,
kReadWrite
,
kVariable
);
add
(
"ResourceApplyMomentum"
,
kReadWrite
,
kVariable
);
add
(
"ResourceApplyKerasMomentum"
,
kReadWrite
,
kVariable
);
add
(
"ResourceApplyPowerSign"
,
kReadWrite
,
kVariable
);
add
(
"ResourceApplyProximalAdagrad"
,
kReadWrite
,
kVariable
);
add
(
"ResourceApplyProximalGradientDescent"
,
kReadWrite
,
kVariable
);
...
...
tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py
浏览文件 @
0bdd941c
...
...
@@ -18,24 +18,12 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
shutil
import
tempfile
from
absl.testing
import
parameterized
import
numpy
as
np
import
six
from
tensorflow.contrib.distribute.python
import
combinations
from
tensorflow.core.protobuf
import
config_pb2
from
tensorflow.python
import
keras
from
tensorflow.python.data.ops
import
dataset_ops
from
tensorflow.python.distribute
import
distribution_strategy_context
as
ds_context
from
tensorflow.python.estimator
import
run_config
from
tensorflow.python.estimator
import
training
from
tensorflow.python.estimator.canned
import
dnn_linear_combined
from
tensorflow.python.estimator.canned
import
prediction_keys
from
tensorflow.python.estimator.export
import
export
from
tensorflow.python.estimator.inputs
import
numpy_io
from
tensorflow.python.feature_column
import
feature_column_lib
as
feature_column
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
ops
...
...
@@ -44,103 +32,7 @@ from tensorflow.python.keras.optimizer_v2 import gradient_descent
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
variable_scope
from
tensorflow.python.ops
import
variables
from
tensorflow.python.platform
import
gfile
from
tensorflow.python.platform
import
test
from
tensorflow.python.summary.writer
import
writer_cache
class
KerasOptimizerV2IntegrationTest
(
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
self
.
_model_dir
=
tempfile
.
mkdtemp
()
def
dataset_input_fn
(
self
,
x
,
y
,
batch_size
):
def
input_fn
():
dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
((
x
,
y
))
dataset
=
dataset
.
repeat
(
1
).
batch
(
batch_size
)
return
dataset
return
input_fn
@
combinations
.
generate
(
combinations
.
combine
(
mode
=
[
'graph'
],
distribution
=
[
combinations
.
one_device_strategy
,
combinations
.
mirrored_strategy_with_gpu_and_cpu
,
combinations
.
mirrored_strategy_with_two_gpus
,
combinations
.
core_mirrored_strategy_with_gpu_and_cpu
,
combinations
.
core_mirrored_strategy_with_two_gpus
],
use_train_and_evaluate
=
[
True
,
False
]))
def
test_complete_flow_with_mode
(
self
,
distribution
,
use_train_and_evaluate
):
label_dimension
=
2
input_dimension
=
label_dimension
batch_size
=
10
data
=
np
.
linspace
(
0.
,
2.
,
batch_size
*
label_dimension
,
dtype
=
np
.
float32
)
data
=
data
.
reshape
(
batch_size
,
label_dimension
)
train_input_fn
=
self
.
dataset_input_fn
(
x
=
{
'x'
:
data
},
y
=
data
,
batch_size
=
batch_size
//
distribution
.
num_replicas_in_sync
)
eval_input_fn
=
self
.
dataset_input_fn
(
x
=
{
'x'
:
data
},
y
=
data
,
batch_size
=
batch_size
//
distribution
.
num_replicas_in_sync
)
predict_input_fn
=
numpy_io
.
numpy_input_fn
(
x
=
{
'x'
:
data
},
batch_size
=
batch_size
,
shuffle
=
False
)
linear_feature_columns
=
[
feature_column
.
numeric_column
(
'x'
,
shape
=
(
input_dimension
,))
]
dnn_feature_columns
=
[
feature_column
.
numeric_column
(
'x'
,
shape
=
(
input_dimension
,))
]
feature_columns
=
linear_feature_columns
+
dnn_feature_columns
session_config
=
config_pb2
.
ConfigProto
(
log_device_placement
=
True
,
allow_soft_placement
=
True
)
estimator
=
dnn_linear_combined
.
DNNLinearCombinedRegressor
(
linear_feature_columns
=
linear_feature_columns
,
dnn_hidden_units
=
(
2
,
2
),
dnn_feature_columns
=
dnn_feature_columns
,
label_dimension
=
label_dimension
,
model_dir
=
self
.
_model_dir
,
dnn_optimizer
=
adam
.
Adam
(
0.001
),
linear_optimizer
=
adam
.
Adam
(
0.001
),
config
=
run_config
.
RunConfig
(
train_distribute
=
distribution
,
eval_distribute
=
distribution
,
session_config
=
session_config
))
num_steps
=
2
if
use_train_and_evaluate
:
scores
,
_
=
training
.
train_and_evaluate
(
estimator
,
training
.
TrainSpec
(
train_input_fn
,
max_steps
=
num_steps
),
training
.
EvalSpec
(
eval_input_fn
))
else
:
estimator
.
train
(
train_input_fn
,
steps
=
num_steps
)
scores
=
estimator
.
evaluate
(
eval_input_fn
)
self
.
assertIn
(
'loss'
,
six
.
iterkeys
(
scores
))
predictions
=
np
.
array
([
x
[
prediction_keys
.
PredictionKeys
.
PREDICTIONS
]
for
x
in
estimator
.
predict
(
predict_input_fn
)
])
self
.
assertAllEqual
((
batch_size
,
label_dimension
),
predictions
.
shape
)
feature_spec
=
feature_column
.
make_parse_example_spec
(
feature_columns
)
serving_input_receiver_fn
=
export
.
build_parsing_serving_input_receiver_fn
(
feature_spec
)
export_dir
=
estimator
.
export_savedmodel
(
tempfile
.
mkdtemp
(),
serving_input_receiver_fn
)
self
.
assertTrue
(
gfile
.
Exists
(
export_dir
))
def
tearDown
(
self
):
if
self
.
_model_dir
:
writer_cache
.
FileWriterCache
.
clear
()
shutil
.
rmtree
(
self
.
_model_dir
)
def
get_model
():
...
...
@@ -162,7 +54,9 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase):
var
=
variables
.
Variable
(
2.0
,
name
=
'var'
,
aggregation
=
variable_scope
.
VariableAggregation
.
SUM
)
# grad for cpu is 1, grad for gpu is 2, avg grad is 1.5.
loss
=
math_ops
.
cast
(
_replica_id
()
+
1
,
dtype
=
dtypes
.
float32
)
*
var
def
loss
():
return
math_ops
.
cast
(
_replica_id
()
+
1
,
dtype
=
dtypes
.
float32
)
*
var
optimizer
=
adam
.
Adam
(
learning_rate
=
0.01
,
beta_1
=
0.2
,
beta_2
=
0.2
)
train_op
=
optimizer
.
minimize
(
loss
,
var_list
=
[
var
])
m
=
optimizer
.
get_slot
(
var
,
'm'
)
...
...
tensorflow/contrib/tpu/python/tpu/keras_support.py
浏览文件 @
0bdd941c
...
...
@@ -2069,6 +2069,8 @@ class KerasTPUModel(models.Model):
# tpu_model may not be compiled, e.g., loading weights and then predict.
return
for
k
,
v
in
six
.
iteritems
(
cpu_optimizer_config
):
if
k
==
'name'
:
continue
opt_var
=
getattr
(
self
.
_tpu_model
.
optimizer
,
k
)
if
isinstance
(
opt_var
,
variables
.
Variable
):
logging
.
info
(
'CPU -> TPU %s: %s {%s}'
,
k
,
v
,
K
.
get_value
(
opt_var
))
...
...
@@ -2097,6 +2099,8 @@ class KerasTPUModel(models.Model):
self
.
_cpu_model
.
set_weights
(
tpu_weights
)
for
k
,
v
in
six
.
iteritems
(
tpu_optimizer_config
):
logging
.
info
(
'TPU -> CPU %s: %s'
,
k
,
v
)
if
k
==
'name'
:
continue
opt_var
=
getattr
(
self
.
cpu_optimizer
,
k
)
if
isinstance
(
opt_var
,
variables
.
Variable
):
K
.
get_session
().
run
(
opt_var
.
assign
(
v
))
...
...
tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
浏览文件 @
0bdd941c
...
...
@@ -69,6 +69,7 @@ class ReplicatedVariable(object):
def
__init__
(
self
,
name
,
variables
):
self
.
_name
=
name
self
.
_primary_var
=
variables
[
0
]
self
.
_common_name
=
self
.
_primary_var
.
name
.
split
(
":"
)[
0
]
self
.
_vars
=
variables
self
.
_cached_value
=
None
self
.
_dtype
=
variables
[
0
].
dtype
...
...
tensorflow/python/keras/engine/training.py
浏览文件 @
0bdd941c
...
...
@@ -40,6 +40,7 @@ from tensorflow.python.keras.engine import training_eager
from
tensorflow.python.keras.engine
import
training_generator
from
tensorflow.python.keras.engine
import
training_utils
from
tensorflow.python.keras.engine.network
import
Network
from
tensorflow.python.keras.optimizer_v2
import
optimizer_v2
from
tensorflow.python.keras.utils
import
data_utils
from
tensorflow.python.keras.utils.generic_utils
import
slice_arrays
from
tensorflow.python.keras.utils.losses_utils
import
squeeze_or_expand_dimensions
...
...
@@ -195,8 +196,9 @@ class Model(Network):
# Validate that arguments passed by the user to `compile` are supported by
# DistributionStrategy.
if
distribute
:
if
not
isinstance
(
optimizer
,
(
tf_optimizer_module
.
Optimizer
,
optimizers
.
TFOptimizer
)):
if
not
isinstance
(
optimizer
,
(
tf_optimizer_module
.
Optimizer
,
optimizers
.
TFOptimizer
,
optimizer_v2
.
OptimizerV2
)):
raise
NotImplementedError
(
'optimizer must be an instance of '
'tf.train.Optimizer, not a %s'
%
type
(
optimizer
))
...
...
tensorflow/python/keras/optimizer_v2/adadelta.py
浏览文件 @
0bdd941c
...
...
@@ -22,8 +22,10 @@ import numpy as np
from
tensorflow.python.keras.optimizer_v2
import
optimizer_v2
from
tensorflow.python.training
import
training_ops
from
tensorflow.python.util.tf_export
import
tf_export
@
tf_export
(
'keras.optimizers.Adadelta'
)
class
Adadelta
(
optimizer_v2
.
OptimizerV2
):
r
"""Optimizer that implements the Adadelta algorithm.
...
...
@@ -85,7 +87,7 @@ class Adadelta(optimizer_v2.OptimizerV2):
@end_compatibility
"""
super
(
Adadelta
,
self
).
__init__
(
name
,
**
kwargs
)
self
.
_set_hyper
(
'learning_rate'
,
learning_rate
)
self
.
_set_hyper
(
'learning_rate'
,
kwargs
.
get
(
'lr'
,
learning_rate
)
)
self
.
_set_hyper
(
'decay'
,
self
.
_initial_decay
)
self
.
_set_hyper
(
'rho'
,
rho
)
self
.
_set_hyper
(
'epsilon'
,
epsilon
)
...
...
tensorflow/python/keras/optimizer_v2/adadelta_test.py
浏览文件 @
0bdd941c
...
...
@@ -153,8 +153,11 @@ class AdadeltaOptimizerTest(test.TestCase):
with
self
.
cached_session
():
var0
=
resource_variable_ops
.
ResourceVariable
([[
1.0
,
2.0
]],
dtype
=
dtype
)
x
=
constant_op
.
constant
([[
4.0
],
[
5.0
]],
dtype
=
dtype
)
pred
=
math_ops
.
matmul
(
embedding_ops
.
embedding_lookup
([
var0
],
[
0
]),
x
)
loss
=
pred
*
pred
def
loss
():
pred
=
math_ops
.
matmul
(
embedding_ops
.
embedding_lookup
([
var0
],
[
0
]),
x
)
# pylint: disable=cell-var-from-loop
return
pred
*
pred
sgd_op
=
adadelta
.
Adadelta
(
1.0
,
1.0
,
1.0
).
minimize
(
loss
,
var_list
=
[
var0
])
variables
.
global_variables_initializer
().
run
()
...
...
@@ -165,6 +168,14 @@ class AdadeltaOptimizerTest(test.TestCase):
# Validate updated params
self
.
assertAllCloseAccordingToType
([[
-
111
,
-
138
]],
self
.
evaluate
(
var0
))
def
testConstructAdadeltaWithLR
(
self
):
opt
=
adadelta
.
Adadelta
(
lr
=
1.0
,
rho
=
0.9
,
epsilon
=
1.
)
self
.
assertEqual
(
opt
.
lr
,
1.0
)
opt_2
=
adadelta
.
Adadelta
(
learning_rate
=
0.1
,
rho
=
0.9
,
epsilon
=
1.
,
lr
=
1.0
)
self
.
assertEqual
(
opt_2
.
lr
,
1.0
)
opt_3
=
adadelta
.
Adadelta
(
learning_rate
=
0.1
,
rho
=
0.9
,
epsilon
=
1.
)
self
.
assertEqual
(
opt_3
.
lr
,
0.1
)
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/keras/optimizer_v2/adagrad.py
浏览文件 @
0bdd941c
...
...
@@ -27,8 +27,10 @@ from tensorflow.python.ops import init_ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
resource_variable_ops
from
tensorflow.python.ops
import
state_ops
from
tensorflow.python.util.tf_export
import
tf_export
@
tf_export
(
'keras.optimizers.Adagrad'
)
class
Adagrad
(
optimizer_v2
.
OptimizerV2
):
r
"""Optimizer that implements the Adagrad algorithm.
...
...
@@ -86,7 +88,7 @@ class Adagrad(optimizer_v2.OptimizerV2):
if
epsilon
<
1e-7
:
raise
ValueError
(
'epsilon must be larger than 1e-7: %s'
%
epsilon
)
super
(
Adagrad
,
self
).
__init__
(
name
,
**
kwargs
)
self
.
_set_hyper
(
'learning_rate'
,
learning_rate
)
self
.
_set_hyper
(
'learning_rate'
,
kwargs
.
get
(
'lr'
,
learning_rate
)
)
self
.
_set_hyper
(
'decay'
,
self
.
_initial_decay
)
self
.
_initial_accumulator_value
=
initial_accumulator_value
self
.
_set_hyper
(
'epsilon'
,
epsilon
)
...
...
tensorflow/python/keras/optimizer_v2/adagrad_test.py
浏览文件 @
0bdd941c
...
...
@@ -167,8 +167,11 @@ class AdagradOptimizerTest(test.TestCase):
var0
=
resource_variable_ops
.
ResourceVariable
(
[[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
dtype
)
x
=
constant_op
.
constant
([[
4.0
],
[
5.0
]],
dtype
=
dtype
)
pred
=
math_ops
.
matmul
(
embedding_ops
.
embedding_lookup
([
var0
],
[
0
]),
x
)
loss
=
pred
*
pred
def
loss
():
pred
=
math_ops
.
matmul
(
embedding_ops
.
embedding_lookup
([
var0
],
[
0
]),
x
)
# pylint: disable=cell-var-from-loop
return
pred
*
pred
sgd_op
=
adagrad
.
Adagrad
(
1.0
).
minimize
(
loss
,
var_list
=
[
var0
])
variables
.
global_variables_initializer
().
run
()
# Fetch params to validate initial values
...
...
@@ -297,12 +300,12 @@ class AdagradOptimizerTest(test.TestCase):
with
self
.
cached_session
():
var_repeated
=
resource_variable_ops
.
ResourceVariable
(
[
1.0
,
2.0
],
dtype
=
dtype
)
loss_repeated
=
math_ops
.
reduce_sum
(
embedding_ops
.
embedding_lookup
(
var_repeated
,
[
0
,
0
]))
loss_repeated
=
lambda
:
math_ops
.
reduce_sum
(
# pylint: disable=g-long-lambda
embedding_ops
.
embedding_lookup
(
var_repeated
,
[
0
,
0
]))
# pylint: disable=cell-var-from-loop
var_aggregated
=
resource_variable_ops
.
ResourceVariable
(
[
1.0
,
2.0
],
dtype
=
dtype
)
loss_aggregated
=
2
*
math_ops
.
reduce_sum
(
embedding_ops
.
embedding_lookup
(
var_aggregated
,
[
0
]))
loss_aggregated
=
lambda
:
2
*
math_ops
.
reduce_sum
(
# pylint: disable=g-long-lambda
embedding_ops
.
embedding_lookup
(
var_aggregated
,
[
0
]))
# pylint: disable=cell-var-from-loop
update_op_repeated
=
adagrad
.
Adagrad
(
2.0
).
minimize
(
loss_repeated
,
var_list
=
[
var_repeated
])
update_op_aggregated
=
adagrad
.
Adagrad
(
2.0
).
minimize
(
...
...
@@ -395,6 +398,14 @@ class AdagradOptimizerTest(test.TestCase):
self
.
assertAllCloseAccordingToType
(
var0_np
,
self
.
evaluate
(
var0
))
self
.
assertAllCloseAccordingToType
(
var1_np
,
self
.
evaluate
(
var1
))
def
testConstructAdagradWithLR
(
self
):
opt
=
adagrad
.
Adagrad
(
lr
=
1.0
)
self
.
assertEqual
(
opt
.
lr
,
1.0
)
opt_2
=
adagrad
.
Adagrad
(
learning_rate
=
0.1
,
lr
=
1.0
)
self
.
assertEqual
(
opt_2
.
lr
,
1.0
)
opt_3
=
adagrad
.
Adagrad
(
learning_rate
=
0.1
)
self
.
assertEqual
(
opt_3
.
lr
,
0.1
)
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/keras/optimizer_v2/adam.py
浏览文件 @
0bdd941c
...
...
@@ -24,8 +24,10 @@ from tensorflow.python.ops import math_ops
from
tensorflow.python.ops
import
resource_variable_ops
from
tensorflow.python.ops
import
state_ops
from
tensorflow.python.training
import
training_ops
from
tensorflow.python.util.tf_export
import
tf_export
@
tf_export
(
'keras.optimizers.Adam'
)
class
Adam
(
optimizer_v2
.
OptimizerV2
):
"""Optimizer that implements the Adam algorithm.
...
...
@@ -127,12 +129,12 @@ class Adam(optimizer_v2.OptimizerV2):
"""
super
(
Adam
,
self
).
__init__
(
name
,
**
kwargs
)
self
.
_set_hyper
(
'learning_rate'
,
learning_rate
)
self
.
_set_hyper
(
'learning_rate'
,
kwargs
.
get
(
'lr'
,
learning_rate
)
)
self
.
_set_hyper
(
'decay'
,
self
.
_initial_decay
)
self
.
_set_hyper
(
'beta_1'
,
beta_1
)
self
.
_set_hyper
(
'beta_2'
,
beta_2
)
self
.
_set_hyper
(
'epsilon'
,
epsilon
)
self
.
_
amsgrad
=
amsgrad
self
.
amsgrad
=
amsgrad
def
_create_slots
(
self
,
var_list
):
# Create slots for the first and second moments.
...
...
@@ -141,7 +143,7 @@ class Adam(optimizer_v2.OptimizerV2):
self
.
add_slot
(
var
,
'm'
)
for
var
in
var_list
:
self
.
add_slot
(
var
,
'v'
)
if
self
.
_
amsgrad
:
if
self
.
amsgrad
:
for
var
in
var_list
:
self
.
add_slot
(
var
,
'vhat'
)
...
...
@@ -166,7 +168,7 @@ class Adam(optimizer_v2.OptimizerV2):
local_step
=
math_ops
.
cast
(
self
.
iterations
+
1
,
var_dtype
)
beta_1_power
=
math_ops
.
pow
(
beta_1_t
,
local_step
)
beta_2_power
=
math_ops
.
pow
(
beta_2_t
,
local_step
)
if
not
self
.
_
amsgrad
:
if
not
self
.
amsgrad
:
return
training_ops
.
resource_apply_adam
(
var
.
handle
,
m
.
handle
,
...
...
@@ -220,7 +222,7 @@ class Adam(optimizer_v2.OptimizerV2):
with
ops
.
control_dependencies
([
v_t
]):
v_t
=
self
.
_resource_scatter_add
(
v
,
indices
,
v_scaled_g_values
)
if
not
self
.
_
amsgrad
:
if
not
self
.
amsgrad
:
v_sqrt
=
math_ops
.
sqrt
(
v_t
)
var_update
=
state_ops
.
assign_sub
(
var
,
lr
*
m_t
/
(
v_sqrt
+
epsilon_t
),
use_locking
=
self
.
_use_locking
)
...
...
@@ -251,6 +253,6 @@ class Adam(optimizer_v2.OptimizerV2):
'beta_1'
:
self
.
_serialize_hyperparameter
(
'beta_1'
),
'beta_2'
:
self
.
_serialize_hyperparameter
(
'beta_2'
),
'epsilon'
:
self
.
_serialize_hyperparameter
(
'epsilon'
),
'amsgrad'
:
self
.
_
amsgrad
,
'amsgrad'
:
self
.
amsgrad
,
})
return
config
tensorflow/python/keras/optimizer_v2/adam_test.py
浏览文件 @
0bdd941c
...
...
@@ -162,9 +162,9 @@ class AdamOptimizerTest(test.TestCase):
# it (i.e. they have GPU kernels).
var
=
variables
.
Variable
([[
1.0
],
[
2.0
]])
indices
=
constant_op
.
constant
([
0
,
1
],
dtype
=
index_dtype
)
g
athered_sum
=
math_ops
.
reduce_sum
(
array_ops
.
gather
(
var
,
indices
))
g
_sum
=
lambda
:
math_ops
.
reduce_sum
(
array_ops
.
gather
(
var
,
indices
))
# pylint: disable=cell-var-from-loop
optimizer
=
adam
.
Adam
(
3.0
)
minimize_op
=
optimizer
.
minimize
(
g
athered
_sum
,
var_list
=
[
var
])
minimize_op
=
optimizer
.
minimize
(
g_sum
,
var_list
=
[
var
])
variables
.
global_variables_initializer
().
run
()
minimize_op
.
run
()
...
...
@@ -503,6 +503,14 @@ class AdamOptimizerTest(test.TestCase):
self
.
assertEqual
(
self
.
evaluate
(
keras_v1_iteration
),
self
.
evaluate
(
keras_v2_iteration
))
def
testConstructAdamWithLR
(
self
):
opt
=
adam
.
Adam
(
lr
=
1.0
)
self
.
assertEqual
(
opt
.
lr
,
1.0
)
opt_2
=
adam
.
Adam
(
learning_rate
=
0.1
,
lr
=
1.0
)
self
.
assertEqual
(
opt_2
.
lr
,
1.0
)
opt_3
=
adam
.
Adam
(
learning_rate
=
0.1
)
self
.
assertEqual
(
opt_3
.
lr
,
0.1
)
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/keras/optimizer_v2/adamax.py
浏览文件 @
0bdd941c
...
...
@@ -25,8 +25,10 @@ from tensorflow.python.ops import control_flow_ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
resource_variable_ops
from
tensorflow.python.training
import
training_ops
from
tensorflow.python.util.tf_export
import
tf_export
@
tf_export
(
'keras.optimizers.Adamax'
)
class
Adamax
(
adam
.
Adam
):
"""Optimizer that implements the Adamax algorithm.
...
...
tensorflow/python/keras/optimizer_v2/adamax_test.py
浏览文件 @
0bdd941c
...
...
@@ -136,9 +136,9 @@ class AdamaxOptimizerTest(test.TestCase):
# it (i.e. they have GPU kernels).
var
=
variables
.
Variable
([[
1.0
],
[
2.0
]])
indices
=
constant_op
.
constant
([
0
,
1
],
dtype
=
index_dtype
)
g
athered_sum
=
math_ops
.
reduce_sum
(
array_ops
.
gather
(
var
,
indices
))
g
_sum
=
lambda
:
math_ops
.
reduce_sum
(
array_ops
.
gather
(
var
,
indices
))
# pylint: disable=cell-var-from-loop
optimizer
=
adamax
.
Adamax
(
3.0
)
minimize_op
=
optimizer
.
minimize
(
g
athered
_sum
,
var_list
=
[
var
])
minimize_op
=
optimizer
.
minimize
(
g_sum
,
var_list
=
[
var
])
variables
.
global_variables_initializer
().
run
()
minimize_op
.
run
()
...
...
@@ -362,6 +362,14 @@ class AdamaxOptimizerTest(test.TestCase):
# There should be iteration, and two unique slot variables for v1 and v2.
self
.
assertEqual
(
5
,
len
(
set
(
opt
.
variables
())))
def
testConstructAdamaxWithLR
(
self
):
opt
=
adamax
.
Adamax
(
lr
=
1.0
)
self
.
assertEqual
(
opt
.
lr
,
1.0
)
opt_2
=
adamax
.
Adamax
(
learning_rate
=
0.1
,
lr
=
1.0
)
self
.
assertEqual
(
opt_2
.
lr
,
1.0
)
opt_3
=
adamax
.
Adamax
(
learning_rate
=
0.1
)
self
.
assertEqual
(
opt_3
.
lr
,
0.1
)
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/keras/optimizer_v2/ftrl.py
浏览文件 @
0bdd941c
...
...
@@ -21,8 +21,10 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from
tensorflow.python.ops
import
init_ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.training
import
training_ops
from
tensorflow.python.util.tf_export
import
tf_export
@
tf_export
(
'keras.optimizers.Ftrl'
)
class
Ftrl
(
optimizer_v2
.
OptimizerV2
):
"""Optimizer that implements the FTRL algorithm.
...
...
tensorflow/python/keras/optimizer_v2/ftrl_test.py
浏览文件 @
0bdd941c
...
...
@@ -113,8 +113,11 @@ class FtrlOptimizerTest(test.TestCase):
with
self
.
cached_session
():
var0
=
resource_variable_ops
.
ResourceVariable
([[
1.0
,
2.0
]],
dtype
=
dtype
)
x
=
constant_op
.
constant
([[
4.0
],
[
5.0
]],
dtype
=
dtype
)
pred
=
math_ops
.
matmul
(
embedding_ops
.
embedding_lookup
([
var0
],
[
0
]),
x
)
loss
=
pred
*
pred
def
loss
():
pred
=
math_ops
.
matmul
(
embedding_ops
.
embedding_lookup
([
var0
],
[
0
]),
x
)
# pylint: disable=cell-var-from-loop
return
pred
*
pred
sgd_op
=
ftrl
.
Ftrl
(
1.0
).
minimize
(
loss
,
var_list
=
[
var0
])
variables
.
global_variables_initializer
().
run
()
# Fetch params to validate initial values
...
...
tensorflow/python/keras/optimizer_v2/gradient_descent.py
浏览文件 @
0bdd941c
# Copyright 201
5
The TensorFlow Authors. All Rights Reserved.
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -21,8 +21,10 @@ from tensorflow.python.framework import ops
from
tensorflow.python.keras.optimizer_v2
import
optimizer_v2
from
tensorflow.python.ops
import
resource_variable_ops
from
tensorflow.python.training
import
training_ops
from
tensorflow.python.util.tf_export
import
tf_export
@
tf_export
(
"keras.optimizers.SGD"
)
class
SGD
(
optimizer_v2
.
OptimizerV2
):
"""Stochastic gradient descent and momentum optimizer.
...
...
@@ -32,7 +34,7 @@ class SGD(optimizer_v2.OptimizerV2):
gradient is evaluated at theta(t).
```
or Computes (if `
use_
nesterov = False`):
or Computes (if `nesterov = False`):
```
v(t+1) = momentum * v(t) - learning_rate * gradient
theta(t+1) = theta(t) + v(t+1)
...
...
@@ -75,7 +77,7 @@ class SGD(optimizer_v2.OptimizerV2):
**kwargs: keyword arguments. Allowed to be {`decay`}
"""
super
(
SGD
,
self
).
__init__
(
name
,
**
kwargs
)
self
.
_set_hyper
(
"learning_rate"
,
learning_rate
)
self
.
_set_hyper
(
"learning_rate"
,
kwargs
.
get
(
"lr"
,
learning_rate
)
)
self
.
_set_hyper
(
"decay"
,
self
.
_initial_decay
)
self
.
_momentum
=
False
...
...
@@ -85,7 +87,7 @@ class SGD(optimizer_v2.OptimizerV2):
raise
ValueError
(
"`momentum` must be between [0, 1]."
)
self
.
_set_hyper
(
"momentum"
,
momentum
)
self
.
_
nesterov
=
nesterov
self
.
nesterov
=
nesterov
def
_create_slots
(
self
,
var_list
):
if
self
.
_momentum
:
...
...
@@ -104,7 +106,7 @@ class SGD(optimizer_v2.OptimizerV2):
grad
,
self
.
_get_hyper
(
"momentum"
,
var_dtype
),
use_locking
=
self
.
_use_locking
,
use_nesterov
=
self
.
_
nesterov
)
use_nesterov
=
self
.
nesterov
)
else
:
return
training_ops
.
resource_apply_gradient_descent
(
var
.
handle
,
lr_t
,
grad
,
use_locking
=
self
.
_use_locking
)
...
...
@@ -132,7 +134,7 @@ class SGD(optimizer_v2.OptimizerV2):
indices
,
self
.
_get_hyper
(
"momentum"
,
var_dtype
),
use_locking
=
self
.
_use_locking
,
use_nesterov
=
self
.
_
nesterov
)
use_nesterov
=
self
.
nesterov
)
def
get_config
(
self
):
config
=
super
(
SGD
,
self
).
get_config
()
...
...
@@ -140,6 +142,6 @@ class SGD(optimizer_v2.OptimizerV2):
"learning_rate"
:
self
.
_serialize_hyperparameter
(
"learning_rate"
),
"decay"
:
self
.
_serialize_hyperparameter
(
"decay"
),
"momentum"
:
self
.
_serialize_hyperparameter
(
"momentum"
),
"nesterov"
:
self
.
_
nesterov
,
"nesterov"
:
self
.
nesterov
,
})
return
config
tensorflow/python/keras/optimizer_v2/gradient_descent_test.py
浏览文件 @
0bdd941c
...
...
@@ -122,8 +122,6 @@ class GradientDescentOptimizerTest(test.TestCase):
var1
=
resource_variable_ops
.
ResourceVariable
([
3.0
],
dtype
=
dtype
)
x
=
constant_op
.
constant
([[
4.0
],
[
5.0
]],
dtype
=
dtype
)
loss
=
lambda
:
math_ops
.
matmul
(
var0
,
x
)
+
var1
# pylint: disable=cell-var-from-loop
if
not
context
.
executing_eagerly
():
loss
=
loss
()
sgd
=
gradient_descent
.
SGD
(
1.0
)
sgd_op
=
sgd
.
minimize
(
loss
,
[
var0
,
var1
])
self
.
evaluate
(
variables
.
global_variables_initializer
())
...
...
@@ -141,9 +139,12 @@ class GradientDescentOptimizerTest(test.TestCase):
var0
=
resource_variable_ops
.
ResourceVariable
([[
1.0
,
2.0
]],
dtype
=
dtype
)
var1
=
resource_variable_ops
.
ResourceVariable
([
3.0
],
dtype
=
dtype
)
x
=
constant_op
.
constant
([[
4.0
],
[
5.0
]],
dtype
=
dtype
)
pred
=
math_ops
.
matmul
(
embedding_ops
.
embedding_lookup
([
var0
],
[
0
]),
x
)
pred
+=
var1
loss
=
pred
*
pred
def
loss
():
pred
=
math_ops
.
matmul
(
embedding_ops
.
embedding_lookup
([
var0
],
[
0
]),
x
)
# pylint: disable=cell-var-from-loop
pred
+=
var1
# pylint: disable=cell-var-from-loop
return
pred
*
pred
sgd_op
=
gradient_descent
.
SGD
(
1.0
).
minimize
(
loss
,
[
var0
,
var1
])
self
.
evaluate
(
variables
.
global_variables_initializer
())
# Run 1 step of sgd
...
...
@@ -181,7 +182,8 @@ class GradientDescentOptimizerTest(test.TestCase):
opt
=
gradient_descent
.
SGD
(
3.0
)
values
=
[
1.0
,
3.0
]
vars_
=
[
variables
.
Variable
([
v
],
dtype
=
dtype
)
for
v
in
values
]
grads_and_vars
=
opt
.
compute_gradients
(
vars_
[
0
]
+
vars_
[
1
],
vars_
)
loss
=
lambda
:
vars_
[
0
]
+
vars_
[
1
]
# pylint: disable=cell-var-from-loop
grads_and_vars
=
opt
.
_compute_gradients
(
loss
,
vars_
)
self
.
evaluate
(
variables
.
global_variables_initializer
())
for
grad
,
_
in
grads_and_vars
:
self
.
assertAllCloseAccordingToType
([
1.0
],
self
.
evaluate
(
grad
))
...
...
@@ -259,6 +261,14 @@ class GradientDescentOptimizerTest(test.TestCase):
# be an EagerTensor once again, not a graph Tensor.
self
.
assertEqual
(
float
(
step
()),
-
1.0
)
def
testConstructSGDWithLR
(
self
):
opt
=
gradient_descent
.
SGD
(
lr
=
1.0
)
self
.
assertEqual
(
opt
.
lr
,
1.0
)
opt_2
=
gradient_descent
.
SGD
(
learning_rate
=
0.1
,
lr
=
1.0
)
self
.
assertEqual
(
opt_2
.
lr
,
1.0
)
opt_3
=
gradient_descent
.
SGD
(
learning_rate
=
0.1
)
self
.
assertEqual
(
opt_3
.
lr
,
0.1
)
class
MomentumOptimizerTest
(
test
.
TestCase
):
...
...
@@ -346,7 +356,7 @@ class MomentumOptimizerTest(test.TestCase):
var1_np
=
np
.
array
([
3.0
,
4.0
],
dtype
=
dtype
.
as_numpy_dtype
)
accum0_np
=
np
.
array
([
0.0
,
0.0
],
dtype
=
dtype
.
as_numpy_dtype
)
accum1_np
=
np
.
array
([
0.0
,
0.0
],
dtype
=
dtype
.
as_numpy_dtype
)
loss
=
5
*
var0
*
var0
+
3
*
var1
loss
=
lambda
:
5
*
var0
*
var0
+
3
*
var1
# pylint: disable=cell-var-from-loop
mom_op
=
gradient_descent
.
SGD
(
learning_rate
=
2.0
,
momentum
=
0.9
,
nesterov
=
True
)
opt_op
=
mom_op
.
minimize
(
loss
,
[
var0
,
var1
])
...
...
@@ -677,12 +687,20 @@ class MomentumOptimizerTest(test.TestCase):
opt3
.
_get_hyper
(
"momentum"
))
# self.assertEqual(
# self.evaluate(opt._get_hyper("decay")), opt3._get_hyper("decay"))
self
.
assertTrue
(
opt3
.
_
nesterov
)
self
.
assertTrue
(
opt3
.
nesterov
)
def
testNesterovWithoutMomentum
(
self
):
with
self
.
assertRaisesRegexp
(
ValueError
,
"must be between"
):
gradient_descent
.
SGD
(
learning_rate
=
1.0
,
momentum
=
2.0
)
def
testConstructMomentumWithLR
(
self
):
opt
=
gradient_descent
.
SGD
(
lr
=
1.0
,
momentum
=
0.9
)
self
.
assertEqual
(
opt
.
lr
,
1.0
)
opt_2
=
gradient_descent
.
SGD
(
learning_rate
=
0.1
,
momentum
=
0.9
,
lr
=
1.0
)
self
.
assertEqual
(
opt_2
.
lr
,
1.0
)
opt_3
=
gradient_descent
.
SGD
(
learning_rate
=
0.1
,
momentum
=
0.9
)
self
.
assertEqual
(
opt_3
.
lr
,
0.1
)
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/keras/optimizer_v2/nadam.py
浏览文件 @
0bdd941c
...
...
@@ -74,6 +74,9 @@ class Nadam(adam.Adam):
**kwargs: keyword arguments. Allowed to be {`decay`}
"""
# Backwards compatiblity with keras NAdam optimizer.
if
'schedule_decay'
in
kwargs
:
kwargs
[
'decay'
]
=
kwargs
.
pop
(
'schedule_decay'
)
# pylint: disable=useless-super-delegation
super
(
Nadam
,
self
).
__init__
(
learning_rate
=
learning_rate
,
...
...
tensorflow/python/keras/optimizer_v2/nadam_test.py
浏览文件 @
0bdd941c
...
...
@@ -208,6 +208,18 @@ class NadamOptimizerTest(test.TestCase):
self
.
assertAllCloseAccordingToType
(
var0_np
,
var0
.
eval
())
self
.
assertAllCloseAccordingToType
(
var1_np
,
var1
.
eval
())
def
testConstructNAdamWithLR
(
self
):
opt
=
nadam
.
Nadam
(
lr
=
1.0
)
self
.
assertEqual
(
opt
.
lr
,
1.0
)
opt_2
=
nadam
.
Nadam
(
learning_rate
=
0.1
,
lr
=
1.0
)
self
.
assertEqual
(
opt_2
.
lr
,
1.0
)
opt_3
=
nadam
.
Nadam
(
learning_rate
=
0.1
)
self
.
assertEqual
(
opt_3
.
lr
,
0.1
)
def
testConstructNAdamWithScheduleDecay
(
self
):
opt
=
nadam
.
Nadam
(
schedule_decay
=
0.2
)
self
.
assertEqual
(
opt
.
decay
,
0.2
)
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/keras/optimizer_v2/optimizer_v2.py
浏览文件 @
0bdd941c
...
...
@@ -28,22 +28,45 @@ from tensorflow.python.distribute import distribute_lib
from
tensorflow.python.distribute
import
distribution_strategy_context
as
distribute_ctx
from
tensorflow.python.distribute
import
reduce_util
as
ds_reduce_util
from
tensorflow.python.eager
import
backprop
from
tensorflow.python.eager
import
context
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
ops
from
tensorflow.python.keras
import
backend
from
tensorflow.python.keras
import
initializers
from
tensorflow.python.keras.engine
import
base_layer_utils
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
clip_ops
from
tensorflow.python.ops
import
gradients
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
variables
as
tf_variables
from
tensorflow.python.platform
import
tf_logging
as
logging
from
tensorflow.python.training
import
optimizer
as
optimizer_v1
from
tensorflow.python.training
.checkpointable
import
base
as
checkpointable
from
tensorflow.python.util
import
nest
from
tensorflow.python.util.tf_export
import
tf_export
def
_deduplicate_indexed_slices
(
values
,
indices
):
"""Sums `values` associated with any non-unique `indices`.
Args:
values: A `Tensor` with rank >= 1.
indices: A one-dimensional integer `Tensor`, indexing into the first
dimension of `values` (as in an IndexedSlices object).
Returns:
A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
de-duplicated version of `indices` and `summed_values` contains the sum of
`values` slices associated with each unique index.
"""
unique_indices
,
new_index_positions
=
array_ops
.
unique
(
indices
)
summed_values
=
math_ops
.
unsorted_segment_sum
(
values
,
new_index_positions
,
array_ops
.
shape
(
unique_indices
)[
0
])
return
(
summed_values
,
unique_indices
)
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
OptimizerV2
(
optimizer_v1
.
Optimizer
):
@
tf_export
(
"keras.optimizers.Optimizer"
)
class
OptimizerV2
(
checkpointable
.
CheckpointableBase
):
"""Updated base class for optimizers.
This class defines the API to add Ops to train a model. You never use this
...
...
@@ -138,7 +161,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
_create_vars.
"""
self
.
_use_locking
=
True
s
uper
(
OptimizerV2
,
self
).
__init__
(
self
.
_use_locking
,
name
)
s
elf
.
_name
=
name
self
.
_hyper
=
{}
# dict: {variable name : {slot name : variable}}
self
.
_slots
=
{}
...
...
@@ -148,16 +171,11 @@ class OptimizerV2(optimizer_v1.Optimizer):
if
decay
<
0.
:
raise
ValueError
(
"decay cannot be less than 0: {}"
.
format
(
decay
))
self
.
_initial_decay
=
decay
self
.
__dict__
.
update
(
kwargs
)
self
.
_prepared
=
False
def
minimize
(
self
,
loss
,
var_list
,
aggregation_method
=
None
,
colocate_gradients_with_ops
=
False
,
name
=
None
,
grad_loss
=
None
):
def
minimize
(
self
,
loss
,
var_list
,
grad_loss
=
None
,
name
=
None
):
"""Add operations to minimize `loss` by updating `var_list`.
This method simply combines calls `compute_gradients()` and
...
...
@@ -166,15 +184,11 @@ class OptimizerV2(optimizer_v1.Optimizer):
of using this function.
Args:
loss: A
`Tensor` containing
the value to minimize.
loss: A
callable taking no arguments which returns
the value to minimize.
var_list: list or tuple of `Variable` objects to update to minimize
`loss`.
aggregation_method: Specifies the method used to combine gradient terms.
Valid values are defined in the class `AggregationMethod`.
colocate_gradients_with_ops: If True, try colocating gradients with the
corresponding op.
name: Optional name for the returned operation.
grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
name: Optional name for the returned operation.
Returns:
An Operation that updates the variables in `var_list`. If `global_step`
...
...
@@ -186,29 +200,16 @@ class OptimizerV2(optimizer_v1.Optimizer):
@compatibility(eager)
When eager execution is enabled, `loss` should be a Python function that
takes no arguments and computes the value to be minimized. Minimization (and
gradient computation) is done with respect to the elements of `var_list` if
not None, else with respect to any trainable variables created during the
execution of the `loss` function. `gate_gradients`, `aggregation_method`,
`colocate_gradients_with_ops` and `grad_loss` are ignored when eager
execution is enabled.
gradient computation) is done with respect to the elements of `var_list`.
`grad_loss` is ignored when eager execution is enabled.
@end_compatibility
"""
grads_and_vars
=
self
.
compute_gradients
(
loss
,
var_list
=
var_list
,
aggregation_method
=
aggregation_method
,
colocate_gradients_with_ops
=
colocate_gradients_with_ops
,
grad_loss
=
grad_loss
)
grads_and_vars
=
self
.
_compute_gradients
(
loss
,
var_list
=
var_list
,
grad_loss
=
grad_loss
)
return
self
.
apply_gradients
(
grads_and_vars
,
name
=
name
)
def
compute_gradients
(
self
,
loss
,
var_list
,
aggregation_method
=
None
,
colocate_gradients_with_ops
=
False
,
grad_loss
=
None
,
stop_gradients
=
None
):
def
_compute_gradients
(
self
,
loss
,
var_list
,
grad_loss
=
None
):
"""Compute gradients of `loss` for the variables in `var_list`.
This is the first part of `minimize()`. It returns a list
...
...
@@ -218,19 +219,11 @@ class OptimizerV2(optimizer_v1.Optimizer):
given variable.
Args:
loss: A Tensor containing the value to minimize or a callable taking no
arguments which returns the value to minimize. When eager execution is
enabled it must be a callable.
var_list: Optional list or tuple of `tf.Variable` to update to minimize
loss: A callable taking no arguments which returns the value to minimize.
var_list: List or tuple of `tf.Variable` to update to minimize
`loss`. Defaults to the list of variables collected in the graph under
the key `GraphKeys.TRAINABLE_VARIABLES`.
aggregation_method: Specifies the method used to combine gradient terms.
Valid values are defined in the class `AggregationMethod`.
colocate_gradients_with_ops: If True, try colocating gradients with the
corresponding op.
grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
stop_gradients: Optional. A Tensor or list of tensors not to differentiate
through.
Returns:
A list of (gradient, variable) pairs. Variable is always present, but
...
...
@@ -239,38 +232,22 @@ class OptimizerV2(optimizer_v1.Optimizer):
Raises:
TypeError: If `var_list` contains anything else than `Variable` objects.
ValueError: If some arguments are invalid, or var_list is None.
RuntimeError: If called with eager execution enabled and `loss` is
not callable.
@compatibility(eager)
When eager execution is enabled, `aggregation_method`, and
`colocate_gradients_with_ops` are ignored.
@end_compatibility
"""
var_list
=
nest
.
flatten
(
var_list
)
# TODO(josh11b): Test that we handle weight decay in a reasonable way.
if
callable
(
loss
):
with
backprop
.
GradientTape
()
as
tape
:
tape
.
watch
(
var_list
)
loss_value
=
loss
()
loss_value
=
self
.
_scale_loss
(
loss_value
)
grads
=
tape
.
gradient
(
loss_value
,
var_list
,
grad_loss
)
else
:
if
context
.
executing_eagerly
():
raise
RuntimeError
(
"`loss` passed to Optimizer.compute_gradients "
"should be a function when eager execution is "
"enabled."
)
loss
=
self
.
_scale_loss
(
loss
)
self
.
_assert_valid_dtypes
([
loss
])
if
grad_loss
is
not
None
:
self
.
_assert_valid_dtypes
([
grad_loss
])
grads
=
gradients
.
gradients
(
loss
,
var_list
,
grad_ys
=
grad_loss
,
aggregation_method
=
aggregation_method
,
colocate_gradients_with_ops
=
colocate_gradients_with_ops
,
stop_gradients
=
stop_gradients
)
with
backprop
.
GradientTape
()
as
tape
:
tape
.
watch
(
var_list
)
loss_value
=
loss
()
loss_value
=
self
.
_scale_loss
(
loss_value
)
grads
=
tape
.
gradient
(
loss_value
,
var_list
,
grad_loss
)
if
hasattr
(
self
,
"clipnorm"
):
grads
=
[
clip_ops
.
clip_by_norm
(
g
,
self
.
clipnorm
)
for
g
in
grads
]
if
hasattr
(
self
,
"clipvalue"
):
grads
=
[
clip_ops
.
clip_by_value
(
g
,
-
self
.
clipvalue
,
self
.
clipvalue
)
for
g
in
grads
]
grads_and_vars
=
list
(
zip
(
grads
,
var_list
))
self
.
_assert_valid_dtypes
([
...
...
@@ -289,6 +266,37 @@ class OptimizerV2(optimizer_v1.Optimizer):
loss_value
*=
(
1.
/
num_replicas
)
return
loss_value
def
get_gradients
(
self
,
loss
,
params
):
"""Returns gradients of `loss` with respect to `params`.
Arguments:
loss: Loss tensor.
params: List of variables.
Returns:
List of gradient tensors.
Raises:
ValueError: In case any gradient cannot be computed (e.g. if gradient
function not implemented).
"""
loss
=
self
.
_scale_loss
(
loss
)
grads
=
gradients
.
gradients
(
loss
,
params
)
if
None
in
grads
:
raise
ValueError
(
"An operation has `None` for gradient. "
"Please make sure that all of your ops have a "
"gradient defined (i.e. are differentiable). "
"Common ops without gradient: "
"K.argmax, K.round, K.eval."
)
if
hasattr
(
self
,
"clipnorm"
):
grads
=
[
clip_ops
.
clip_by_norm
(
g
,
self
.
clipnorm
)
for
g
in
grads
]
if
hasattr
(
self
,
"clipvalue"
):
grads
=
[
clip_ops
.
clip_by_value
(
g
,
-
self
.
clipvalue
,
self
.
clipvalue
)
for
g
in
grads
]
return
grads
def
apply_gradients
(
self
,
grads_and_vars
,
name
=
None
):
"""Apply gradients to variables.
...
...
@@ -351,7 +359,13 @@ class OptimizerV2(optimizer_v1.Optimizer):
return
apply_updates
def
get_updates
(
self
,
loss
,
params
):
return
[
self
.
minimize
(
loss
,
params
)]
grads
=
self
.
get_gradients
(
loss
,
params
)
grads_and_vars
=
list
(
zip
(
grads
,
params
))
self
.
_assert_valid_dtypes
([
v
for
g
,
v
in
grads_and_vars
if
g
is
not
None
and
v
.
dtype
!=
dtypes
.
resource
])
return
[
self
.
apply_gradients
(
grads_and_vars
)]
def
_set_hyper
(
self
,
name
,
value
):
"""set hyper `name` to value. value can be callable, tensor, numeric."""
...
...
@@ -575,6 +589,95 @@ class OptimizerV2(optimizer_v1.Optimizer):
return
variable
def
_assert_valid_dtypes
(
self
,
tensors
):
"""Asserts tensors are all valid types (see `_valid_dtypes`).
Args:
tensors: Tensors to check.
Raises:
ValueError: If any tensor is not a valid type.
"""
valid_dtypes
=
self
.
_valid_dtypes
()
for
t
in
tensors
:
dtype
=
t
.
dtype
.
base_dtype
if
dtype
not
in
valid_dtypes
:
raise
ValueError
(
"Invalid type %r for %s, expected: %s."
%
(
dtype
,
t
.
name
,
[
v
for
v
in
valid_dtypes
]))
def
_valid_dtypes
(
self
):
"""Valid types for loss, variables and gradients.
Subclasses should override to allow other float types.
Returns:
Valid types for loss, variables and gradients.
"""
return
set
(
[
dtypes
.
float16
,
dtypes
.
bfloat16
,
dtypes
.
float32
,
dtypes
.
float64
])
def
_call_if_callable
(
self
,
param
):
"""Call the function if param is callable."""
return
param
()
if
callable
(
param
)
else
param
def
_resource_apply_dense
(
self
,
grad
,
handle
):
"""Add ops to apply dense gradients to the variable `handle`.
Args:
grad: a `Tensor` representing the gradient.
handle: a `Tensor` of dtype `resource` which points to the variable to be
updated.
Returns:
An `Operation` which updates the value of the variable.
"""
raise
NotImplementedError
()
def
_resource_apply_sparse_duplicate_indices
(
self
,
grad
,
handle
,
indices
):
"""Add ops to apply sparse gradients to `handle`, with repeated indices.
Optimizers which override this method must deal with repeated indices. See
the docstring of `_apply_sparse_duplicate_indices` for details. By default
the correct behavior, to sum non-unique indices and their associated
gradients, is enforced by first pre-processing `grad` and `indices` and
passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
with duplicate indices may instead override this method to avoid the
overhead of summing.
Args:
grad: a `Tensor` representing the gradient for the affected indices.
handle: a `Tensor` of dtype `resource` which points to the variable to be
updated.
indices: a `Tensor` of integral type representing the indices for which
the gradient is nonzero. Indices may be repeated.
Returns:
An `Operation` which updates the value of the variable.
"""
summed_grad
,
unique_indices
=
_deduplicate_indexed_slices
(
values
=
grad
,
indices
=
indices
)
return
self
.
_resource_apply_sparse
(
summed_grad
,
handle
,
unique_indices
)
def
_resource_apply_sparse
(
self
,
grad
,
handle
,
indices
):
"""Add ops to apply sparse gradients to the variable `handle`.
Similar to `_apply_sparse`, the `indices` argument to this method has been
de-duplicated. Optimizers which deal correctly with non-unique indices may
instead override `_resource_apply_sparse_duplicate_indices` to avoid this
overhead.
Args:
grad: a `Tensor` representing the gradient for the affected indices.
handle: a `Tensor` of dtype `resource` which points to the variable to be
updated.
indices: a `Tensor` of integral type representing the indices for which
the gradient is nonzero. Indices are unique.
Returns:
An `Operation` which updates the value of the variable.
"""
raise
NotImplementedError
()
def
_filter_grads
(
grads_and_vars
):
"""Filter out iterable with grad equal to None."""
...
...
tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
浏览文件 @
0bdd941c
...
...
@@ -46,7 +46,6 @@ from tensorflow.python.keras.optimizer_v2 import gradient_descent
from
tensorflow.python.keras.optimizer_v2
import
optimizer_v2
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
clip_ops
from
tensorflow.python.ops
import
gradients_impl
from
tensorflow.python.ops
import
resource_variable_ops
from
tensorflow.python.ops
import
state_ops
from
tensorflow.python.ops
import
variables
...
...
@@ -64,8 +63,6 @@ class OptimizerTest(test.TestCase):
var0
=
resource_variable_ops
.
ResourceVariable
([
1.0
,
2.0
],
dtype
=
dtype
)
var1
=
resource_variable_ops
.
ResourceVariable
([
3.0
,
4.0
],
dtype
=
dtype
)
loss
=
lambda
:
5
*
var0
+
3
*
var1
# pylint: disable=cell-var-from-loop
if
not
context
.
executing_eagerly
():
loss
=
loss
()
sgd
=
gradient_descent
.
SGD
(
3.0
)
self
.
evaluate
(
variables
.
global_variables_initializer
())
...
...
@@ -116,33 +113,6 @@ class OptimizerTest(test.TestCase):
# var1 = [0., 1.] - 0.5 * [3, 3]
self
.
assertAllClose
([
-
1.5
,
-
0.5
],
self
.
evaluate
(
var1
))
@
test_util
.
run_in_graph_and_eager_modes
def
testAggregationMethod
(
self
):
for
dtype
in
[
dtypes
.
half
,
dtypes
.
float32
,
dtypes
.
float64
]:
with
self
.
cached_session
():
var0
=
variables
.
Variable
([
1.0
,
2.0
],
dtype
=
dtype
)
var1
=
variables
.
Variable
([
3.0
,
4.0
],
dtype
=
dtype
)
loss
=
lambda
:
5
*
var0
+
3
*
var1
# pylint: disable=cell-var-from-loop
if
not
context
.
executing_eagerly
():
loss
=
loss
()
sgd
=
gradient_descent
.
SGD
(
3.0
)
self
.
evaluate
(
variables
.
global_variables_initializer
())
# Fetch params to validate initial values
self
.
assertAllClose
([
1.0
,
2.0
],
self
.
evaluate
(
var0
))
self
.
assertAllClose
([
3.0
,
4.0
],
self
.
evaluate
(
var1
))
# Run 1 step of sgd through optimizer
opt_op
=
sgd
.
minimize
(
loss
,
var_list
=
[
var0
,
var1
],
aggregation_method
=
gradients_impl
.
AggregationMethod
.
EXPERIMENTAL_ACCUMULATE_N
)
self
.
evaluate
(
variables
.
global_variables_initializer
())
self
.
evaluate
(
opt_op
)
# Validate updated params
self
.
assertAllClose
([
-
14.
,
-
13.
],
self
.
evaluate
(
var0
))
self
.
assertAllClose
([
-
6.
,
-
5.
],
self
.
evaluate
(
var1
))
@
test_util
.
run_in_graph_and_eager_modes
def
testPrecomputedGradient
(
self
):
for
dtype
in
[
dtypes
.
half
,
dtypes
.
float32
,
dtypes
.
float64
]:
...
...
@@ -150,8 +120,6 @@ class OptimizerTest(test.TestCase):
var0
=
variables
.
Variable
([
1.0
,
2.0
],
dtype
=
dtype
)
var1
=
variables
.
Variable
([
3.0
,
4.0
],
dtype
=
dtype
)
loss
=
lambda
:
5
*
var0
+
3
*
var1
# pylint: disable=cell-var-from-loop
if
not
context
.
executing_eagerly
():
loss
=
loss
()
grad_loss
=
constant_op
.
constant
([
42
,
-
42
],
dtype
=
dtype
)
sgd
=
gradient_descent
.
SGD
(
3.0
)
...
...
@@ -176,8 +144,6 @@ class OptimizerTest(test.TestCase):
var0
=
resource_variable_ops
.
ResourceVariable
([
1.0
,
2.0
],
dtype
=
dtype
)
var1
=
resource_variable_ops
.
ResourceVariable
([
3.0
,
4.0
],
dtype
=
dtype
)
loss
=
lambda
:
5
*
var0
# pylint: disable=cell-var-from-loop
if
not
context
.
executing_eagerly
():
loss
=
loss
()
sgd_op
=
gradient_descent
.
SGD
(
3.0
)
with
self
.
assertRaisesRegexp
(
ValueError
,
'No gradients'
):
# var1 has no gradient
...
...
@@ -190,8 +156,6 @@ class OptimizerTest(test.TestCase):
var0
=
resource_variable_ops
.
ResourceVariable
([
1.0
,
2.0
],
dtype
=
dtype
)
var1
=
resource_variable_ops
.
ResourceVariable
([
3.0
,
4.0
],
dtype
=
dtype
)
loss
=
lambda
:
constant_op
.
constant
(
5.0
)
if
not
context
.
executing_eagerly
():
loss
=
loss
()
sgd_op
=
gradient_descent
.
SGD
(
3.0
)
with
self
.
assertRaisesRegexp
(
ValueError
,
...
...
@@ -216,11 +180,9 @@ class OptimizerTest(test.TestCase):
var0
=
resource_variable_ops
.
ResourceVariable
([
1.0
,
2.0
],
dtype
=
dtype
)
var1
=
resource_variable_ops
.
ResourceVariable
([
3.0
,
4.0
],
dtype
=
dtype
)
loss
=
lambda
:
5
*
var0
+
3
*
var1
# pylint: disable=cell-var-from-loop
if
not
context
.
executing_eagerly
():
loss
=
loss
()
sgd
=
gradient_descent
.
SGD
(
3.0
)
grads_and_vars
=
sgd
.
compute_gradients
(
loss
,
[
var0
,
var1
])
grads_and_vars
=
sgd
.
_
compute_gradients
(
loss
,
[
var0
,
var1
])
# Convert gradients to tf.Variables
converted_grads
=
[
resource_variable_ops
.
ResourceVariable
(
...
...
@@ -259,7 +221,7 @@ class OptimizerTest(test.TestCase):
return
x
*
x
sgd
=
gradient_descent
.
SGD
(
3.0
)
grads_and_vars
=
sgd
.
compute_gradients
(
f
,
[
x
])
grads_and_vars
=
sgd
.
_
compute_gradients
(
f
,
[
x
])
self
.
assertEqual
(
1
,
len
(
grads_and_vars
))
grad
,
x_as_var
=
grads_and_vars
[
0
]
self
.
assertIs
(
x
,
x_as_var
)
...
...
@@ -278,8 +240,6 @@ class OptimizerTest(test.TestCase):
var1
=
variables
.
Variable
([
3.0
,
4.0
],
constraint
=
constraint_0
)
loss
=
lambda
:
5
*
var0
+
3
*
var1
if
not
context
.
executing_eagerly
():
# pylint: disable=cell-var-from-loop
loss
=
loss
()
sgd
=
gradient_descent
.
SGD
(
3.0
)
self
.
evaluate
(
variables
.
global_variables_initializer
())
...
...
@@ -338,6 +298,28 @@ class OptimizerTest(test.TestCase):
self
.
evaluate
(
opt
.
_get_hyper
(
'learning_rate'
)),
opt3
.
_get_hyper
(
'learning_rate'
))
@
test_util
.
run_in_graph_and_eager_modes
def
testGradClipValue
(
self
):
with
self
.
cached_session
():
var
=
resource_variable_ops
.
ResourceVariable
([
1.0
,
2.0
])
loss
=
lambda
:
3
*
var
opt
=
gradient_descent
.
SGD
(
learning_rate
=
1.0
,
clipvalue
=
1.0
)
opt_op
=
opt
.
minimize
(
loss
,
[
var
])
self
.
evaluate
(
variables
.
global_variables_initializer
())
self
.
evaluate
(
opt_op
)
self
.
assertAllClose
([
0.
,
1.
],
self
.
evaluate
(
var
))
@
test_util
.
run_in_graph_and_eager_modes
def
testGradClipNorm
(
self
):
with
self
.
cached_session
():
var
=
resource_variable_ops
.
ResourceVariable
([
1.0
])
loss
=
lambda
:
3
*
var
opt
=
gradient_descent
.
SGD
(
learning_rate
=
1.0
,
clipnorm
=
1.0
)
opt_op
=
opt
.
minimize
(
loss
,
[
var
])
self
.
evaluate
(
variables
.
global_variables_initializer
())
self
.
evaluate
(
opt_op
)
self
.
assertAllClose
([
0.
],
self
.
evaluate
(
var
))
@
test_util
.
run_in_graph_and_eager_modes
def
testWeights
(
self
):
with
self
.
cached_session
():
...
...
tensorflow/python/keras/optimizer_v2/rmsprop.py
浏览文件 @
0bdd941c
...
...
@@ -20,8 +20,10 @@ from __future__ import print_function
from
tensorflow.python.framework
import
ops
from
tensorflow.python.keras.optimizer_v2
import
optimizer_v2
from
tensorflow.python.training
import
training_ops
from
tensorflow.python.util.tf_export
import
tf_export
@
tf_export
(
"keras.optimizers.RMSprop"
)
class
RMSprop
(
optimizer_v2
.
OptimizerV2
):
r
"""Optimizer that implements the RMSprop algorithm.
...
...
@@ -91,7 +93,7 @@ class RMSprop(optimizer_v2.OptimizerV2):
**kwargs: keyword arguments. Allowed to be {`decay`}
"""
super
(
RMSprop
,
self
).
__init__
(
name
,
**
kwargs
)
self
.
_set_hyper
(
"learning_rate"
,
learning_rate
)
self
.
_set_hyper
(
"learning_rate"
,
kwargs
.
get
(
"lr"
,
learning_rate
)
)
self
.
_set_hyper
(
"decay"
,
self
.
_initial_decay
)
self
.
_set_hyper
(
"rho"
,
rho
)
...
...
@@ -103,13 +105,13 @@ class RMSprop(optimizer_v2.OptimizerV2):
self
.
_set_hyper
(
"momentum"
,
momentum
)
self
.
_set_hyper
(
"epsilon"
,
epsilon
)
self
.
_
centered
=
centered
self
.
centered
=
centered
def
_create_slots
(
self
,
var_list
):
for
var
in
var_list
:
self
.
add_slot
(
var
,
"rms"
)
self
.
add_slot
(
var
,
"momentum"
)
if
self
.
_
centered
:
if
self
.
centered
:
self
.
add_slot
(
var
,
"mg"
)
def
_resource_apply_dense
(
self
,
grad
,
var
):
...
...
@@ -120,7 +122,7 @@ class RMSprop(optimizer_v2.OptimizerV2):
rho
=
self
.
_get_hyper
(
"rho"
,
var_dtype
)
momentum
=
self
.
_get_hyper
(
"momentum"
,
var_dtype
)
epsilon
=
self
.
_get_hyper
(
"epsilon"
,
var_dtype
)
if
self
.
_
centered
:
if
self
.
centered
:
mg
=
self
.
get_slot
(
var
,
"mg"
)
return
training_ops
.
resource_apply_centered_rms_prop
(
var
.
handle
,
...
...
@@ -153,7 +155,7 @@ class RMSprop(optimizer_v2.OptimizerV2):
rho
=
self
.
_get_hyper
(
"rho"
,
var_dtype
)
momentum
=
self
.
_get_hyper
(
"momentum"
,
var_dtype
)
epsilon
=
self
.
_get_hyper
(
"epsilon"
,
var_dtype
)
if
self
.
_
centered
:
if
self
.
centered
:
mg
=
self
.
get_slot
(
var
,
"mg"
)
return
training_ops
.
resource_sparse_apply_centered_rms_prop
(
var
.
handle
,
...
...
@@ -188,7 +190,7 @@ class RMSprop(optimizer_v2.OptimizerV2):
"rho"
:
self
.
_serialize_hyperparameter
(
"rho"
),
"momentum"
:
self
.
_serialize_hyperparameter
(
"momentum"
),
"epsilon"
:
self
.
_serialize_hyperparameter
(
"epsilon"
),
"centered"
:
self
.
_
centered
,
"centered"
:
self
.
centered
,
})
return
config
...
...
tensorflow/python/keras/optimizer_v2/rmsprop_test.py
浏览文件 @
0bdd941c
...
...
@@ -233,8 +233,11 @@ class RMSpropOptimizerTest(test.TestCase):
with
self
.
cached_session
():
var0
=
resource_variable_ops
.
ResourceVariable
([[
1.0
,
2.0
]],
dtype
=
dtype
)
x
=
constant_op
.
constant
([[
4.0
],
[
5.0
]],
dtype
=
dtype
)
pred
=
math_ops
.
matmul
(
embedding_ops
.
embedding_lookup
([
var0
],
[
0
]),
x
)
loss
=
pred
*
pred
def
loss
():
pred
=
math_ops
.
matmul
(
embedding_ops
.
embedding_lookup
([
var0
],
[
0
]),
x
)
# pylint: disable=cell-var-from-loop
return
pred
*
pred
sgd_op
=
rmsprop
.
RMSprop
(
learning_rate
=
1.0
,
rho
=
0.0
,
...
...
@@ -258,8 +261,12 @@ class RMSpropOptimizerTest(test.TestCase):
with
self
.
cached_session
():
var0
=
resource_variable_ops
.
ResourceVariable
([[
1.0
,
2.0
]],
dtype
=
dtype
)
x
=
constant_op
.
constant
([[
4.0
],
[
5.0
]],
dtype
=
dtype
)
pred
=
math_ops
.
matmul
(
embedding_ops
.
embedding_lookup
([
var0
],
[
0
]),
x
)
loss
=
pred
*
pred
def
loss
():
pred
=
math_ops
.
matmul
(
embedding_ops
.
embedding_lookup
([
var0
],
[
0
]),
x
)
# pylint: disable=cell-var-from-loop
return
pred
*
pred
# loss = lambda: pred * pred # pylint: disable=cell-var-from-loop
sgd_op
=
rmsprop
.
RMSprop
(
learning_rate
=
1.0
,
rho
=
0.0
,
...
...
@@ -405,6 +412,14 @@ class RMSpropOptimizerTest(test.TestCase):
(
0.01
*
2.0
/
math
.
sqrt
(
0.00001
*
0.9
+
1e-5
+
1.0
))
]),
self
.
evaluate
(
var1
))
def
testConstructRMSpropWithLR
(
self
):
opt
=
rmsprop
.
RMSprop
(
lr
=
1.0
)
self
.
assertEqual
(
opt
.
lr
,
1.0
)
opt_2
=
rmsprop
.
RMSprop
(
learning_rate
=
0.1
,
lr
=
1.0
)
self
.
assertEqual
(
opt_2
.
lr
,
1.0
)
opt_3
=
rmsprop
.
RMSprop
(
learning_rate
=
0.1
)
self
.
assertEqual
(
opt_3
.
lr
,
0.1
)
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/keras/optimizers.py
浏览文件 @
0bdd941c
...
...
@@ -45,7 +45,6 @@ from tensorflow.python.training.checkpointable import base as checkpointable
from
tensorflow.python.util.tf_export
import
tf_export
@
tf_export
(
'keras.optimizers.Optimizer'
)
class
Optimizer
(
object
):
"""Abstract optimizer base class.
...
...
@@ -159,7 +158,6 @@ class Optimizer(object):
return
cls
(
**
config
)
@
tf_export
(
'keras.optimizers.SGD'
)
class
SGD
(
Optimizer
):
"""Stochastic gradient descent optimizer.
...
...
@@ -224,7 +222,6 @@ class SGD(Optimizer):
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
@
tf_export
(
'keras.optimizers.RMSprop'
)
class
RMSprop
(
Optimizer
):
"""RMSProp optimizer.
...
...
@@ -291,7 +288,6 @@ class RMSprop(Optimizer):
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
@
tf_export
(
'keras.optimizers.Adagrad'
)
class
Adagrad
(
Optimizer
):
"""Adagrad optimizer.
...
...
@@ -358,7 +354,6 @@ class Adagrad(Optimizer):
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
@
tf_export
(
'keras.optimizers.Adadelta'
)
class
Adadelta
(
Optimizer
):
"""Adadelta optimizer.
...
...
@@ -442,7 +437,6 @@ class Adadelta(Optimizer):
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
@
tf_export
(
'keras.optimizers.Adam'
)
class
Adam
(
Optimizer
):
"""Adam optimizer.
...
...
@@ -539,7 +533,6 @@ class Adam(Optimizer):
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
@
tf_export
(
'keras.optimizers.Adamax'
)
class
Adamax
(
Optimizer
):
"""Adamax optimizer from Adam paper's Section 7.
...
...
tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt
浏览文件 @
0bdd941c
path: "tensorflow.keras.optimizers.Adadelta"
tf_class {
is_instance: "<class \'tensorflow.python.keras.optimizers.Adadelta\'>"
is_instance: "<class \'tensorflow.python.keras.optimizers.Optimizer\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.adadelta.Adadelta\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'1.0\', \'0.95\', \'None\', \'0.0\'], "
argspec: "args=[\'self\', \'learning_rate\', \'rho\', \'epsilon\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.95\', \'1e-07\', \'Adadelta\'], "
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'
], varargs=None, keywords=None, defaults=None
"
argspec: "args=[\'cls\', \'config\'
, \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'],
"
}
member_method {
name: "get_config"
...
...
@@ -19,6 +40,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
...
...
@@ -27,8 +52,16 @@ tf_class {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt
浏览文件 @
0bdd941c
path: "tensorflow.keras.optimizers.Adagrad"
tf_class {
is_instance: "<class \'tensorflow.python.keras.optimizers.Adagrad\'>"
is_instance: "<class \'tensorflow.python.keras.optimizers.Optimizer\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.adagrad.Adagrad\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'lr\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'None\', \'0.0\'], "
argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator_value\', \'epsilon\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.1\', \'1e-07\', \'Adagrad\'], "
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'
], varargs=None, keywords=None, defaults=None
"
argspec: "args=[\'cls\', \'config\'
, \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'],
"
}
member_method {
name: "get_config"
...
...
@@ -19,6 +40,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
...
...
@@ -27,8 +52,16 @@ tf_class {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt
浏览文件 @
0bdd941c
path: "tensorflow.keras.optimizers.Adam"
tf_class {
is_instance: "<class \'tensorflow.python.keras.optimizers.Adam\'>"
is_instance: "<class \'tensorflow.python.keras.optimizers.Optimizer\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.adam.Adam\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\', \'amsgrad\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'None\', \'0.0\', \'False\'], "
argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'amsgrad\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'False\', \'Adam\'], "
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'
], varargs=None, keywords=None, defaults=None
"
argspec: "args=[\'cls\', \'config\'
, \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'],
"
}
member_method {
name: "get_config"
...
...
@@ -19,6 +40,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
...
...
@@ -27,8 +52,16 @@ tf_class {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt
浏览文件 @
0bdd941c
path: "tensorflow.keras.optimizers.Adamax"
tf_class {
is_instance: "<class \'tensorflow.python.keras.optimizers.Adamax\'>"
is_instance: "<class \'tensorflow.python.keras.optimizers.Optimizer\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.adamax.Adamax\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.adam.Adam\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'None\', \'0.0\'], "
argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'Adamax\'], "
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'
], varargs=None, keywords=None, defaults=None
"
argspec: "args=[\'cls\', \'config\'
, \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'],
"
}
member_method {
name: "get_config"
...
...
@@ -19,6 +41,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
...
...
@@ -27,8 +53,16 @@ tf_class {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt
浏览文件 @
0bdd941c
path: "tensorflow.keras.optimizers.Optimizer"
tf_class {
is_instance: "<class \'tensorflow.python.keras.optimizers.Optimizer\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'
], varargs=None, keywords=None, defaults=None
"
argspec: "args=[\'cls\', \'config\'
, \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'],
"
}
member_method {
name: "get_config"
...
...
@@ -18,6 +39,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
...
...
@@ -26,8 +51,16 @@ tf_class {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
浏览文件 @
0bdd941c
path: "tensorflow.keras.optimizers.RMSprop"
tf_class {
is_instance: "<class \'tensorflow.python.keras.optimizers.RMSprop\'>"
is_instance: "<class \'tensorflow.python.keras.optimizers.Optimizer\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.rmsprop.RMSprop\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'None\', \'0.0\'], "
argspec: "args=[\'self\', \'learning_rate\', \'rho\', \'momentum\', \'epsilon\', \'centered\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.0\', \'1e-07\', \'False\', \'RMSprop\'], "
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'
], varargs=None, keywords=None, defaults=None
"
argspec: "args=[\'cls\', \'config\'
, \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'],
"
}
member_method {
name: "get_config"
...
...
@@ -19,6 +40,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
...
...
@@ -27,8 +52,16 @@ tf_class {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt
浏览文件 @
0bdd941c
path: "tensorflow.keras.optimizers.SGD"
tf_class {
is_instance: "<class \'tensorflow.python.keras.optimizers.SGD\'>"
is_instance: "<class \'tensorflow.python.keras.optimizers.Optimizer\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.gradient_descent.SGD\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'lr\', \'momentum\', \'decay\', \'nesterov\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'0.0\', \'0.0\', \'False\'], "
argspec: "args=[\'self\', \'learning_rate\', \'momentum\', \'nesterov\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.0\', \'False\', \'SGD\'], "
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'
], varargs=None, keywords=None, defaults=None
"
argspec: "args=[\'cls\', \'config\'
, \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'],
"
}
member_method {
name: "get_config"
...
...
@@ -19,6 +40,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
...
...
@@ -27,8 +52,16 @@ tf_class {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt
浏览文件 @
0bdd941c
path: "tensorflow.keras.optimizers.Adadelta"
tf_class {
is_instance: "<class \'tensorflow.python.keras.optimizers.Adadelta\'>"
is_instance: "<class \'tensorflow.python.keras.optimizers.Optimizer\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.adadelta.Adadelta\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'1.0\', \'0.95\', \'None\', \'0.0\'], "
argspec: "args=[\'self\', \'learning_rate\', \'rho\', \'epsilon\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.95\', \'1e-07\', \'Adadelta\'], "
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'
], varargs=None, keywords=None, defaults=None
"
argspec: "args=[\'cls\', \'config\'
, \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'],
"
}
member_method {
name: "get_config"
...
...
@@ -19,6 +40,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
...
...
@@ -27,8 +52,16 @@ tf_class {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt
浏览文件 @
0bdd941c
path: "tensorflow.keras.optimizers.Adagrad"
tf_class {
is_instance: "<class \'tensorflow.python.keras.optimizers.Adagrad\'>"
is_instance: "<class \'tensorflow.python.keras.optimizers.Optimizer\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.adagrad.Adagrad\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'lr\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'None\', \'0.0\'], "
argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator_value\', \'epsilon\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.1\', \'1e-07\', \'Adagrad\'], "
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'
], varargs=None, keywords=None, defaults=None
"
argspec: "args=[\'cls\', \'config\'
, \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'],
"
}
member_method {
name: "get_config"
...
...
@@ -19,6 +40,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
...
...
@@ -27,8 +52,16 @@ tf_class {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt
浏览文件 @
0bdd941c
path: "tensorflow.keras.optimizers.Adam"
tf_class {
is_instance: "<class \'tensorflow.python.keras.optimizers.Adam\'>"
is_instance: "<class \'tensorflow.python.keras.optimizers.Optimizer\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.adam.Adam\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\', \'amsgrad\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'None\', \'0.0\', \'False\'], "
argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'amsgrad\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'False\', \'Adam\'], "
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'
], varargs=None, keywords=None, defaults=None
"
argspec: "args=[\'cls\', \'config\'
, \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'],
"
}
member_method {
name: "get_config"
...
...
@@ -19,6 +40,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
...
...
@@ -27,8 +52,16 @@ tf_class {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt
浏览文件 @
0bdd941c
path: "tensorflow.keras.optimizers.Adamax"
tf_class {
is_instance: "<class \'tensorflow.python.keras.optimizers.Adamax\'>"
is_instance: "<class \'tensorflow.python.keras.optimizers.Optimizer\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.adamax.Adamax\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.adam.Adam\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'None\', \'0.0\'], "
argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'Adamax\'], "
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'
], varargs=None, keywords=None, defaults=None
"
argspec: "args=[\'cls\', \'config\'
, \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'],
"
}
member_method {
name: "get_config"
...
...
@@ -19,6 +41,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
...
...
@@ -27,8 +53,16 @@ tf_class {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt
浏览文件 @
0bdd941c
path: "tensorflow.keras.optimizers.Optimizer"
tf_class {
is_instance: "<class \'tensorflow.python.keras.optimizers.Optimizer\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'
], varargs=None, keywords=None, defaults=None
"
argspec: "args=[\'cls\', \'config\'
, \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'],
"
}
member_method {
name: "get_config"
...
...
@@ -18,6 +39,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
...
...
@@ -26,8 +51,16 @@ tf_class {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
浏览文件 @
0bdd941c
path: "tensorflow.keras.optimizers.RMSprop"
tf_class {
is_instance: "<class \'tensorflow.python.keras.optimizers.RMSprop\'>"
is_instance: "<class \'tensorflow.python.keras.optimizers.Optimizer\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.rmsprop.RMSprop\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'None\', \'0.0\'], "
argspec: "args=[\'self\', \'learning_rate\', \'rho\', \'momentum\', \'epsilon\', \'centered\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.0\', \'1e-07\', \'False\', \'RMSprop\'], "
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'
], varargs=None, keywords=None, defaults=None
"
argspec: "args=[\'cls\', \'config\'
, \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'],
"
}
member_method {
name: "get_config"
...
...
@@ -19,6 +40,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
...
...
@@ -27,8 +52,16 @@ tf_class {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt
浏览文件 @
0bdd941c
path: "tensorflow.keras.optimizers.SGD"
tf_class {
is_instance: "<class \'tensorflow.python.keras.optimizers.SGD\'>"
is_instance: "<class \'tensorflow.python.keras.optimizers.Optimizer\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.gradient_descent.SGD\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'lr\', \'momentum\', \'decay\', \'nesterov\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'0.0\', \'0.0\', \'False\'], "
argspec: "args=[\'self\', \'learning_rate\', \'momentum\', \'nesterov\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.0\', \'False\', \'SGD\'], "
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'
], varargs=None, keywords=None, defaults=None
"
argspec: "args=[\'cls\', \'config\'
, \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'],
"
}
member_method {
name: "get_config"
...
...
@@ -19,6 +40,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
...
...
@@ -27,8 +52,16 @@ tf_class {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录