From f175abc3457b6c8aec01a8b5c53adc8de5a4111d Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Wed, 24 Oct 2018 16:05:47 -0700 Subject: [PATCH] Move version check to a function (#5601) * move version check to a function * delint * tweak pip check * delint --- official/utils/logs/mlperf_helper.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/official/utils/logs/mlperf_helper.py b/official/utils/logs/mlperf_helper.py index 832bd2b78..aea903167 100644 --- a/official/utils/logs/mlperf_helper.py +++ b/official/utils/logs/mlperf_helper.py @@ -86,19 +86,22 @@ def unparse_line(parsed_line): # type: (ParsedLine) -> str def get_mlperf_log(): """Shielded import of mlperf_log module.""" try: - import pkg_resources import mlperf_compliance - version = pkg_resources.get_distribution("mlperf_compliance") - version = tuple(int(i) for i in version.version.split(".")) - if version < _MIN_VERSION: - tf.logging.warning( - "mlperf_compliance is version {}, must be at least version {}".format( - ".".join([str(i) for i in version]), - ".".join([str(i) for i in _MIN_VERSION]))) - raise ImportError - - mlperf_log = mlperf_compliance.mlperf_log + def test_mlperf_log_pip_version(): + """Check that mlperf_compliance is up to date.""" + import pkg_resources + version = pkg_resources.get_distribution("mlperf_compliance") + version = tuple(int(i) for i in version.version.split(".")) + if version < _MIN_VERSION: + tf.logging.warning( + "mlperf_compliance is version {}, must be >= {}".format( + ".".join([str(i) for i in version]), + ".".join([str(i) for i in _MIN_VERSION]))) + raise ImportError + return mlperf_compliance.mlperf_log + + mlperf_log = test_mlperf_log_pip_version() except ImportError: mlperf_log = None -- GitLab