提交 41293260 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 273371605
上级 fa5b66e5
......@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app as absl_app
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
......@@ -181,6 +181,12 @@ def run(flags_obj):
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
# TODO(anj-s): Set data_format without using Keras.
data_format = flags_obj.data_format
if data_format is None:
......@@ -375,4 +381,4 @@ if __name__ == '__main__':
common.define_keras_flags()
ctl_common.define_ctl_flags()
flags.adopt_module_key_flags(ctl_common)
absl_app.run(main)
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test the ResNet model with ImageNet data using CTL."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tempfile import mkdtemp
import tensorflow as tf
from tensorflow.python.platform import googletest
from official.resnet.ctl import ctl_common
from official.resnet.ctl import ctl_imagenet_main
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import common
from official.utils.misc import keras_utils
from official.utils.testing import integration
class CtlImagenetTest(googletest.TestCase):
"""Unit tests for Keras ResNet with ImageNet using CTL."""
_extra_flags = [
'-batch_size', '4',
'-train_steps', '4',
'-use_synthetic_data', 'true'
]
_tempdir = None
def get_temp_dir(self):
if not self._tempdir:
self._tempdir = mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(CtlImagenetTest, cls).setUpClass()
common.define_keras_flags()
ctl_common.define_ctl_flags()
def setUp(self):
super(CtlImagenetTest, self).setUp()
if not keras_utils.is_v2_0():
tf.compat.v1.enable_v2_behavior()
imagenet_preprocessing.NUM_IMAGES['validation'] = 4
def tearDown(self):
super(CtlImagenetTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
def test_end_to_end_tpu(self):
"""Test Keras model with TPU distribution strategy."""
extra_flags = [
'-distribution_strategy', 'tpu',
'-model_dir', 'ctl_imagenet_tpu_dist_strat',
'-data_format', 'channels_last',
'-use_tf_function', 'true',
'-single_l2_loss_op', 'true',
]
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=ctl_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
def test_end_to_end_tpu_bf16(self):
"""Test Keras model with TPU and bfloat16 activation."""
extra_flags = [
'-distribution_strategy', 'tpu',
'-model_dir', 'ctl_imagenet_tpu_dist_strat_bf16',
'-data_format', 'channels_last',
'-use_tf_function', 'true',
'-single_l2_loss_op', 'true',
'-dtype', 'bf16',
]
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=ctl_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
if __name__ == '__main__':
googletest.main()
......@@ -35,7 +35,7 @@ def inf(dtype):
Returns:
A very large value.
"""
if dtype == "float32":
if dtype == "float32" or dtype == "bfloat16":
return 1e7
elif dtype == "float16":
# Disable no-member lint error, as the linter thinks np.float16 does not
......
......@@ -386,7 +386,7 @@ class LayerNormalization(tf.keras.layers.Layer):
def call(self, x, epsilon=1e-6):
input_dtype = x.dtype
if input_dtype == tf.float16:
if input_dtype == tf.float16 or input_dtype == tf.bfloat16:
x = tf.cast(x, tf.float32)
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)
......
......@@ -171,6 +171,11 @@ class TransformerTask(object):
"mixed_float16", loss_scale=loss_scale)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
if params["dtype"] == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_bfloat16")
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=num_gpus,
......
......@@ -29,6 +29,7 @@ from official.utils.flags._conventions import help_wrap
# Map string to TensorFlow dtype
DTYPE_MAP = {
"fp16": tf.float16,
"bf16": tf.bfloat16,
"fp32": tf.float32,
}
......
......@@ -67,6 +67,10 @@ def run(flags_obj):
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
if not keras_utils.is_v2_0():
raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')
elif dtype == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
data_format = flags_obj.data_format
if data_format is None:
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test the keras ResNet model with ImageNet data on TPU."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_imagenet_main
class KerasImagenetTest(tf.test.TestCase):
"""Unit tests for Keras ResNet with ImageNet."""
_extra_flags = [
"-batch_size", "4",
"-train_steps", "1",
"-use_synthetic_data", "true"
]
_tempdir = None
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasImagenetTest, cls).setUpClass()
resnet_imagenet_main.define_imagenet_keras_flags()
def setUp(self):
super(KerasImagenetTest, self).setUp()
imagenet_preprocessing.NUM_IMAGES["validation"] = 4
def tearDown(self):
super(KerasImagenetTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
def test_end_to_end_tpu(self):
"""Test Keras model with TPU distribution strategy."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
extra_flags = [
"-distribution_strategy", "tpu",
"-data_format", "channels_last",
]
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
def test_end_to_end_tpu_bf16(self):
"""Test Keras model with TPU and bfloat16 activation."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
extra_flags = [
"-distribution_strategy", "tpu",
"-data_format", "channels_last",
"-dtype", "bf16",
]
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=resnet_imagenet_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
if __name__ == "__main__":
tf.compat.v1.enable_v2_behavior()
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册