未验证 提交 aee49bbd 编写于 作者: H Hongkun Yu 提交者: GitHub

Merged commit includes the following changes: (#7357)

261202754  by hongkuny<hongkuny@google.com>:

    Use enable_xla flag for classifier and squad, so xla option is exposed to users.

--

PiperOrigin-RevId: 261202754
上级 87542800
......@@ -33,7 +33,6 @@ from official.bert import modeling
from official.bert import run_classifier
from official.bert.benchmark import benchmark_utils
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
# pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_model.ckpt'
......@@ -55,7 +54,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
self.num_steps_per_epoch = None
@flagsaver.flagsaver
def _run_bert_classifier(self, callbacks=None, use_ds=True, enable_xla=False):
def _run_bert_classifier(self, callbacks=None, use_ds=True):
"""Starts BERT classification task."""
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
......@@ -73,8 +72,6 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus)
# TODO(hongkuny): Enable XLA once we are confident with its performance.
keras_utils.set_config_v2(enable_xla)
steps_per_loop = 1
......@@ -119,13 +116,10 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
training_summary_path,
min_accuracy=0,
max_accuracy=1,
use_ds=True,
enable_xla=False):
use_ds=True):
"""Starts BERT performance benchmark test."""
start_time_sec = time.time()
self._run_bert_classifier(
callbacks=[self.timer_callback], use_ds=use_ds, enable_xla=enable_xla)
self._run_bert_classifier(callbacks=[self.timer_callback], use_ds=use_ds)
wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
......@@ -168,9 +162,10 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 4
FLAGS.eval_batch_size = 4
FLAGS.enable_xla = True
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path, enable_xla=True)
self._run_and_report_benchmark(summary_path)
def benchmark_1_gpu_mrpc_no_dist_strat(self):
"""Test BERT model performance with 1 GPU, no distribution strategy."""
......@@ -253,13 +248,11 @@ class BertClassifyAccuracy(BertClassifyBenchmarkBase):
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0.84,
max_accuracy=0.88,
enable_xla=False):
max_accuracy=0.88):
"""Starts BERT accuracy benchmark test."""
start_time_sec = time.time()
self._run_bert_classifier(
callbacks=[self.timer_callback], enable_xla=enable_xla)
self._run_bert_classifier(callbacks=[self.timer_callback])
wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
......@@ -296,9 +289,9 @@ class BertClassifyAccuracy(BertClassifyBenchmarkBase):
"""Run BERT model accuracy test with 8 GPUs with XLA."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc_xla')
FLAGS.enable_xla = True
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path, enable_xla=True)
self._run_and_report_benchmark(summary_path)
if __name__ == '__main__':
......
......@@ -32,7 +32,6 @@ from official.bert import run_squad
from official.bert.benchmark import benchmark_utils
from official.bert.benchmark import squad_evaluate_v1_1
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
# pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_model.ckpt'
......@@ -131,10 +130,8 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
def _run_and_report_benchmark(self,
use_ds=True,
enable_xla=False,
run_eagerly=False):
"""Runs the benchmark and reports various metrics."""
keras_utils.set_config_v2(enable_xla)
start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
wall_time_sec = time.time() - start_time_sec
......@@ -164,8 +161,9 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla_squad')
FLAGS.train_batch_size = 4
FLAGS.enable_xla = True
self._run_and_report_benchmark(enable_xla=True)
self._run_and_report_benchmark()
def benchmark_1_gpu_no_dist_strat(self):
"""Tests BERT SQuAD model performance with 1 GPU without DS."""
......@@ -291,10 +289,8 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
def _run_and_report_benchmark(self,
use_ds=True,
enable_xla=False,
run_eagerly=False):
"""Runs the benchmark and reports various metrics."""
keras_utils.set_config_v2(enable_xla)
start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
self._evaluate_squad()
......@@ -348,8 +344,9 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_xla')
FLAGS.train_batch_size = 32
FLAGS.enable_xla = True
self._run_and_report_benchmark(enable_xla=True)
self._run_and_report_benchmark()
if __name__ == '__main__':
......
......@@ -58,7 +58,7 @@ def define_common_bert_flags():
loss_scale=True,
all_reduce_alg=False,
num_packs=False,
enable_xla=False
enable_xla=True
)
......
......@@ -36,6 +36,7 @@ from official.bert import model_training_utils
from official.bert import modeling
from official.bert import optimization
from official.bert import tpu_lib
from official.utils.misc import keras_utils
flags.DEFINE_enum(
'mode', 'train_and_eval', ['train_and_eval', 'export_only'],
......@@ -174,6 +175,8 @@ def run_bert(strategy, input_meta_data):
if FLAGS.mode != 'train_and_eval':
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla)
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
epochs = FLAGS.num_train_epochs
......
......@@ -37,6 +37,7 @@ from official.bert import optimization
from official.bert import squad_lib
from official.bert import tokenization
from official.bert import tpu_lib
from official.utils.misc import keras_utils
flags.DEFINE_bool('do_train', False, 'Whether to run training.')
flags.DEFINE_bool('do_predict', False, 'Whether to run eval on the dev set.')
......@@ -181,6 +182,8 @@ def train_squad(strategy,
if strategy:
logging.info('Training using customized training loop with distribution'
' strategy.')
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla)
use_float16 = common_flags.use_float16()
if use_float16:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册