提交 e67a2064 编写于 作者: A A. Unique TensorFlower

Change summary directory and model checkpoint directory so that training via...

Change summary directory and model checkpoint directory so that training via Keras Compile/Fit() and custom training loop is consistent.

PiperOrigin-RevId: 274202793
上级 ad1a37c9
......@@ -148,7 +148,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.train_batch_size = 4
FLAGS.eval_batch_size = 4
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_1_gpu_mrpc_xla(self):
......@@ -165,7 +166,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.eval_batch_size = 4
FLAGS.enable_xla = True
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_1_gpu_mrpc_no_dist_strat(self):
......@@ -181,7 +183,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.train_batch_size = 4
FLAGS.eval_batch_size = 4
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
def benchmark_2_gpu_mrpc(self):
......@@ -197,7 +200,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.train_batch_size = 8
FLAGS.eval_batch_size = 8
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_4_gpu_mrpc(self):
......@@ -212,7 +216,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 16
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_8_gpu_mrpc(self):
......@@ -225,7 +230,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_1_gpu_amp_mrpc_no_dist_strat(self):
......@@ -243,7 +249,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
def benchmark_8_gpu_amp_mrpc(self):
......@@ -262,7 +269,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
......@@ -320,7 +328,8 @@ class BertClassifyAccuracy(BertClassifyBenchmarkBase):
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_8_gpu_mrpc_xla(self):
......@@ -328,7 +337,8 @@ class BertClassifyAccuracy(BertClassifyBenchmarkBase):
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc_xla')
FLAGS.enable_xla = True
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
......
......@@ -52,7 +52,8 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
def _read_training_summary_from_file(self):
"""Reads the training summary from a file."""
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
with tf.io.gfile.GFile(summary_path, 'rb') as reader:
return json.loads(reader.read().decode('utf-8'))
......
......@@ -122,7 +122,8 @@ class XLNetClassifyAccuracy(XLNetClassifyBenchmarkBase):
# Sets timer_callback to None as we do not use it now.
self.timer_callback = None
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
......
......@@ -72,9 +72,9 @@ def _steps_to_run(current_step, steps_per_epoch, steps_per_loop):
return steps_per_loop
def write_txt_summary(training_summary, model_dir):
def write_txt_summary(training_summary, summary_dir):
"""Writes a summary text file to record stats."""
summary_path = os.path.join(model_dir, _SUMMARY_TXT)
summary_path = os.path.join(summary_dir, _SUMMARY_TXT)
with tf.io.gfile.GFile(summary_path, 'wb') as f:
logging.info('Training Summary: \n%s', str(training_summary))
f.write(json.dumps(training_summary, indent=4))
......@@ -221,13 +221,14 @@ def run_customized_training_loop(
]
# Create summary writers
summary_dir = os.path.join(model_dir, 'summaries')
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'summaries/eval'))
os.path.join(summary_dir, 'eval'))
if steps_per_loop >= _MIN_SUMMARY_STEPS:
# Only writes summary when the stats are collected sufficiently over
# enough steps.
train_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'summaries/train'))
os.path.join(summary_dir, 'train'))
else:
train_summary_writer = None
......@@ -415,6 +416,6 @@ def run_customized_training_loop(
train_metrics[0])
training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
write_txt_summary(training_summary, model_dir)
write_txt_summary(training_summary, summary_dir)
return model
......@@ -185,7 +185,8 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
# Two checkpoints should be saved after two epochs.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(model_dir, 'ctl_step_*')))
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'training_summary*')))
tf.io.gfile.glob(
os.path.join(model_dir, 'summaries/training_summary*')))
# Loss and accuracy values should be written into summaries.
self.assertTrue(
......
......@@ -26,10 +26,10 @@ import tensorflow as tf
import typing
def export_bert_model(
model_export_path: typing.Text,
model: tf.keras.Model,
checkpoint_dir: typing.Optional[typing.Text] = None) -> None:
def export_bert_model(model_export_path: typing.Text,
model: tf.keras.Model,
checkpoint_dir: typing.Optional[typing.Text] = None,
restore_model_using_load_weights: bool = False) -> None:
"""Export BERT model for serving which does not include the optimizer.
Arguments:
......@@ -37,6 +37,14 @@ def export_bert_model(
model: Keras model object to export.
checkpoint_dir: Path from which model weights will be loaded, if
specified.
restore_model_using_load_weights: Whether to use checkpoint.restore() API
for custom checkpoint or to use model.load_weights() API.
There are 2 different ways to save checkpoints. One is using
tf.train.Checkpoint and another is using Keras model.save_weights().
Custom training loop implementation uses tf.train.Checkpoint API
and Keras ModelCheckpoint callback internally uses model.save_weights()
API. Since these two API's cannot be used toghether, model loading logic
must be take into account how model checkpoint was saved.
Raises:
ValueError when either model_export_path or model is not specified.
......@@ -47,13 +55,24 @@ def export_bert_model(
raise ValueError('model must be a tf.keras.Model object.')
if checkpoint_dir:
# Restores the model from latest checkpoint.
checkpoint = tf.train.Checkpoint(model=model)
latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file).assert_existing_objects_matched()
# Keras compile/fit() was used to save checkpoint using
# model.save_weights().
if restore_model_using_load_weights:
model_weight_path = os.path.join(checkpoint_dir, 'checkpoint')
assert tf.io.gfile.exists(model_weight_path)
model.load_weights(model_weight_path)
# tf.train.Checkpoint API was used via custom training loop logic.
else:
checkpoint = tf.train.Checkpoint(model=model)
# Restores the model from latest checkpoint.
latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
model.save(model_export_path, include_optimizer=False, save_format='tf')
......
......@@ -92,7 +92,8 @@ def run_bert_classifier(strategy,
initial_lr,
init_checkpoint,
custom_callbacks=None,
run_eagerly=False):
run_eagerly=False,
use_keras_compile_fit=False):
"""Run BERT classifier training using low-level API."""
max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data['num_labels']
......@@ -142,7 +143,7 @@ def run_bert_classifier(strategy,
return tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32)
if FLAGS.use_keras_compile_fit:
if use_keras_compile_fit:
# Start training using Keras compile/fit API.
logging.info('Training using TF 2.0 Keras compile/fit API with '
'distrubuted strategy.')
......@@ -206,9 +207,11 @@ def run_keras_compile_fit(model_dir,
bert_model.compile(optimizer=optimizer, loss=loss_fn, metrics=[metric_fn()])
summary_callback = tf.keras.callbacks.TensorBoard(model_dir)
checkpoint_dir = os.path.join(model_dir, 'model_checkpoint.{epoch:02d}')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir)
summary_dir = os.path.join(model_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
checkpoint_path = os.path.join(model_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True)
if custom_callbacks is not None:
custom_callbacks += [summary_callback, checkpoint_callback]
......@@ -226,12 +229,21 @@ def run_keras_compile_fit(model_dir,
return bert_model
def export_classifier(model_export_path, input_meta_data):
def export_classifier(model_export_path, input_meta_data,
restore_model_using_load_weights):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
restore_model_using_load_weights: Whether to use checkpoint.restore() API
for custom checkpoint or to use model.load_weights() API.
There are 2 different ways to save checkpoints. One is using
tf.train.Checkpoint and another is using Keras model.save_weights().
Custom training loop implementation uses tf.train.Checkpoint API
and Keras ModelCheckpoint callback internally uses model.save_weights()
API. Since these two API's cannot be used toghether, model loading logic
must be take into account how model checkpoint was saved.
Raises:
Export path is not specified, got an empty string or None.
......@@ -243,14 +255,22 @@ def export_classifier(model_export_path, input_meta_data):
classifier_model = bert_models.classifier_model(
bert_config, tf.float32, input_meta_data['num_labels'],
input_meta_data['max_seq_length'])[0]
model_saving_utils.export_bert_model(
model_export_path, model=classifier_model, checkpoint_dir=FLAGS.model_dir)
model_export_path,
model=classifier_model,
checkpoint_dir=FLAGS.model_dir,
restore_model_using_load_weights=restore_model_using_load_weights)
def run_bert(strategy, input_meta_data):
"""Run BERT training."""
if FLAGS.mode == 'export_only':
export_classifier(FLAGS.model_export_path, input_meta_data)
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used.
export_classifier(FLAGS.model_export_path, input_meta_data,
FLAGS.use_keras_compile_fit)
return
if FLAGS.mode != 'train_and_eval':
......@@ -281,11 +301,17 @@ def run_bert(strategy, input_meta_data):
warmup_steps,
FLAGS.learning_rate,
FLAGS.init_checkpoint,
run_eagerly=FLAGS.run_eagerly)
run_eagerly=FLAGS.run_eagerly,
use_keras_compile_fit=FLAGS.use_keras_compile_fit)
if FLAGS.model_export_path:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used.
model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model)
FLAGS.model_export_path,
model=trained_model,
restore_model_using_load_weights=FLAGS.use_keras_compile_fit)
return trained_model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册