提交 53f95a95 编写于 作者: R Reed Wanderman-Milne 提交者: TensorFlower Gardener

Fix issue where AutoCastVariables wouldn't cast in tf.functions.

I tried to fix this in 1ec90f6a, but it was rolled back in e737182e because it broke cases where one layer tried accessed another layer's variables when AutoCastVariables were used, such as in RNNs. This is a much simpler fix.

I added RNN mixed precision tests, because RNNs use tf.functions, which did not work with mixed precision before this change.

PiperOrigin-RevId: 257684192
上级 7c416718
......@@ -373,6 +373,10 @@ class FuncGraph(ops.Graph):
# optimizers.
old_graph_key = self._graph_key
self._graph_key = graph._graph_key
# Inherit the auto_cast_variable_read_dtype, since this should not change
# inside a function.
old_auto_cast_var_read_dtype = self._auto_cast_variable_read_dtype
self._auto_cast_variable_read_dtype = graph._auto_cast_variable_read_dtype
# pylint: enable=protected-access
with outer_cm as g:
......@@ -383,6 +387,7 @@ class FuncGraph(ops.Graph):
self._device_function_stack = old_device_stack
self._variable_creator_stack = old_creator_stack
self._graph_key = old_graph_key
self._auto_cast_variable_read_dtype = old_auto_cast_var_read_dtype
return inner_cm()
@property
......
......@@ -27,6 +27,7 @@ from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
......@@ -38,6 +39,7 @@ from tensorflow.python.keras import regularizers
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers import recurrent
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
from tensorflow.python.keras.mixed_precision.experimental import policy
from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
......@@ -124,6 +126,14 @@ class AddLayerWithoutAutoCast(AddLayer):
return self._add(inputs, math_ops.cast(self.v, inputs.dtype))
class AddLayerWithFunction(AddLayer):
"""Same as AddLayer, but _add is decorated with a tf.function."""
@def_function.function
def _add(self, x, y):
return super(AddLayerWithFunction, self)._add(x, y)
class IdentityRegularizer(regularizers.Regularizer):
def __call__(self, x):
......@@ -181,6 +191,19 @@ class KerasLayerTest(keras_parameterized.TestCase):
self.evaluate(variables.global_variables_initializer())
self.assertEqual(self.evaluate(y), 2.)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_layer_calling_tf_function(self, strategy_fn):
x = constant_op.constant([1.], dtype=dtypes.float16)
with strategy_fn().scope():
with policy.policy_scope('infer_float32_vars'):
layer = AddLayerWithFunction(assert_type=dtypes.float16)
y = layer(x)
self.assertEqual(layer.v.dtype, dtypes.float32)
self.assertEqual(y.dtype, dtypes.float16)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(self.evaluate(y), 2.)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_layer_regularizer_runs_in_float32(self, strategy_fn):
......@@ -694,5 +717,52 @@ class KerasModelTest(keras_parameterized.TestCase):
self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1)
class RnnTest(keras_parameterized.TestCase):
"""Test mixed precision with RNNs."""
# TODO(b/136512020): Support and test recurrent_v2.GRU.
@parameterized.named_parameters({
'testcase_name': 'base_simple',
'strategy_fn': default_strategy_fn,
'rnn_class': recurrent.SimpleRNN,
}, {
'testcase_name': 'distribute_simple',
'strategy_fn': create_mirrored_strategy,
'rnn_class': recurrent.SimpleRNN,
}, {
'testcase_name': 'base_gru',
'strategy_fn': default_strategy_fn,
'rnn_class': recurrent.GRU,
}, {
'testcase_name': 'distribute_gru',
'strategy_fn': create_mirrored_strategy,
'rnn_class': recurrent.GRU,
})
@test_util.run_in_graph_and_eager_modes
# RNNs do not work properly with GradientTape in graph mode when V1 control
# flow is used.
@test_util.enable_control_flow_v2
def test_rnn(self, strategy_fn, rnn_class):
x = array_ops.ones((2, 3, 4), dtype=dtypes.float16)
strategy = strategy_fn()
with strategy.scope(), policy.policy_scope('infer_float32_vars'):
layer = rnn_class(units=4)
def run_fn():
with backprop.GradientTape() as tape:
y = layer(x)
self.assertEqual(y.dtype, dtypes.float16)
opt = gradient_descent.SGD(1.)
grads = tape.gradient(y, layer.trainable_weights)
return opt.apply_gradients(zip(grads, layer.trainable_weights))
op = strategy.experimental_run(run_fn)
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
self.evaluate(op)
for v in layer.weights:
self.assertEqual(v.dtype, dtypes.float32)
if __name__ == '__main__':
test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册