提交 a617b671 编写于 作者: D David Chen 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 285322154
上级 0f5bdd0e
......@@ -51,10 +51,11 @@ FLAGS = flags.FLAGS
class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None):
def __init__(self, output_dir=None, tpu=None):
super(BertClassifyBenchmarkBase, self).__init__(output_dir)
self.num_epochs = None
self.num_steps_per_epoch = None
self.tpu = tpu
@flagsaver.flagsaver
def _run_bert_classifier(self, callbacks=None, use_ds=True):
......@@ -72,9 +73,13 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
warmup_steps = int(epochs * steps_per_epoch * 0.1)
eval_steps = int(
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus)
if self.tpu:
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu)
else:
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus)
steps_per_loop = 1
......@@ -109,13 +114,15 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
"""Short benchmark performance tests for BERT model.
Tests BERT classification performance in different GPU configurations.
Tests BERT classification performance in different GPU, TPU configurations.
The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu_(dataset type)` format.
`benchmark_(number of gpus)_gpu_(dataset type)` for GPUs and
`benchmark_(topology)_tpu_(dataset type)` for TPUs.
"""
def __init__(self, output_dir=TMP_DIR, **kwargs):
super(BertClassifyBenchmarkReal, self).__init__(output_dir=output_dir)
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
super(BertClassifyBenchmarkReal, self).__init__(
output_dir=output_dir, tpu=tpu)
self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
......@@ -289,6 +296,22 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
def benchmark_2x2_tpu_mrpc(self):
"""Test BERT model performance with 2x2 TPU."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mrpc')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 32
FLAGS.eval_batch_size = 32
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
class BertClassifyAccuracy(BertClassifyBenchmarkBase):
"""Short accuracy test for BERT model.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册