未验证 提交 812c142b 编写于 作者: C chenjian 提交者: GitHub

fix a bug when device info not exists in json format (#1166)

上级 193dcc8d
...@@ -265,10 +265,16 @@ class ProfilerResult: ...@@ -265,10 +265,16 @@ class ProfilerResult:
def parse_json(self, json_data): def parse_json(self, json_data):
self.schema_version = json_data['schemaVersion'] self.schema_version = json_data['schemaVersion']
self.span_idx = json_data['span_indx'] self.span_idx = json_data['span_indx']
try:
self.device_infos = { self.device_infos = {
device_info['id']: device_info device_info['id']: device_info
for device_info in json_data['deviceProperties'] for device_info in json_data['deviceProperties']
} }
except Exception:
print(
"paddlepaddle-gpu version is needed to get GPU device informations."
)
self.device_infos = {}
hostnodes = [] hostnodes = []
runtimenodes = [] runtimenodes = []
devicenodes = [] devicenodes = []
......
...@@ -1767,6 +1767,8 @@ class DistributedProfilerData: ...@@ -1767,6 +1767,8 @@ class DistributedProfilerData:
data = [] data = []
for profile_data in self.profile_datas: for profile_data in self.profile_datas:
device_infos = profile_data.device_infos device_infos = profile_data.device_infos
if not device_infos:
return data
gpu_id = int(next(iter(profile_data.gpu_ids))) gpu_id = int(next(iter(profile_data.gpu_ids)))
data.append({ data.append({
'worker_name': 'worker_name':
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ======================================================================= # =======================================================================
import os import os
import re import re
from threading import Lock
from threading import Thread from threading import Thread
import packaging.version import packaging.version
...@@ -28,6 +29,7 @@ from .run_manager import RunManager ...@@ -28,6 +29,7 @@ from .run_manager import RunManager
from visualdl.io import bfile from visualdl.io import bfile
_name_pattern = re.compile(r"(.+)_time_(.+)\.paddle_trace\.((pb)|(json))") _name_pattern = re.compile(r"(.+)_time_(.+)\.paddle_trace\.((pb)|(json))")
_lock = Lock()
def is_VDLProfiler_file(path): def is_VDLProfiler_file(path):
...@@ -130,8 +132,10 @@ class ProfilerReader(object): ...@@ -130,8 +132,10 @@ class ProfilerReader(object):
self.run_managers[run] = RunManager(run) self.run_managers[run] = RunManager(run)
self.run_managers[run].set_all_filenames(filenames) self.run_managers[run].set_all_filenames(filenames)
for filename in filenames: for filename in filenames:
with _lock: # we add this to prevent parallel requests for handling a file multiple times
if self.run_managers[run].has_handled(filename): if self.run_managers[run].has_handled(filename):
continue continue
self.run_managers[run].handled_filenames.add(filename)
self._read_data(run, filename) self._read_data(run, filename)
return list(self.walks.keys()) return list(self.walks.keys())
......
...@@ -202,6 +202,8 @@ class ProfilerApi(object): ...@@ -202,6 +202,8 @@ class ProfilerApi(object):
run_manager = self._reader.get_run_manager(run) run_manager = self._reader.get_run_manager(run)
distributed_profiler_data = run_manager.get_distributed_profiler_data( distributed_profiler_data = run_manager.get_distributed_profiler_data(
span) span)
if distributed_profiler_data is None:
return
return distributed_profiler_data.get_distributed_steps() return distributed_profiler_data.get_distributed_steps()
@result() @result()
...@@ -209,6 +211,8 @@ class ProfilerApi(object): ...@@ -209,6 +211,8 @@ class ProfilerApi(object):
run_manager = self._reader.get_run_manager(run) run_manager = self._reader.get_run_manager(run)
distributed_profiler_data = run_manager.get_distributed_profiler_data( distributed_profiler_data = run_manager.get_distributed_profiler_data(
span) span)
if distributed_profiler_data is None:
return
return distributed_profiler_data.get_distributed_histogram( return distributed_profiler_data.get_distributed_histogram(
step, time_unit) step, time_unit)
......
...@@ -104,11 +104,8 @@ class RunManager: ...@@ -104,11 +104,8 @@ class RunManager:
return return
def join(self): def join(self):
if self.has_join:
return
for thread in self.threads.values(): for thread in self.threads.values():
thread.join() thread.join()
self.has_join = True
distributed_profiler_data = defaultdict(list) distributed_profiler_data = defaultdict(list)
for worker_name, span_data in self.profiler_data.items(): for worker_name, span_data in self.profiler_data.items():
for span_idx, profiler_data in span_data.items(): for span_idx, profiler_data in span_data.items():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册