提交 68146271 编写于 作者: D David Chen 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 285991432
上级 7a69f962
......@@ -152,6 +152,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
# Since we do not load from any pretrained checkpoints, we ignore all
# accuracy metrics.
summary.pop('eval_metrics', None)
summary['start_time_sec'] = start_time_sec
super(BertClassifyBenchmarkReal, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
......
......@@ -38,24 +38,25 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
def __init__(self, num_batches_to_skip=10):
super(BenchmarkTimerCallback, self).__init__()
self.num_batches_to_skip = num_batches_to_skip
self.timer_records = []
self.start_time = None
self.batch_start_times = {}
self.batch_stop_times = {}
def on_batch_begin(self, batch, logs=None):
if batch < self.num_batches_to_skip:
return
self.start_time = time.time()
self.batch_start_times[batch] = time.time()
def on_batch_end(self, batch, logs=None):
if batch < self.num_batches_to_skip:
return
self.batch_stop_times[batch] = time.time()
assert self.start_time
self.timer_records.append(time.time() - self.start_time)
def get_examples_per_sec(self, batch_size, num_batches_to_skip=10):
batch_durations = []
for batch in self.batch_start_times:
if batch in self.batch_stop_times and batch >= num_batches_to_skip:
batch_durations.append(self.batch_stop_times[batch] -
self.batch_start_times[batch])
return batch_size / np.mean(batch_durations)
def get_examples_per_sec(self, batch_size):
return batch_size / np.mean(self.timer_records)
def get_startup_time(self, program_start_time):
return self.batch_start_times[0] - program_start_time
class BertBenchmarkBase(tf.test.Benchmark):
......@@ -113,6 +114,11 @@ class BertBenchmarkBase(tf.test.Benchmark):
'name': 'exp_per_second',
'value': 0.0,
})
if self.timer_callback and 'start_time_sec' in stats:
metrics.append({
'name': 'startup_time',
'value': self.timer_callback.get_startup_time(stats['start_time_sec'])
})
if 'eval_metrics' in stats:
metrics.append({
......
......@@ -144,6 +144,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
summary['start_time_sec'] = start_time_sec
super(BertSquadBenchmarkReal, self)._report_benchmark(
stats=summary,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册