提交 0612e190 编写于 作者: D David Chen 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 286087529
上级 7e67dbbc
......@@ -18,17 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
# pylint: disable=g-bad-import-order
import numpy as np
from absl import flags
from absl.testing import flagsaver
import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order
from official.utils.flags import core as flags_core
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
FLAGS = flags.FLAGS
......@@ -59,34 +58,20 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
return self.batch_start_times[0] - program_start_time
class BertBenchmarkBase(tf.test.Benchmark):
class BertBenchmarkBase(PerfZeroBenchmark):
"""Base class to hold methods common to test classes."""
local_flags = None
def __init__(self, output_dir=None):
super(BertBenchmarkBase, self).__init__(output_dir=output_dir)
self.num_gpus = 8
if not output_dir:
output_dir = '/tmp'
self.output_dir = output_dir
self.timer_callback = None
def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log."""
return os.path.join(self.output_dir, folder_name)
def _setup(self):
"""Sets up and resets flags before each test."""
super(BertBenchmarkBase, self)._setup()
self.timer_callback = BenchmarkTimerCallback()
if BertBenchmarkBase.local_flags is None:
# Loads flags to get defaults to then override. List cannot be empty.
flags.FLAGS(['foo'])
saved_flag_values = flagsaver.save_flag_values()
BertBenchmarkBase.local_flags = saved_flag_values
else:
flagsaver.restore_flag_values(BertBenchmarkBase.local_flags)
def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy):
"""Report benchmark results by writing to local protobuf file.
......
......@@ -52,6 +52,10 @@ FLAGS = flags.FLAGS
class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None, tpu=None):
super(BertSquadBenchmarkBase, self).__init__(output_dir=output_dir)
self.tpu = tpu
def _read_training_summary_from_file(self):
"""Reads the training summary from a file."""
summary_path = os.path.join(FLAGS.model_dir,
......@@ -78,9 +82,13 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
def _get_distribution_strategy(self, use_ds=True):
"""Gets the distribution strategy."""
return distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus)
if self.tpu:
return distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu)
else:
return distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus)
@flagsaver.flagsaver
def _train_squad(self, use_ds=True, run_eagerly=False):
......@@ -117,11 +125,12 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
Tests BERT SQuAD performance in different GPU configurations.
The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format.
`benchmark_(number of gpus)_gpu` format for GPUs and
`benchmark_(topology)_tpu` format for TPUs.
"""
def __init__(self, output_dir=TMP_DIR, **kwargs):
super(BertSquadBenchmarkReal, self).__init__(output_dir=output_dir)
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
super(BertSquadBenchmarkReal, self).__init__(output_dir=output_dir, tpu=tpu)
def _setup(self):
"""Sets up the benchmark and SQuAD flags."""
......@@ -322,16 +331,26 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self._run_and_report_benchmark()
def benchmark_2x2_tpu(self):
"""Tests BERT SQuAD model performance with 2x2 TPU."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.train_batch_size = 48
self._run_and_report_benchmark()
class BertSquadAccuracy(BertSquadBenchmarkBase):
"""Short accuracy test for BERT SQuAD model.
Tests BERT SQuAD accuracy. The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format.
`benchmark_(number of gpus)_gpu` format for GPUs and
`benchmark_(topology)_tpu` format for TPUs.
"""
def __init__(self, output_dir=None, **kwargs):
super(BertSquadAccuracy, self).__init__(output_dir=output_dir)
def __init__(self, output_dir=None, tpu=None, **kwargs):
super(BertSquadAccuracy, self).__init__(output_dir=output_dir, tpu=tpu)
def _setup(self):
"""Sets up the benchmark and SQuAD flags."""
......@@ -407,6 +426,15 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
self._run_and_report_benchmark()
def benchmark_2x2_tpu(self):
"""Tests BERT SQuAD model accuracy with 2x2 TPU."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.train_batch_size = 48
self._run_and_report_benchmark()
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册