提交 542ec6d4 编写于 作者: E Eugene Brevdo 提交者: TensorFlower Gardener

[TF optimizers (v1)] Non-slot variables are ResourceVariables iff the input vars are.

This fixes a bug where Adam beta*_power variables were always created as RefVars
even if the optimizer acts on ResourceVars.  This broke certain defun + Adam
use cases.

Also fixed the unit tests, which *always* created ResourceVariables
(ever since variables.Variable() constructor became aliased to ResourceVariables).

PiperOrigin-RevId: 224869338
上级 c0627e1f
......@@ -68,8 +68,8 @@ class AdamOptimizerTest(test.TestCase):
var0 = resource_variable_ops.ResourceVariable(var0_np)
var1 = resource_variable_ops.ResourceVariable(var1_np)
else:
var0 = variables.Variable(var0_np)
var1 = variables.Variable(var1_np)
var0 = variables.RefVariable(var0_np)
var1 = variables.RefVariable(var1_np)
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = ops.IndexedSlices(
constant_op.constant(grads0_np),
......@@ -156,6 +156,9 @@ class AdamOptimizerTest(test.TestCase):
self.evaluate(repeated_index_update_var))
def doTestBasic(self, use_resource=False, use_callable_params=False):
if context.executing_eagerly() and not use_resource:
self.skipTest(
"Skipping test with use_resource=False and executing eagerly.")
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
with self.session(graph=ops.Graph()):
# Initialize variables for numpy implementation.
......@@ -171,8 +174,8 @@ class AdamOptimizerTest(test.TestCase):
var1 = resource_variable_ops.ResourceVariable(
var1_np, name="var1_%d" % i)
else:
var0 = variables.Variable(var0_np)
var1 = variables.Variable(var1_np)
var0 = variables.RefVariable(var0_np)
var1 = variables.RefVariable(var1_np)
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
......@@ -194,6 +197,14 @@ class AdamOptimizerTest(test.TestCase):
self.assertTrue(beta2_power is not None)
self.assertIn(beta1_power, opt_variables)
self.assertIn(beta2_power, opt_variables)
# Ensure that non-slot variables are the same type as the requested
# variables.
self.assertEqual(
use_resource,
resource_variable_ops.is_resource_variable(beta1_power))
self.assertEqual(
use_resource,
resource_variable_ops.is_resource_variable(beta2_power))
if not context.executing_eagerly():
with ops.Graph().as_default():
......
......@@ -822,7 +822,10 @@ class Optimizer(
name=name, shape=None)
if restored_initial_value is not None:
initial_value = restored_initial_value
v = variable_scope.variable(initial_value, name=name, trainable=False)
v = variable_scope.variable(
initial_value, name=name, trainable=False,
use_resource=resource_variable_ops.is_resource_variable(
colocate_with))
# Restore this variable by name if necessary, but don't add a
# Checkpointable dependency. Optimizers return the current graph's
# non-slot variables from _checkpoint_dependencies explicitly rather
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册