提交 a629af4c 编写于 作者: A A. Unique TensorFlower

Merge pull request #7531 from vinhngx:amp_bert

PiperOrigin-RevId: 267629422
......@@ -227,6 +227,43 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_1_gpu_amp_mrpc_no_dist_strat(self):
"""Performance for 1 GPU no DS with automatic mixed precision."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir(
'benchmark_1_gpu_amp_mrpc_no_dist_strat')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 4
FLAGS.eval_batch_size = 4
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
def benchmark_8_gpu_amp_mrpc(self):
"""Test BERT model performance with 8 GPUs with automatic mixed precision.
"""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_mrpc')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 32
FLAGS.eval_batch_size = 32
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
class BertClassifyAccuracy(BertClassifyBenchmarkBase):
"""Short accuracy test for BERT model.
......
......@@ -281,6 +281,42 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self._run_and_report_benchmark()
def benchmark_1_gpu_amp(self):
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp_squad')
FLAGS.train_batch_size = 4
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark()
def benchmark_4_gpu_amp(self):
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
self._setup()
self.num_gpus = 4
FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_amp_squad')
FLAGS.train_batch_size = 16
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_squad')
FLAGS.train_batch_size = 32
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark()
class BertSquadAccuracy(BertSquadBenchmarkBase):
"""Short accuracy test for BERT SQuAD model.
......
......@@ -69,7 +69,8 @@ def define_common_bert_flags():
loss_scale=True,
all_reduce_alg=False,
num_packs=False,
enable_xla=True
enable_xla=True,
fp16_implementation=True,
)
......
......@@ -111,11 +111,20 @@ def run_customized_training(strategy,
drop_remainder=False)
def _get_classifier_model():
"""Gets a classifier model."""
classifier_model, core_model = (
bert_models.classifier_model(bert_config, tf.float32, num_classes,
max_seq_length))
classifier_model.optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
classifier_model.optimizer)
return classifier_model, core_model
loss_fn = get_loss_fn(
......
......@@ -123,10 +123,19 @@ def run_customized_training(strategy,
train_batch_size, strategy)
def _get_pretrain_model():
"""Gets a pretraining model."""
pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq)
pretrain_model.optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
pretrain_model.optimizer)
return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop(
......
......@@ -226,6 +226,14 @@ def train_squad(strategy,
squad_model.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
squad_model.optimizer)
return squad_model, core_model
# The original BERT model does not scale the loss by
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册