提交 eaaf891c 编写于 作者: S Sami Kama

Pass session config to device_lib list_local_devices() call

上级 34beb7ad
......@@ -307,11 +307,15 @@ def _gather_run_info(model_name, dataset_name, run_params, test_id):
"test_id": test_id,
"run_date": datetime.datetime.utcnow().strftime(
_DATE_TIME_FORMAT_PATTERN)}
if "session_config" in run_params:
session_config=run_params["session_config"]
else:
session_config=None
_collect_tensorflow_info(run_info)
_collect_tensorflow_environment_variables(run_info)
_collect_run_params(run_info, run_params)
_collect_cpu_info(run_info)
_collect_gpu_info(run_info)
_collect_gpu_info(run_info,session_config)
_collect_memory_info(run_info)
_collect_test_environment(run_info)
return run_info
......@@ -385,10 +389,10 @@ def _collect_cpu_info(run_info):
tf.logging.warn("'cpuinfo' not imported. CPU info will not be logged.")
def _collect_gpu_info(run_info):
def _collect_gpu_info(run_info,session_config=None):
"""Collect local GPU information by TF device library."""
gpu_info = {}
local_device_protos = device_lib.list_local_devices()
local_device_protos = device_lib.list_local_devices(session_config)
gpu_info["count"] = len([d for d in local_device_protos
if d.device_type == "GPU"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册