提交 26bbda73 编写于 作者: D David Chen 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 284874717
上级 38e48f91
......@@ -35,19 +35,28 @@ class PerfZeroBenchmark(tf.test.Benchmark):
"""
local_flags = None
def __init__(self, output_dir=None, default_flags=None, flag_methods=None):
def __init__(self,
output_dir=None,
default_flags=None,
flag_methods=None,
tpu=None):
"""Initialize class.
Args:
output_dir: Base directory to store all output for the test.
default_flags:
flag_methods:
default_flags: Set of flags to pass to model.
flag_methods: Set of flag methods to run during setup.
tpu: (optional) TPU name to use in a TPU benchmark.
"""
if not output_dir:
output_dir = '/tmp'
self.output_dir = output_dir
self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {}
if tpu:
# TPU models are expected to accept a --tpu=name flag. PerfZero creates
# the TPU at runtime and passes the TPU's name to this flag.
self.default_flags['tpu'] = tpu
def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册