提交 026a6870 编写于 作者: Z Zhenyu Tan 提交者: TensorFlower Gardener

Implement Keras V2 Adagrad optimizer.

Several notes regarding this change:

1) initial_accumulator_value and is kept from the public signature, since
internal search suggested many people use that.

2) included new argument epsilon from Keras, just the same as other optimizers.

PiperOrigin-RevId: 222150112
上级 43a573e6
......@@ -14,6 +14,7 @@ py_library(
name = "optimizer_v2",
srcs = [
"adadelta.py",
"adagrad.py",
"adam.py",
"adamax.py",
"gradient_descent.py",
......@@ -35,6 +36,25 @@ py_library(
],
)
cuda_py_test(
name = "adagrad_test",
size = "medium",
srcs = ["adagrad_test.py"],
additional_deps = [
":optimizer_v2",
"//tensorflow/python:client_testlib",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:resources",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
],
shard_count = 4,
)
cuda_py_test(
name = "adam_test",
size = "medium",
......
......@@ -54,7 +54,7 @@ class Adadelta(optimizer_v2.OptimizerV2):
def __init__(self,
learning_rate=0.001,
rho=0.95,
epsilon=1e-8,
epsilon=1e-7,
name='Adadelta'):
"""Construct a new Adadelta optimizer.
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adagrad for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
class Adagrad(optimizer_v2.OptimizerV2):
r"""Optimizer that implements the Adagrad algorithm.
Adagrad is an optimizer with parameter-specific learning rates,
which are adapted relative to how frequently a parameter gets
updated during training. The more updates a parameter receives,
the smaller the updates.
Initialization:
$$accum_g_0 := initial_accumulator_value$$
$$t := t + 1$$
$$accum_g_t := accum_g_{t-1} + g * g$$
$$theta_t := theta_{t-1} - lr * g / (\sqrt{accum_g_t} + \epsilon)$$
References
See [paper]
(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
or this
[intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
"""
def __init__(self,
learning_rate=0.001,
initial_accumulator_value=0.1,
epsilon=1e-7,
name='Adagrad'):
"""Construct a new Adagrad optimizer.
Args:
learning_rate: A `Tensor` or a floating point value. The learning rate.
initial_accumulator_value: A floating point value.
Starting value for the accumulators, must be positive.
epsilon: A floating point value.
Starting value for the accumulators, must be positive.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "Adagrad".
Raises:
ValueError: If the `initial_accumulator_value` or `epsilon` is invalid.
@compatibility(eager)
When eager execution is enabled, `learning_rate` can be a callable that
takes no arguments and returns the actual value to use. This can be useful
for changing these values across different invocations of optimizer
functions.
@end_compatibility
"""
if initial_accumulator_value <= 0.0:
raise ValueError('initial_accumulator_value must be positive: %s' %
initial_accumulator_value)
if epsilon < 1e-7:
raise ValueError('epsilon must be larger than 1e-7: %s' % epsilon)
super(Adagrad, self).__init__(name)
self._set_hyper('learning_rate', learning_rate)
self._initial_accumulator_value = initial_accumulator_value
self._set_hyper('epsilon', epsilon)
def _create_slots(self, var_list):
for var in var_list:
dtype = var.dtype.base_dtype
init = init_ops.constant_initializer(
self._initial_accumulator_value, dtype=dtype)
self.add_slot(var, 'accumulator', init)
def _init_constant_op(self, v, dtype):
def init():
# Use a Tensor instead of initializer if variable does not have
# static shape.
init_constant = gen_array_ops.fill(array_ops.shape(v),
self._initial_accumulator_value)
return math_ops.cast(init_constant, dtype)
return init
def _resource_apply_dense(self, grad, var):
var_dtype = var.dtype.base_dtype
learning_rate = math_ops.cast(self._get_hyper('learning_rate'), var_dtype)
epsilon = math_ops.cast(self._get_hyper('epsilon'), var_dtype)
acc = self.get_slot(var, 'accumulator')
acc_t = state_ops.assign_add(
acc, math_ops.square(grad), use_locking=self._use_locking)
var_update = state_ops.assign_sub(
var, learning_rate * grad / (math_ops.sqrt(acc_t) + epsilon))
return var_update
def _resource_apply_sparse(self, grad, var, indices):
def _resource_scatter_add(x, i, v):
with ops.control_dependencies(
[resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
return x.value()
var_dtype = var.dtype.base_dtype
learning_rate = math_ops.cast(self._get_hyper('learning_rate'), var_dtype)
epsilon = math_ops.cast(self._get_hyper('epsilon'), var_dtype)
acc = self.get_slot(var, 'accumulator')
acc_t = _resource_scatter_add(acc, indices, math_ops.square(grad))
acc_t_slice = array_ops.gather(acc_t, indices)
var_update = _resource_scatter_add(
var, indices,
-learning_rate * grad / (math_ops.sqrt(acc_t_slice) + epsilon))
return var_update
def get_config(self):
config = super(Adagrad, self).get_config()
config.update({
'learning_rate': self._serialize_hyperparameter('learning_rate'),
'initial_accumulator_value': self._initial_accumulator_value,
'epsilon': self._serialize_hyperparameter('epsilon'),
})
return config
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for aggregate operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.optimizer_v2 import adagrad
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
def adagrad_update_numpy(param, accum, g_t, lr=0.001, epsilon=1e-7):
accum_t = accum + g_t * g_t
param_t = param - lr * g_t / (np.sqrt(accum_t) + epsilon)
return param_t, accum_t
def sparse_adagrad_update_numpy(param,
accum,
gindexs,
gvalues,
lr=0.001,
epsilon=1e-7):
accum_t = copy.deepcopy(accum)
param_t = copy.deepcopy(param)
# first loop accumulates repeated indices if necessary.
for i in range(len(gindexs)):
gindex = gindexs[i]
gvalue = gvalues[i]
accum_t[gindex] = accum_t[gindex] + gvalue * gvalue
for i in range(len(gindexs)):
gindex = gindexs[i]
gvalue = gvalues[i]
param_t[gindex] = param_t[gindex] - lr * gvalue / (
np.sqrt(accum_t[gindex]) + epsilon)
return param_t, accum_t
class AdagradOptimizerTest(test.TestCase):
def doTestBasic(self, use_callable_params=False):
for dtype in [dtypes.float32, dtypes.float64]:
with self.cached_session():
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
var0 = resource_variable_ops.ResourceVariable(var0_np)
var1 = resource_variable_ops.ResourceVariable(var1_np)
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
learning_rate = lambda: 3.0
if not use_callable_params:
learning_rate = learning_rate()
ada_opt = adagrad.Adagrad(learning_rate)
accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
if not context.executing_eagerly():
ada_update = ada_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose([1.0, 2.0], v0_val)
self.assertAllClose([3.0, 4.0], v1_val)
# Run 3 steps of adagrad
for _ in range(3):
if not context.executing_eagerly():
self.evaluate(ada_update)
else:
ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np,
grads0_np, 3.0)
var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np,
grads1_np, 3.0)
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
@test_util.run_in_graph_and_eager_modes(reset_test=True)
def testBasic(self):
self.doTestBasic()
def testBasicCallableParams(self):
with context.eager_mode():
self.doTestBasic(use_callable_params=True)
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
var0 = resource_variable_ops.ResourceVariable(
[[1.0, 2.0], [3.0, 4.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
loss = pred * pred
sgd_op = adagrad.Adagrad(1.0).minimize(loss, var_list=[var0])
variables.global_variables_initializer().run()
# Fetch params to validate initial values
self.assertAllCloseAccordingToType(
[[1.0, 2.0], [3.0, 4.0]], var0.eval())
# Run 1 step of sgd
sgd_op.run()
# Validate updated params
self.assertAllCloseAccordingToType(
[[0, 1], [3, 4]], var0.eval(), atol=0.01)
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
var0 = resource_variable_ops.ResourceVariable(var0_np)
var1 = resource_variable_ops.ResourceVariable(var1_np)
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
learning_rate = constant_op.constant(3.0)
ada_opt = adagrad.Adagrad(learning_rate)
ada_update = ada_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 4.0], var1.eval())
accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
# Run 3 steps of adagrad
for _ in range(3):
ada_update.run()
var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np,
grads0_np, learning_rate)
var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np,
grads1_np, learning_rate)
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
def testSparseBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype)
grads0_np = np.array([0.1, 0, 0.1], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype)
grads1_np = np.array([0.01, 0, 0.01], dtype=dtype.as_numpy_dtype)
var0 = resource_variable_ops.ResourceVariable(var0_np)
var1 = resource_variable_ops.ResourceVariable(var1_np)
grads0_np_indices = np.array([0, 2], dtype=np.int32)
grads0 = ops.IndexedSlices(
constant_op.constant(grads0_np[grads0_np_indices]),
constant_op.constant(grads0_np_indices), constant_op.constant([3]))
grads1_np_indices = np.array([0, 2], dtype=np.int32)
grads1 = ops.IndexedSlices(
constant_op.constant(grads1_np[grads1_np_indices]),
constant_op.constant(grads1_np_indices), constant_op.constant([3]))
learning_rate = 3.0
ada_opt = adagrad.Adagrad(learning_rate)
ada_update = ada_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
self.assertAllClose([1.0, 1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 3.0, 4.0], var1.eval())
accum0_np = np.array([0.1, 0.1, 0.1], dtype=dtype.as_numpy_dtype)
accum1_np = np.array([0.1, 0.1, 0.1], dtype=dtype.as_numpy_dtype)
# Run 3 step of sgd
for _ in range(3):
ada_update.run()
var0_np, accum0_np = sparse_adagrad_update_numpy(
var0_np, accum0_np, grads0_np_indices,
grads0_np[grads0_np_indices], learning_rate)
var1_np, accum1_np = sparse_adagrad_update_numpy(
var1_np, accum1_np, grads1_np_indices,
grads1_np[grads1_np_indices], learning_rate)
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
def testSparseRepeatedIndices(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
var_np = np.array([[1.0], [2.0]], dtype=dtype.as_numpy_dtype)
repeated_index_update_var = resource_variable_ops.ResourceVariable(
var_np, dtype=dtype)
aggregated_update_var = resource_variable_ops.ResourceVariable(
var_np, dtype=dtype)
grad_repeated_index = ops.IndexedSlices(
constant_op.constant(
[0.1, 0.1], shape=[2, 1], dtype=dtype),
constant_op.constant([1, 1]),
constant_op.constant([2, 1]))
grad_aggregated = ops.IndexedSlices(
constant_op.constant(
[0.2], shape=[1, 1], dtype=dtype),
constant_op.constant([1]),
constant_op.constant([2, 1]))
repeated_update = adagrad.Adagrad(3.0).apply_gradients(
[(grad_repeated_index, repeated_index_update_var)])
aggregated_update = adagrad.Adagrad(3.0).apply_gradients(
[(grad_aggregated, aggregated_update_var)])
variables.global_variables_initializer().run()
self.assertAllClose(aggregated_update_var.eval(),
repeated_index_update_var.eval())
for _ in range(3):
repeated_update.run()
aggregated_update.run()
self.assertAllClose(aggregated_update_var.eval(),
repeated_index_update_var.eval())
def testSparseRepeatedIndicesByEmbeddingLookUp(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
var_repeated = resource_variable_ops.ResourceVariable(
[1.0, 2.0], dtype=dtype)
loss_repeated = math_ops.reduce_sum(
embedding_ops.embedding_lookup(var_repeated, [0, 0]))
var_aggregated = resource_variable_ops.ResourceVariable(
[1.0, 2.0], dtype=dtype)
loss_aggregated = 2 * math_ops.reduce_sum(
embedding_ops.embedding_lookup(var_aggregated, [0]))
update_op_repeated = adagrad.Adagrad(2.0).minimize(
loss_repeated, var_list=[var_repeated])
update_op_aggregated = adagrad.Adagrad(2.0).minimize(
loss_aggregated, var_list=[var_aggregated])
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(
var_repeated.eval(), var_aggregated.eval())
for _ in range(3):
update_op_repeated.run()
update_op_aggregated.run()
self.assertAllCloseAccordingToType(
var_repeated.eval(), var_aggregated.eval())
def testSparseStability(self):
for dtype in [dtypes.half]:
with self.cached_session():
shape = [1, 6]
var0_np = np.array([[
0.00872496, -0.106952, 0.110467, 0.226505, -0.0147257, -0.0105945
]],
dtype=dtype.as_numpy_dtype)
var0 = resource_variable_ops.ResourceVariable(var0_np)
grads0_np = np.array([[
-5.91278e-05, 5.31673e-05, -2.5779e-06, 4.29153e-05, -8.4877e-05,
-9.48906e-05
]],
dtype=dtype.as_numpy_dtype)
grads0 = ops.IndexedSlices(
constant_op.constant(grads0_np), constant_op.constant([0]),
constant_op.constant(shape))
ada_opt = adagrad.Adagrad(1.0)
ada_update = ada_opt.apply_gradients(zip([grads0], [var0]))
slot0 = ada_opt.get_slot(var0, "accumulator")
init = variables.global_variables_initializer()
for _ in range(100):
init.run()
ada_update.run()
self.assertAllCloseAccordingToType(
np.array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]), slot0.eval())
self.assertAllCloseAccordingToType(
np.array([[
0.00891194, -0.10712013, 0.11047515, 0.22636929, -0.0144573,
-0.01029443
]]), var0.eval())
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
var0 = resource_variable_ops.ResourceVariable(var0_np)
var1 = resource_variable_ops.ResourceVariable(var1_np)
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
learning_rate = 3.0
ada_opt = adagrad.Adagrad(learning_rate)
# Apply the optimizer twice. Both applications will use
# the same accums.
ada_update1 = ada_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
ada_update2 = ada_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
slot0 = ada_opt.get_slot(var0, "accumulator")
self.assertEquals(slot0.get_shape(), var0.get_shape())
slot1 = ada_opt.get_slot(var1, "accumulator")
self.assertEquals(slot1.get_shape(), var1.get_shape())
variables.global_variables_initializer().run()
# Fetch params to validate initial values.
self.assertAllClose([1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 4.0], var1.eval())
# Mix the first and the second adagrad for 3 steps.
ada_update1.run()
ada_update2.run()
ada_update1.run()
accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
for _ in range(3):
var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np,
grads0_np, learning_rate)
var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np,
grads1_np, learning_rate)
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
if __name__ == "__main__":
test.main()
......@@ -44,7 +44,7 @@ class Adam(optimizer_v2.OptimizerV2):
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-8,
epsilon=1e-7,
name='Adam'):
r"""Construct a new Adam optimizer.
......
......@@ -41,7 +41,7 @@ def adam_update_numpy(param,
alpha=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8):
epsilon=1e-7):
alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t)
m_t = beta1 * m + (1 - beta1) * g_t
......
......@@ -43,7 +43,7 @@ class AdaMax(adam.Adam):
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-8,
epsilon=1e-7,
name='AdaMax'):
"""Construct a new AdaMax optimizer.
......
......@@ -374,12 +374,16 @@ class OptimizerV2(optimizer_v1.Optimizer):
else:
super(OptimizerV2, self).__setattr__(name, value)
def add_slot(self, var, slot_name):
def add_slot(self, var, slot_name, initializer="zeros"):
var_key = _var_key(var)
slot_dict = self._slots.setdefault(var_key, {})
if slot_name not in slot_dict:
slot_key = _get_slot_key_from_var(var, slot_name)
weight = self.add_weight(name=slot_key, shape=var.shape, dtype=var.dtype)
weight = self.add_weight(
name=slot_key,
shape=var.shape,
dtype=var.dtype,
initializer=initializer)
slot_dict[slot_name] = weight
self._weights.append(weight)
......
......@@ -56,7 +56,7 @@ class RMSProp(optimizer_v2.OptimizerV2):
learning_rate=0.001,
rho=0.9,
momentum=0.0,
epsilon=1e-10,
epsilon=1e-7,
centered=False,
name="RMSProp"):
"""Construct a new RMSProp optimizer.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册