提交 99ce664f 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fix bad unpack in error message generation when no gradients are provided

for any variables in apply_gradients().
Add tests.
Change: 141480285
上级 bb5b65a8
......@@ -388,7 +388,7 @@ class Optimizer(object):
var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
if not var_list:
raise ValueError("No gradients provided for any variable: %s." %
([str(v) for _, v in converted_grads_and_vars],))
([str(v) for _, _, v in converted_grads_and_vars],))
with ops.control_dependencies(None):
self._create_slots(var_list)
update_ops = []
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional test for optimizer."""
from __future__ import absolute_import
from __future__ import division
......@@ -52,8 +51,7 @@ class OptimizerTest(tf.test.TestCase):
sgd_op = tf.train.GradientDescentOptimizer(3.0)
opt_op = sgd_op.minimize(
cost,
global_step,
[var0, var1],
global_step, [var0, var1],
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
tf.global_variables_initializer().run()
......@@ -75,9 +73,8 @@ class OptimizerTest(tf.test.TestCase):
grad_loss = tf.constant([42, -42], dtype=dtype)
global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
sgd_op = tf.train.GradientDescentOptimizer(3.0)
opt_op = sgd_op.minimize(cost,
global_step, [var0, var1],
grad_loss=grad_loss)
opt_op = sgd_op.minimize(
cost, global_step, [var0, var1], grad_loss=grad_loss)
tf.global_variables_initializer().run()
# Fetch params to validate initial values
......@@ -86,10 +83,10 @@ class OptimizerTest(tf.test.TestCase):
# Run 1 step of sgd through optimizer
opt_op.run()
# Validate updated params
self.assertAllClose(
[1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)], var0.eval())
self.assertAllClose(
[3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)], var1.eval())
self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)],
var0.eval())
self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
var1.eval())
def testNoVariables(self):
for dtype in [tf.half, tf.float32, tf.float64]:
......@@ -113,6 +110,28 @@ class OptimizerTest(tf.test.TestCase):
# var1 has no gradient
sgd_op.minimize(cost, global_step, [var1])
def testNoGradientsForAnyVariables_Minimize(self):
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
cost = tf.constant(5.0)
global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
sgd_op = tf.train.GradientDescentOptimizer(3.0)
with self.assertRaisesRegexp(ValueError,
'No gradients provided for any variable'):
sgd_op.minimize(cost, global_step, [var0, var1])
def testNoGradientsForAnyVariables_ApplyGradients(self):
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
sgd_op = tf.train.GradientDescentOptimizer(3.0)
with self.assertRaisesRegexp(ValueError,
'No gradients provided for any variable'):
sgd_op.apply_gradients([(None, var0), (None, var1)])
def testGradientsAsVariables(self):
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session() as sess:
......@@ -123,8 +142,13 @@ class OptimizerTest(tf.test.TestCase):
sgd_op = tf.train.GradientDescentOptimizer(3.0)
grads_and_vars = sgd_op.compute_gradients(cost, [var0, var1])
# Convert gradients to tf.Variables
converted_grads = [tf.Variable(tf.zeros([2], dtype)) for i in grads_and_vars]
convert_ops = [tf.assign(converted_grads[i], gv[0]) for i,gv in enumerate(grads_and_vars)]
converted_grads = [
tf.Variable(tf.zeros([2], dtype)) for i in grads_and_vars
]
convert_ops = [
tf.assign(converted_grads[i], gv[0])
for i, gv in enumerate(grads_and_vars)
]
converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
opt_op = sgd_op.apply_gradients(converted_grads_and_vars, global_step)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册