From a128dee224c19d60dd9b29a5ab54c9ec2a2c9470 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Mon, 11 Apr 2016 10:05:51 -0800 Subject: [PATCH] Clean up benchmark.py after previous modifications and add unit test. Change: 119549296 --- .../python/kernel_tests/benchmark_test.py | 13 +++++++ tensorflow/python/platform/benchmark.py | 36 ++++++++----------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/tensorflow/python/kernel_tests/benchmark_test.py b/tensorflow/python/kernel_tests/benchmark_test.py index 2cddbe98724..9a96038618e 100644 --- a/tensorflow/python/kernel_tests/benchmark_test.py +++ b/tensorflow/python/kernel_tests/benchmark_test.py @@ -103,6 +103,19 @@ class BenchmarkTest(tf.test.TestCase): self.assertTrue(_ran_somebenchmark_2[0]) self.assertFalse(_ran_somebenchmark_but_shouldnt[0]) + _ran_somebenchmark_1[0] = False + _ran_somebenchmark_2[0] = False + _ran_somebenchmark_but_shouldnt[0] = False + + # Test running a specific method of SomeRandomBenchmark + if benchmark.TEST_REPORTER_TEST_ENV in os.environ: + del os.environ[benchmark.TEST_REPORTER_TEST_ENV] + benchmark._run_benchmarks("SomeRandom.*1$") + + self.assertTrue(_ran_somebenchmark_1[0]) + self.assertFalse(_ran_somebenchmark_2[0]) + self.assertFalse(_ran_somebenchmark_but_shouldnt[0]) + def testReportingBenchmark(self): tempdir = tf.test.get_temp_dir() try: diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py index af6ea10ef37..65c8d100d29 100644 --- a/tensorflow/python/platform/benchmark.py +++ b/tensorflow/python/platform/benchmark.py @@ -245,27 +245,13 @@ class TensorFlowBenchmark(Benchmark): name=name) -def _run_specific_benchmark(benchmark_class): - benchmark = benchmark_class() - attrs = dir(benchmark) - # Only run methods of this class whose names start with "benchmark" - for attr in attrs: - if not attr.startswith("benchmark"): - continue - benchmark_fn = getattr(benchmark, attr) - if not callable(benchmark_fn): - continue - # Call this benchmark method - benchmark_fn() - - def _run_benchmarks(regex): """Run benchmarks that match regex `regex`. This function goes through the global benchmark registry, and matches - benchmark **classe names** of the form "module.name.BenchmarkClass" to - the given regex. If a class matches, all of its benchmark methods - are run. + benchmark class and method names of the form + `module.name.BenchmarkClass.benchmarkMethod` to the given regex. + If a method matches, it is run. Args: regex: The string regular expression to match Benchmark classes against. @@ -275,18 +261,24 @@ def _run_benchmarks(regex): # Match benchmarks in registry against regex for benchmark in registry: benchmark_name = "%s.%s" % (benchmark.__module__, benchmark.__name__) + attrs = dir(benchmark) + # Don't instantiate the benchmark class unless necessary + benchmark_instance = None - benchmark_class = benchmark() - attrs = dir(benchmark_class) for attr in attrs: if not attr.startswith("benchmark"): continue - benchmark_fn = getattr(benchmark_class, attr) - if not callable(benchmark_fn): + candidate_benchmark_fn = getattr(benchmark, attr) + if not callable(candidate_benchmark_fn): continue full_benchmark_name = "%s.%s" % (benchmark_name, attr) if regex == "all" or re.search(regex, full_benchmark_name): - benchmark_fn() + # Instantiate the class if it hasn't been instantiated + benchmark_instance = benchmark_instance or benchmark() + # Get the method tied to the class + instance_benchmark_fn = getattr(benchmark_instance, attr) + # Call the instance method + instance_benchmark_fn() def benchmarks_main(true_main): -- GitLab