提交 a8b6963c 编写于 作者: J Jing Li 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 272915002
上级 b045ce7d
......@@ -71,6 +71,20 @@ def build_stats(train_result, eval_result, time_callback):
def get_input_dataset(flags_obj, strategy):
"""Returns the test and train input datasets."""
dtype = flags_core.get_tf_dtype(flags_obj)
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
batch_size = flags_obj.batch_size
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
'Batch size must be divisible by number of replicas : {}'.format(
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
if flags_obj.use_synthetic_data:
input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
......@@ -82,34 +96,51 @@ def get_input_dataset(flags_obj, strategy):
else:
input_fn = imagenet_preprocessing.input_fn
train_ds = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
parse_record_fn=imagenet_preprocessing.parse_record,
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype)
def _train_dataset_fn(ctx=None):
train_ds = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=batch_size,
parse_record_fn=imagenet_preprocessing.parse_record,
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype,
input_context=ctx,
drop_remainder=True)
return train_ds
if strategy:
train_ds = strategy.experimental_distribute_dataset(train_ds)
if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
train_ds = strategy.experimental_distribute_datasets_from_function(_train_dataset_fn)
else:
train_ds = strategy.experimental_distribute_dataset(_train_dataset_fn())
else:
train_ds = _train_dataset_fn()
test_ds = None
if not flags_obj.skip_eval:
test_ds = input_fn(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
parse_record_fn=imagenet_preprocessing.parse_record,
dtype=dtype)
def _test_data_fn(ctx=None):
test_ds = input_fn(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=batch_size,
parse_record_fn=imagenet_preprocessing.parse_record,
dtype=dtype,
input_context=ctx)
return test_ds
if strategy:
test_ds = strategy.experimental_distribute_dataset(test_ds)
if strategy:
if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
test_ds = strategy.experimental_distribute_datasets_from_function(_test_data_fn)
else:
test_ds = strategy.experimental_distribute_dataset(_test_data_fn())
else:
test_ds = _test_data_fn()
return train_ds, test_ds
def get_num_train_iterations(flags_obj):
"""Returns the number of training stesps, train and test epochs."""
"""Returns the number of training steps, train and test epochs."""
train_steps = (
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
train_epochs = flags_obj.train_epochs
......@@ -124,6 +155,15 @@ def get_num_train_iterations(flags_obj):
return train_steps, train_epochs, eval_steps
def _steps_to_run(steps_in_current_epoch, steps_per_epoch, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 0:
raise ValueError('steps_per_loop should be positive integer.')
if steps_per_loop == 1:
return steps_per_loop
return min(steps_per_loop, steps_per_epoch - steps_in_current_epoch)
def run(flags_obj):
"""Run ResNet ImageNet training and eval loop using custom training loops.
......@@ -152,33 +192,45 @@ def run(flags_obj):
num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster(),
all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs)
num_packs=flags_obj.num_packs,
tpu_address=flags_obj.tpu)
train_ds, test_ds = get_input_dataset(flags_obj, strategy)
train_steps, train_epochs, eval_steps = get_num_train_iterations(flags_obj)
per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
flags_obj)
steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)
logging.info("Training %d epochs, each epoch has %d steps, "
"total steps: %d; Eval %d steps",
train_epochs, per_epoch_steps, train_epochs * per_epoch_steps,
eval_steps)
time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
flags_obj.log_steps)
strategy_scope = distribution_utils.get_strategy_scope(strategy)
with strategy_scope:
with distribution_utils.get_strategy_scope(strategy):
model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES,
batch_size=flags_obj.batch_size,
use_l2_regularizer=not flags_obj.single_l2_loss_op)
optimizer = tf.keras.optimizers.SGD(
learning_rate=common.BASE_LEARNING_RATE, momentum=0.9,
nesterov=True)
if flags_obj.fp16_implementation == "graph_rewrite":
lr_schedule = common.PiecewiseConstantDecayWithWarmup(
batch_size=flags_obj.batch_size,
epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
warmup_epochs=common.LR_SCHEDULE[0][1],
boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
multipliers=list(p[0] for p in common.LR_SCHEDULE),
compute_lr_on_cpu=True)
optimizer = common.get_optimizer(lr_schedule)
if flags_obj.fp16_implementation == 'graph_rewrite':
if not flags_obj.use_tf_function:
raise ValueError("--fp16_implementation=graph_rewrite requires "
"--use_tf_function to be true")
raise ValueError('--fp16_implementation=graph_rewrite requires '
'--use_tf_function to be true')
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer, loss_scale)
train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32)
test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
......@@ -187,55 +239,56 @@ def run(flags_obj):
trainable_variables = model.trainable_variables
def train_step(train_ds_inputs):
"""Training StepFn."""
def step_fn(inputs):
"""Per-Replica StepFn."""
images, labels = inputs
with tf.GradientTape() as tape:
logits = model(images, training=True)
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits)
loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
if flags_obj.single_l2_loss_op:
filtered_variables = [
tf.reshape(v, (-1,))
for v in trainable_variables
if 'bn' not in v.name
]
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
tf.concat(filtered_variables, axis=0))
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas)
# Scale the loss
if flags_obj.dtype == "fp16":
loss = optimizer.get_scaled_loss(loss)
grads = tape.gradient(loss, trainable_variables)
# Unscale the grads
def step_fn(inputs):
"""Per-Replica StepFn."""
images, labels = inputs
with tf.GradientTape() as tape:
logits = model(images, training=True)
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits)
loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
if flags_obj.single_l2_loss_op:
filtered_variables = [
tf.reshape(v, (-1,))
for v in trainable_variables
if 'bn' not in v.name
]
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
tf.concat(filtered_variables, axis=0))
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas)
# Scale the loss
if flags_obj.dtype == "fp16":
grads = optimizer.get_unscaled_gradients(grads)
loss = optimizer.get_scaled_loss(loss)
optimizer.apply_gradients(zip(grads, trainable_variables))
grads = tape.gradient(loss, trainable_variables)
training_accuracy.update_state(labels, logits)
return loss
# Unscale the grads
if flags_obj.dtype == "fp16":
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(zip(grads, trainable_variables))
train_loss.update_state(loss)
training_accuracy.update_state(labels, logits)
@tf.function
def train_steps(iterator, steps):
"""Performs distributed training steps in a loop."""
for _ in tf.range(steps):
strategy.experimental_run_v2(step_fn, args=(next(iterator),))
def train_single_step(iterator):
if strategy:
per_replica_losses = strategy.experimental_run_v2(
step_fn, args=(train_ds_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
axis=None)
strategy.experimental_run_v2(step_fn, args=(next(iterator),))
else:
return step_fn(train_ds_inputs)
return step_fn(next(iterator))
def test_step(test_ds_inputs):
def test_step(iterator):
"""Evaluation StepFn."""
def step_fn(inputs):
images, labels = inputs
......@@ -247,34 +300,39 @@ def run(flags_obj):
test_accuracy.update_state(labels, logits)
if strategy:
strategy.experimental_run_v2(step_fn, args=(test_ds_inputs,))
strategy.experimental_run_v2(step_fn, args=(next(iterator),))
else:
step_fn(test_ds_inputs)
step_fn(next(iterator))
if flags_obj.use_tf_function:
train_step = tf.function(train_step)
train_single_step = tf.function(train_single_step)
test_step = tf.function(test_step)
train_iter = iter(train_ds)
time_callback.on_train_begin()
for epoch in range(train_epochs):
train_iter = iter(train_ds)
total_loss = 0.0
train_loss.reset_states()
training_accuracy.reset_states()
for step in range(train_steps):
optimizer.lr = common.learning_rate_schedule(
epoch, step, train_steps, flags_obj.batch_size)
time_callback.on_batch_begin(step+epoch*train_steps)
total_loss += train_step(next(train_iter))
time_callback.on_batch_end(step+epoch*train_steps)
train_loss = total_loss / train_steps
logging.info('Training loss: %s, accuracy: %s%% at epoch: %d',
train_loss.numpy(),
steps_in_current_epoch = 0
while steps_in_current_epoch < per_epoch_steps:
time_callback.on_batch_begin(
steps_in_current_epoch+epoch*per_epoch_steps)
steps = _steps_to_run(steps_in_current_epoch, per_epoch_steps,
steps_per_loop)
if steps == 1:
train_single_step(train_iter)
else:
# Converts steps to a Tensor to avoid tf.function retracing.
train_steps(train_iter, tf.convert_to_tensor(steps, dtype=tf.int32))
time_callback.on_batch_end(
steps_in_current_epoch+epoch*per_epoch_steps)
steps_in_current_epoch += steps
logging.info('Training loss: %s, accuracy: %s%% at epoch %d',
train_loss.result().numpy(),
training_accuracy.result().numpy(),
epoch)
epoch + 1)
if (not flags_obj.skip_eval and
(epoch + 1) % flags_obj.epochs_between_evals == 0):
......@@ -283,12 +341,12 @@ def run(flags_obj):
test_iter = iter(test_ds)
for _ in range(eval_steps):
test_step(next(test_iter))
test_step(test_iter)
logging.info('Test loss: %s, accuracy: %s%% at epoch: %d',
test_loss.result().numpy(),
test_accuracy.result().numpy(),
epoch)
epoch + 1)
time_callback.on_train_end()
......@@ -297,7 +355,7 @@ def run(flags_obj):
if not flags_obj.skip_eval:
eval_result = [test_loss.result().numpy(),
test_accuracy.result().numpy()]
train_result = [train_loss.numpy(),
train_result = [train_loss.result().numpy(),
training_accuracy.result().numpy()]
stats = build_stats(train_result, eval_result, time_callback)
......@@ -307,7 +365,8 @@ def run(flags_obj):
def main(_):
model_helpers.apply_clean(flags.FLAGS)
with logger.benchmark_context(flags.FLAGS):
return run(flags.FLAGS)
stats = run(flags.FLAGS)
logging.info('Run stats:\n%s', stats)
if __name__ == '__main__':
......
......@@ -353,6 +353,13 @@ def define_keras_flags(dynamic_loss_scale=True):
flags.DEFINE_boolean(
name='enable_checkpoint_and_export', default=False,
help='Whether to enable a checkpoint callback and export the savedmodel.')
flags.DEFINE_string(
name='tpu', default='', help='TPU address to connect to.')
flags.DEFINE_integer(
name='steps_per_loop', default=1,
help='Number of steps per graph-mode loop. Only training step happens '
'inside the loop. Callbacks will not be called inside. Will be capped at '
'steps per epoch.')
def get_synth_input_fn(height, width, num_channels, num_classes,
......
......@@ -115,9 +115,9 @@ def process_record_dataset(dataset,
if is_training:
# Shuffles records before repeating to respect epoch boundaries.
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
# Repeats the dataset for the number of epochs to train.
dataset = dataset.repeat()
# Repeats the dataset for the number of epochs to train.
dataset = dataset.repeat(num_epochs)
# Parses the raw records into images and labels.
dataset = dataset.map(
......@@ -133,10 +133,10 @@ def process_record_dataset(dataset,
# on how many devices are present.
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
if tf_data_experimental_slack:
options = tf.data.Options()
options.experimental_slack = True
dataset = dataset.with_options(options)
options = tf.data.Options()
options.experimental_slack = tf_data_experimental_slack
options.experimental_allow_stateful = True
dataset = dataset.with_options(options)
return dataset
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册