From 68146271959c20dd70c37f5cec7315b3c8521fbd Mon Sep 17 00:00:00 2001 From: David Chen Date: Tue, 17 Dec 2019 09:04:41 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 285991432 --- official/benchmark/bert_benchmark.py | 2 ++ official/benchmark/bert_benchmark_utils.py | 30 +++++++++++++--------- official/benchmark/bert_squad_benchmark.py | 1 + 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/official/benchmark/bert_benchmark.py b/official/benchmark/bert_benchmark.py index cc495504b..89a4e4367 100644 --- a/official/benchmark/bert_benchmark.py +++ b/official/benchmark/bert_benchmark.py @@ -152,6 +152,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase): # Since we do not load from any pretrained checkpoints, we ignore all # accuracy metrics. summary.pop('eval_metrics', None) + summary['start_time_sec'] = start_time_sec + super(BertClassifyBenchmarkReal, self)._report_benchmark( stats=summary, wall_time_sec=wall_time_sec, diff --git a/official/benchmark/bert_benchmark_utils.py b/official/benchmark/bert_benchmark_utils.py index ab606e2c3..d281d9290 100644 --- a/official/benchmark/bert_benchmark_utils.py +++ b/official/benchmark/bert_benchmark_utils.py @@ -38,24 +38,25 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback): def __init__(self, num_batches_to_skip=10): super(BenchmarkTimerCallback, self).__init__() - self.num_batches_to_skip = num_batches_to_skip - self.timer_records = [] - self.start_time = None + self.batch_start_times = {} + self.batch_stop_times = {} def on_batch_begin(self, batch, logs=None): - if batch < self.num_batches_to_skip: - return - self.start_time = time.time() + self.batch_start_times[batch] = time.time() def on_batch_end(self, batch, logs=None): - if batch < self.num_batches_to_skip: - return + self.batch_stop_times[batch] = time.time() - assert self.start_time - self.timer_records.append(time.time() - self.start_time) + def get_examples_per_sec(self, batch_size, num_batches_to_skip=10): + batch_durations = [] + for batch in self.batch_start_times: + if batch in self.batch_stop_times and batch >= num_batches_to_skip: + batch_durations.append(self.batch_stop_times[batch] - + self.batch_start_times[batch]) + return batch_size / np.mean(batch_durations) - def get_examples_per_sec(self, batch_size): - return batch_size / np.mean(self.timer_records) + def get_startup_time(self, program_start_time): + return self.batch_start_times[0] - program_start_time class BertBenchmarkBase(tf.test.Benchmark): @@ -113,6 +114,11 @@ class BertBenchmarkBase(tf.test.Benchmark): 'name': 'exp_per_second', 'value': 0.0, }) + if self.timer_callback and 'start_time_sec' in stats: + metrics.append({ + 'name': 'startup_time', + 'value': self.timer_callback.get_startup_time(stats['start_time_sec']) + }) if 'eval_metrics' in stats: metrics.append({ diff --git a/official/benchmark/bert_squad_benchmark.py b/official/benchmark/bert_squad_benchmark.py index 7e90ddbb0..e9b0da999 100644 --- a/official/benchmark/bert_squad_benchmark.py +++ b/official/benchmark/bert_squad_benchmark.py @@ -144,6 +144,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): wall_time_sec = time.time() - start_time_sec summary = self._read_training_summary_from_file() + summary['start_time_sec'] = start_time_sec super(BertSquadBenchmarkReal, self)._report_benchmark( stats=summary, -- GitLab