提交 c115444f 编写于 作者: J Jaehong Kim 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 283962490
上级 ef8aed79
......@@ -24,6 +24,7 @@ import numpy as np
import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
import tensorflow_model_optimization as tfmot
from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils
......@@ -180,7 +181,12 @@ def get_optimizer(learning_rate=0.1):
# TODO(hongkuny,haoyuzhang): make cifar model use_tensor_lr to clean up code.
def get_callbacks(steps_per_epoch, learning_rate_schedule_fn=None):
def get_callbacks(
steps_per_epoch,
learning_rate_schedule_fn=None,
pruning_method='',
enable_checkpoint_and_export=False,
model_dir=''):
"""Returns common callbacks."""
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
callbacks = [time_callback]
......@@ -205,6 +211,17 @@ def get_callbacks(steps_per_epoch, learning_rate_schedule_fn=None):
steps_per_epoch)
callbacks.append(profiler_callback)
if model_dir:
if pruning_method == 'polynomial_decay':
callbacks.append(tfmot.sparsity.keras.PruningSummaries(
log_dir=model_dir, profile_batch=0))
callbacks.append(tfmot.sparsity.keras.UpdatePruningStep())
if enable_checkpoint_and_export:
ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
save_weights_only=True))
return callbacks
......@@ -358,6 +375,31 @@ def get_synth_data(height, width, num_channels, num_classes, dtype):
return inputs, labels
def define_pruning_flags():
"""Define flags for pruning methods."""
flags.DEFINE_string('pruning_method', '',
'Pruning method.'
'Empty string (no pruning) or polynomial_decay.')
flags.DEFINE_float('pruning_initial_sparsity', 0.0,
'Initial sparsity for pruning.')
flags.DEFINE_float('pruning_final_sparsity', 0.5,
'Final sparsity for pruning.')
flags.DEFINE_integer('pruning_begin_step', 0,
'Begin step for pruning.')
flags.DEFINE_integer('pruning_end_step', 100000,
'End step for pruning.')
flags.DEFINE_integer('pruning_frequency', 100,
'Frequency for pruning.')
flags.DEFINE_string('model', 'resnet50_v1.5',
'Name of model preset. (mobilenet, resnet50_v1.5)')
flags.DEFINE_string('optimizer', 'resnet50_default',
'Name of optimizer preset. '
'(mobilenet_default, resnet50_default)')
flags.DEFINE_string('pretrained_filepath', '',
'Pretrained file path.')
def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32, drop_remainder=True):
"""Returns an input function that returns a dataset with random data.
......
......@@ -246,6 +246,24 @@ def parse_record(raw_record, is_training, dtype):
return image, label
def get_parse_record_fn(use_keras_image_data_format=False):
"""Get function to use for parsing the records.
Args:
use_keras_image_data_format: A boolean denoting whether data format is keras
backend image data format.
Returns:
Function to use for parsing the records.
"""
def parse_record_fn(raw_record, is_training, dtype):
image, label = parse_record(raw_record, is_training, dtype)
if use_keras_image_data_format:
if tf.keras.backend.image_data_format() == 'channels_first':
image = tf.transpose(image, perm=[2, 0, 1])
return image, label
return parse_record_fn
def input_fn(is_training,
data_dir,
batch_size,
......
......@@ -25,6 +25,8 @@ from absl import flags
from absl import logging
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.benchmark.models import trivial_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
......@@ -44,6 +46,7 @@ def run(flags_obj):
Raises:
ValueError: If fp16 is passed as it is not currently supported.
NotImplementedError: If some features are not currently supported.
Returns:
Dictionary of training and eval stats.
......@@ -120,12 +123,20 @@ def run(flags_obj):
# in the dataset, as XLA-GPU doesn't support dynamic shapes.
drop_remainder = flags_obj.enable_xla
# Current resnet_model.resnet50 input format is always channel-last.
# We use keras_application mobilenet model which input format is depends on
# the keras beckend image data format.
# This use_keras_image_data_format flags indicates whether image preprocessor
# output format should be same as the keras backend image data format or just
# channel-last format.
use_keras_image_data_format = (flags_obj.model == 'mobilenet')
train_input_dataset = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=imagenet_preprocessing.parse_record,
parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
use_keras_image_data_format=use_keras_image_data_format),
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype,
drop_remainder=drop_remainder,
......@@ -140,7 +151,8 @@ def run(flags_obj):
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=imagenet_preprocessing.parse_record,
parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
use_keras_image_data_format=use_keras_image_data_format),
dtype=dtype,
drop_remainder=drop_remainder)
......@@ -153,9 +165,27 @@ def run(flags_obj):
boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
multipliers=list(p[0] for p in common.LR_SCHEDULE),
compute_lr_on_cpu=True)
steps_per_epoch = (
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
learning_rate_schedule_fn = None
with strategy_scope:
optimizer = common.get_optimizer(lr_schedule)
if flags_obj.optimizer == 'resnet50_default':
optimizer = common.get_optimizer(lr_schedule)
learning_rate_schedule_fn = common.learning_rate_schedule
elif flags_obj.optimizer == 'mobilenet_default':
lr_decay_factor = 0.94
num_epochs_per_decay = 2.5
initial_learning_rate_per_sample = 0.000007
initial_learning_rate = \
initial_learning_rate_per_sample * flags_obj.batch_size
optimizer = tf.keras.optimizers.SGD(
learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate,
decay_steps=steps_per_epoch * num_epochs_per_decay,
decay_rate=lr_decay_factor,
staircase=True),
momentum=0.9)
if flags_obj.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
......@@ -169,9 +199,30 @@ def run(flags_obj):
if flags_obj.use_trivial_model:
model = trivial_model.trivial_model(
imagenet_preprocessing.NUM_CLASSES)
else:
elif flags_obj.model == 'resnet50_v1.5':
model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES)
elif flags_obj.model == 'mobilenet':
model = tf.keras.applications.mobilenet.MobileNet(
weights=None,
classes=imagenet_preprocessing.NUM_CLASSES)
if flags_obj.pretrained_filepath:
model.load_weights(flags_obj.pretrained_filepath)
if flags_obj.pruning_method == 'polynomial_decay':
if dtype != tf.float32:
raise NotImplementedError(
'Pruning is currently only supported on dtype=tf.float32.')
pruning_params = {
'pruning_schedule':
tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=flags_obj.pruning_initial_sparsity,
final_sparsity=flags_obj.pruning_final_sparsity,
begin_step=flags_obj.pruning_begin_step,
end_step=flags_obj.pruning_end_step,
frequency=flags_obj.pruning_frequency),
}
model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
# a valid arg for this model. Also remove as a valid flag.
......@@ -191,16 +242,14 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly)
steps_per_epoch = (
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
train_epochs = flags_obj.train_epochs
callbacks = common.get_callbacks(steps_per_epoch,
common.learning_rate_schedule)
if flags_obj.enable_checkpoint_and_export:
ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}')
callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
save_weights_only=True))
callbacks = common.get_callbacks(
steps_per_epoch=steps_per_epoch,
learning_rate_schedule_fn=learning_rate_schedule_fn,
pruning_method=flags_obj.pruning_method,
enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
model_dir=flags_obj.model_dir)
# if mutliple epochs, ignore the train_steps flag.
if train_epochs <= 1 and flags_obj.train_steps:
......@@ -236,13 +285,6 @@ def run(flags_obj):
validation_data=validation_data,
validation_freq=flags_obj.epochs_between_evals,
verbose=2)
if flags_obj.enable_checkpoint_and_export:
if dtype == tf.bfloat16:
logging.warning("Keras model.save does not support bfloat16 dtype.")
else:
# Keras model.save assumes a float32 input designature.
export_path = os.path.join(flags_obj.model_dir, 'saved_model')
model.save(export_path, include_optimizer=False)
eval_output = None
if not flags_obj.skip_eval:
......@@ -250,6 +292,16 @@ def run(flags_obj):
steps=num_eval_steps,
verbose=2)
if flags_obj.pruning_method == 'polynomial_decay':
model = tfmot.sparsity.keras.strip_pruning(model)
if flags_obj.enable_checkpoint_and_export:
if dtype == tf.bfloat16:
logging.warning('Keras model.save does not support bfloat16 dtype.')
else:
# Keras model.save assumes a float32 input designature.
export_path = os.path.join(flags_obj.model_dir, 'saved_model')
model.save(export_path, include_optimizer=False)
if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__()
......@@ -259,6 +311,7 @@ def run(flags_obj):
def define_imagenet_keras_flags():
common.define_keras_flags()
common.define_pruning_flags()
flags_core.set_defaults()
flags.adopt_module_key_flags(common)
......
......@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.eager import context
......@@ -27,14 +28,45 @@ from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_imagenet_main
@parameterized.parameters(
"resnet",
"resnet_polynomial_decay",
"mobilenet",
"mobilenet_polynomial_decay")
class KerasImagenetTest(tf.test.TestCase):
"""Unit tests for Keras ResNet with ImageNet."""
_extra_flags = [
"-batch_size", "4",
"-train_steps", "1",
"-use_synthetic_data", "true"
]
"""Unit tests for Keras Models with ImageNet."""
_extra_flags_dict = {
"resnet": [
"-batch_size", "4",
"-train_steps", "1",
"-use_synthetic_data", "true"
"-model", "resnet50_v1.5",
"-optimizer", "resnet50_default",
],
"resnet_polynomial_decay": [
"-batch_size", "4",
"-train_steps", "1",
"-use_synthetic_data", "true",
"-model", "resnet50_v1.5",
"-optimizer", "resnet50_default",
"-pruning_method", "polynomial_decay",
],
"mobilenet": [
"-batch_size", "4",
"-train_steps", "1",
"-use_synthetic_data", "true"
"-model", "mobilenet",
"-optimizer", "mobilenet_default",
],
"mobilenet_polynomial_decay": [
"-batch_size", "4",
"-train_steps", "1",
"-use_synthetic_data", "true",
"-model", "mobilenet",
"-optimizer", "mobilenet_default",
"-pruning_method", "polynomial_decay",
],
}
_tempdir = None
@classmethod
......@@ -50,7 +82,7 @@ class KerasImagenetTest(tf.test.TestCase):
super(KerasImagenetTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
def test_end_to_end_no_dist_strat(self):
def test_end_to_end_no_dist_strat(self, flags_key):
"""Test Keras model with 1 GPU, no distribution strategy."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
......@@ -59,7 +91,7 @@ class KerasImagenetTest(tf.test.TestCase):
"-distribution_strategy", "off",
"-data_format", "channels_last",
]
extra_flags = extra_flags + self._extra_flags
extra_flags = extra_flags + self._extra_flags_dict[flags_key]
integration.run_synthetic(
main=resnet_imagenet_main.run,
......@@ -67,14 +99,14 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags=extra_flags
)
def test_end_to_end_graph_no_dist_strat(self):
def test_end_to_end_graph_no_dist_strat(self, flags_key):
"""Test Keras model in legacy graph mode with 1 GPU, no dist strat."""
extra_flags = [
"-enable_eager", "false",
"-distribution_strategy", "off",
"-data_format", "channels_last",
]
extra_flags = extra_flags + self._extra_flags
extra_flags = extra_flags + self._extra_flags_dict[flags_key]
integration.run_synthetic(
main=resnet_imagenet_main.run,
......@@ -82,7 +114,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags=extra_flags
)
def test_end_to_end_1_gpu(self):
def test_end_to_end_1_gpu(self, flags_key):
"""Test Keras model with 1 GPU."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
......@@ -98,7 +130,7 @@ class KerasImagenetTest(tf.test.TestCase):
"-data_format", "channels_last",
"-enable_checkpoint_and_export", "1",
]
extra_flags = extra_flags + self._extra_flags
extra_flags = extra_flags + self._extra_flags_dict[flags_key]
integration.run_synthetic(
main=resnet_imagenet_main.run,
......@@ -106,7 +138,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags=extra_flags
)
def test_end_to_end_1_gpu_fp16(self):
def test_end_to_end_1_gpu_fp16(self, flags_key):
"""Test Keras model with 1 GPU and fp16."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
......@@ -122,7 +154,10 @@ class KerasImagenetTest(tf.test.TestCase):
"-distribution_strategy", "mirrored",
"-data_format", "channels_last",
]
extra_flags = extra_flags + self._extra_flags
extra_flags = extra_flags + self._extra_flags_dict[flags_key]
if "polynomial_decay" in extra_flags:
self.skipTest("Pruning with fp16 is not currently supported.")
integration.run_synthetic(
main=resnet_imagenet_main.run,
......@@ -130,8 +165,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags=extra_flags
)
def test_end_to_end_2_gpu(self):
def test_end_to_end_2_gpu(self, flags_key):
"""Test Keras model with 2 GPUs."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
......@@ -145,7 +179,7 @@ class KerasImagenetTest(tf.test.TestCase):
"-num_gpus", "2",
"-distribution_strategy", "mirrored",
]
extra_flags = extra_flags + self._extra_flags
extra_flags = extra_flags + self._extra_flags_dict[flags_key]
integration.run_synthetic(
main=resnet_imagenet_main.run,
......@@ -153,7 +187,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags=extra_flags
)
def test_end_to_end_xla_2_gpu(self):
def test_end_to_end_xla_2_gpu(self, flags_key):
"""Test Keras model with XLA and 2 GPUs."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
......@@ -168,7 +202,7 @@ class KerasImagenetTest(tf.test.TestCase):
"-enable_xla", "true",
"-distribution_strategy", "mirrored",
]
extra_flags = extra_flags + self._extra_flags
extra_flags = extra_flags + self._extra_flags_dict[flags_key]
integration.run_synthetic(
main=resnet_imagenet_main.run,
......@@ -176,7 +210,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags=extra_flags
)
def test_end_to_end_2_gpu_fp16(self):
def test_end_to_end_2_gpu_fp16(self, flags_key):
"""Test Keras model with 2 GPUs and fp16."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
......@@ -191,7 +225,10 @@ class KerasImagenetTest(tf.test.TestCase):
"-dtype", "fp16",
"-distribution_strategy", "mirrored",
]
extra_flags = extra_flags + self._extra_flags
extra_flags = extra_flags + self._extra_flags_dict[flags_key]
if "polynomial_decay" in extra_flags:
self.skipTest("Pruning with fp16 is not currently supported.")
integration.run_synthetic(
main=resnet_imagenet_main.run,
......@@ -199,7 +236,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags=extra_flags
)
def test_end_to_end_xla_2_gpu_fp16(self):
def test_end_to_end_xla_2_gpu_fp16(self, flags_key):
"""Test Keras model with XLA, 2 GPUs and fp16."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
......@@ -215,7 +252,10 @@ class KerasImagenetTest(tf.test.TestCase):
"-enable_xla", "true",
"-distribution_strategy", "mirrored",
]
extra_flags = extra_flags + self._extra_flags
extra_flags = extra_flags + self._extra_flags_dict[flags_key]
if "polynomial_decay" in extra_flags:
self.skipTest("Pruning with fp16 is not currently supported.")
integration.run_synthetic(
main=resnet_imagenet_main.run,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册