提交 b2f6e348 编写于 作者: R Reed Wanderman-Milne 提交者: TensorFlower Gardener

Rename internal set_policy function to set_global_policy.

Also change some references to the experimental "set_policy" function to the nonexperimental "set_global_policy" function.

PiperOrigin-RevId: 393891991
上级 cf930d36
......@@ -22,7 +22,7 @@ import numpy as np
import keras
from keras import keras_parameterized
from keras.applications import imagenet_utils as utils
from keras.mixed_precision.policy import set_policy
from keras.mixed_precision.policy import set_global_policy
class TestImageNetUtils(keras_parameterized.TestCase):
......@@ -160,7 +160,7 @@ class TestImageNetUtils(keras_parameterized.TestCase):
},
])
def test_preprocess_input_symbolic_mixed_precision(self, mode):
set_policy('mixed_float16')
set_global_policy('mixed_float16')
shape = (20, 20, 3)
inputs = keras.layers.Input(shape=shape)
try:
......@@ -168,7 +168,7 @@ class TestImageNetUtils(keras_parameterized.TestCase):
lambda x: utils.preprocess_input(x, mode=mode), output_shape=shape)(
inputs)
finally:
set_policy('float32')
set_global_policy('float32')
@parameterized.named_parameters([
{'testcase_name': 'channels_last_format',
......
......@@ -84,7 +84,7 @@ def set_floatx(value):
Note: It is not recommended to set this to float16 for training, as this will
likely cause numeric stability issues. Instead, mixed precision, which is
using a mix of float16 and float32, can be used by calling
`tf.keras.mixed_precision.experimental.set_policy('mixed_float16')`. See the
`tf.keras.mixed_precision.set_global_policy('mixed_float16')`. See the
[mixed precision guide](
https://www.tensorflow.org/guide/keras/mixed_precision) for details.
......
......@@ -125,13 +125,13 @@ class EmbeddingTest(keras_parameterized.TestCase):
@testing_utils.enable_v2_dtype_behavior
def test_mixed_precision_embedding(self):
try:
policy.set_policy('mixed_float16')
policy.set_global_policy('mixed_float16')
layer = keras.layers.Embedding(input_dim=5, output_dim=2)
self.assertEqual(layer._dtype_policy.name, 'mixed_float16')
outputs = layer(np.array([0, 1, 2]))
self.assertEqual(outputs.dtype, 'float16')
finally:
policy.set_policy('float32')
policy.set_global_policy('float32')
if __name__ == '__main__':
......
......@@ -450,7 +450,7 @@ def _check_if_mixed_precision_graph_rewrite_is_enabled(policy):
'At most, one of the following can be called:\n\n'
' 1. tf.compat.v1.train.enable_mixed_precision_graph_rewrite() '
'(You called this first)\n'
' 2. tf.keras.mixed_precision.experimental.set_policy() with a mixed '
' 2. tf.keras.mixed_precision.set_global_policy() with a mixed '
'precision policy (You called this second)\n\n'
'You called both functions, which is an error, because both functions '
'enable you to use mixed precision. If in doubt which function to use, '
......@@ -460,7 +460,7 @@ def _check_if_mixed_precision_graph_rewrite_is_enabled(policy):
@keras_export('keras.mixed_precision.set_global_policy',
'keras.mixed_precision.experimental.set_policy', v1=[])
def set_policy(policy):
def set_global_policy(policy):
"""Sets the global dtype policy.
The global policy is the default `tf.keras.mixed_precision.Policy` used for
......@@ -509,8 +509,8 @@ def set_policy(policy):
_check_if_mixed_precision_graph_rewrite_is_enabled(policy)
if (policy is not None and policy.compute_dtype is not None and
not tf.as_dtype(policy.compute_dtype).is_floating):
raise ValueError('set_policy can only be used to set the global policy to '
'floating-point policies, such as "float32" and '
raise ValueError('set_global_policy can only be used to set the global '
'policy to floating-point policies, such as "float32" and '
'"mixed_float16", but got policy: %s'
% (policy.name,))
_global_policy = policy
......@@ -530,10 +530,10 @@ def policy_scope(policy):
"""
old_policy = _global_policy
try:
set_policy(policy)
set_global_policy(policy)
yield
finally:
set_policy(old_policy)
set_global_policy(old_policy)
def _is_convertible_to_dtype(dtype):
......
......@@ -141,32 +141,32 @@ class PolicyTest(tf.test.TestCase, parameterized.TestCase):
default_policy = '_infer'
self.assertEqual(mp_policy.global_policy().name, default_policy)
try:
mp_policy.set_policy('mixed_float16')
mp_policy.set_global_policy('mixed_float16')
self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
with tf.Graph().as_default(): # Policies are not associated with a graph
self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
mp_policy.set_policy('_infer')
mp_policy.set_global_policy('_infer')
self.assertEqual(mp_policy.global_policy().name, '_infer')
policy = mp_policy.Policy('mixed_bfloat16')
mp_policy.set_policy(policy)
mp_policy.set_global_policy(policy)
self.assertIs(mp_policy.global_policy(), policy)
finally:
mp_policy.set_policy(None)
mp_policy.set_global_policy(None)
@testing_utils.enable_v2_dtype_behavior
def test_global_policy_dtype_error(self):
with self.assertRaisesRegex(
ValueError,
'set_policy can only be used to set the global policy to '
'set_global_policy can only be used to set the global policy to '
'floating-point policies, such as "float32" and "mixed_float16", but '
'got policy: int32'):
mp_policy.set_policy('int32')
mp_policy.set_global_policy('int32')
with self.assertRaisesRegex(
ValueError,
'set_policy can only be used to set the global policy to '
'set_global_policy can only be used to set the global policy to '
'floating-point policies, such as "float32" and "mixed_float16", but '
'got policy: complex64'):
mp_policy.set_policy(mp_policy.Policy('complex64'))
mp_policy.set_global_policy(mp_policy.Policy('complex64'))
@testing_utils.enable_v2_dtype_behavior
def test_loss_scale_warning(self):
......@@ -301,7 +301,7 @@ class PolicyTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaisesRegex(
ValueError, 'cannot be set to "mixed_float16", .* the mixed '
'precision graph rewrite has already been enabled'):
mp_policy.set_policy('mixed_float16')
mp_policy.set_global_policy('mixed_float16')
with mp_policy.policy_scope('float64'):
pass # Non-mixed policies are allowed
finally:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册