提交 a128dee2 编写于 作者: E Eugene Brevdo 提交者: TensorFlower Gardener

Clean up benchmark.py after previous modifications and add unit test.

Change: 119549296
上级 45b208d1
......@@ -103,6 +103,19 @@ class BenchmarkTest(tf.test.TestCase):
self.assertTrue(_ran_somebenchmark_2[0])
self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
_ran_somebenchmark_1[0] = False
_ran_somebenchmark_2[0] = False
_ran_somebenchmark_but_shouldnt[0] = False
# Test running a specific method of SomeRandomBenchmark
if benchmark.TEST_REPORTER_TEST_ENV in os.environ:
del os.environ[benchmark.TEST_REPORTER_TEST_ENV]
benchmark._run_benchmarks("SomeRandom.*1$")
self.assertTrue(_ran_somebenchmark_1[0])
self.assertFalse(_ran_somebenchmark_2[0])
self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
def testReportingBenchmark(self):
tempdir = tf.test.get_temp_dir()
try:
......
......@@ -245,27 +245,13 @@ class TensorFlowBenchmark(Benchmark):
name=name)
def _run_specific_benchmark(benchmark_class):
benchmark = benchmark_class()
attrs = dir(benchmark)
# Only run methods of this class whose names start with "benchmark"
for attr in attrs:
if not attr.startswith("benchmark"):
continue
benchmark_fn = getattr(benchmark, attr)
if not callable(benchmark_fn):
continue
# Call this benchmark method
benchmark_fn()
def _run_benchmarks(regex):
"""Run benchmarks that match regex `regex`.
This function goes through the global benchmark registry, and matches
benchmark **classe names** of the form "module.name.BenchmarkClass" to
the given regex. If a class matches, all of its benchmark methods
are run.
benchmark class and method names of the form
`module.name.BenchmarkClass.benchmarkMethod` to the given regex.
If a method matches, it is run.
Args:
regex: The string regular expression to match Benchmark classes against.
......@@ -275,18 +261,24 @@ def _run_benchmarks(regex):
# Match benchmarks in registry against regex
for benchmark in registry:
benchmark_name = "%s.%s" % (benchmark.__module__, benchmark.__name__)
attrs = dir(benchmark)
# Don't instantiate the benchmark class unless necessary
benchmark_instance = None
benchmark_class = benchmark()
attrs = dir(benchmark_class)
for attr in attrs:
if not attr.startswith("benchmark"):
continue
benchmark_fn = getattr(benchmark_class, attr)
if not callable(benchmark_fn):
candidate_benchmark_fn = getattr(benchmark, attr)
if not callable(candidate_benchmark_fn):
continue
full_benchmark_name = "%s.%s" % (benchmark_name, attr)
if regex == "all" or re.search(regex, full_benchmark_name):
benchmark_fn()
# Instantiate the class if it hasn't been instantiated
benchmark_instance = benchmark_instance or benchmark()
# Get the method tied to the class
instance_benchmark_fn = getattr(benchmark_instance, attr)
# Call the instance method
instance_benchmark_fn()
def benchmarks_main(true_main):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册