提交 c7504db1 编写于 作者: R Reed Wanderman-Milne 提交者: TensorFlower Gardener

Add better error message for LossScaleOptimizer AttributeError.

Before, the error message would be something like "LossScaleOptimizer object has no attribute _hyper", even if the accessed attribute was not _hyper.

PiperOrigin-RevId: 258462751
上级 c51eb496
......@@ -136,6 +136,9 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
self._track_trackable(self._optimizer, 'base_optimizer')
self._track_trackable(self._loss_scale, 'loss_scale')
# Needed because the superclass's __getattribute__ checks this.
self._hyper = {}
@property
def loss_scale(self):
"""The `LossScale` instance associated with this optimizer."""
......
......@@ -265,10 +265,14 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
opt.beta_1 # pylint: disable=pointless-statement
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=10.)
# Test that attributes defined by OptimizerV2 subclasses are not exposed in
# LossScaleOptimizer.
with self.assertRaises(AttributeError):
# LossScaleOptimizer, and that the error message is sensible.
with self.assertRaisesRegexp(
AttributeError,
"'LossScaleOptimizer' object has no attribute 'epsilon'"):
opt.epsilon # pylint: disable=pointless-statement
with self.assertRaises(AttributeError):
with self.assertRaisesRegexp(
AttributeError,
"'LossScaleOptimizer' object has no attribute 'beta_1'"):
opt.beta_1 # pylint: disable=pointless-statement
@test_util.run_in_graph_and_eager_modes
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册