未验证 提交 e7284deb 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

add download scalars (#879)

* add download scalars

add content-type for scalar data

* delete _register_reader
Co-authored-by: Nwuzewu <wuzewu@baidu.com>
上级 cf66d953
......@@ -71,14 +71,15 @@ class LogReader(object):
self.file_readers = {}
# {'run': {'scalar': {'tag1': data, 'tag2': data}}}
self._log_datas = collections.defaultdict(lambda: collections.defaultdict(lambda: collections.defaultdict(list)))
if file_path:
self._log_data = collections.defaultdict(lambda: collections.defaultdict(list))
self.get_file_reader(file_path=file_path)
remain = self.get_remain()
self.read_log_data(remain=remain)
components_name = components.keys()
for name in components_name:
exec("self.get_%s=partial(self.get_data, '%s')" % (name, name))
elif logdir:
......@@ -108,6 +109,14 @@ class LogReader(object):
def logdir(self):
return self.dir
def parsing_from_proto(self, component, proto_datas):
data = []
if 'scalar' == component:
for item in proto_datas:
data.append([item.id, item.tag, item.timestamp, item.value])
return data
return proto_datas
def _get_log_tags(self):
component_keys = self._log_data.keys()
log_tags = {}
......@@ -118,6 +127,18 @@ class LogReader(object):
return log_tags
def get_log_data(self, component, run, tag):
if run in self._log_datas.keys() and component in self._log_datas[run].keys() and tag in self._log_datas[run][component].keys():
return self._log_datas[run][component][tag]
else:
file_path = bfile.join(run, self.walks[run])
reader = self._get_file_reader(file_path=file_path, update=False)
remain = self.get_remain(reader=reader)
data = self.read_log_data(remain=remain, update=False)[component][tag]
data = self.parsing_from_proto(component, data)
self._log_datas[run][component][tag] = data
return data
def get_tags(self):
return self._get_log_tags()
......@@ -209,7 +230,7 @@ class LogReader(object):
filepath = bfile.join(dir, log)
if filepath not in self.readers.keys():
self._register_reader(filepath, dir)
self.register_reader(filepath, dir)
self.reader = self.readers[filepath]
return self.reader
......@@ -221,15 +242,25 @@ class LogReader(object):
Args:
file_path: Vdl log file path.
"""
self._register_reader(file_path)
return self._get_file_reader(file_path, True)
def _get_file_reader(self, file_path, update=True):
if update:
self.register_reader(file_path)
self.reader = self.readers[file_path]
self.reader.dir = file_path
return self.reader
else:
reader = RecordReader(filepath=file_path)
return reader
def _register_reader(self, path, dir=None):
def register_reader(self, path, dir=None, update=True):
if update:
if path not in list(self.readers.keys()):
reader = RecordReader(filepath=path, dir=dir)
self.readers[path] = reader
else:
pass
def register_readers(self, update=False):
"""Register all readers for all vdl log files.
......@@ -240,7 +271,7 @@ class LogReader(object):
self.logs(update)
for dir, path in self.walks.items():
filepath = bfile.join(dir, path)
self._register_reader(filepath, dir)
self.register_reader(filepath, dir)
def add_remain(self):
"""Add remain data to data_manager.
......@@ -257,22 +288,26 @@ class LogReader(object):
self.data_manager.add_item(component, self.reader.dir, tag,
record)
def get_remain(self):
def get_remain(self, reader=None):
"""Get all remain data by self.reader.
"""
if self.reader is None:
if self.reader is None and reader is None:
raise RuntimeError("Please specify log path!")
return self.reader.get_remain()
return self.reader.get_remain() if reader is None else reader.get_remain()
def read_log_data(self, remain):
def read_log_data(self, remain, update=True):
"""Parse data from log file without sampling.
Args:
remain: Raw data from log file.
"""
_log_data = collections.defaultdict(lambda: collections.defaultdict(list))
for item in remain:
component, dir, tag, record = self.parse_from_bin(item)
self._log_data[component][tag].append(record)
_log_data[component][tag].append(record)
if update:
self._log_data = _log_data
return _log_data
@property
def log_data(self):
......
......@@ -118,6 +118,11 @@ class Api(object):
key = os.path.join('data/plugin/scalars/scalars', run, tag)
return self._get_with_retry(key, lib.get_scalar, run, tag)
@result('text/tab-separated-values')
def scalar_data(self, run, tag):
key = os.path.join('data/plugin/scalars/data', run, tag)
return self._get_with_retry(key, lib.get_scalar_data, run, tag)
@result()
def image_list(self, mode, tag):
key = os.path.join('data/plugin/images/images', mode, tag)
......@@ -199,6 +204,7 @@ def create_api_call(logdir, model, cache_timeout):
'histogram/tags': (api.histogram_tags, []),
'pr-curve/tags': (api.pr_curve_tags, []),
'scalar/list': (api.scalar_list, ['run', 'tag']),
'scalar/data': (api.scalar_data, ['run', 'tag']),
'image/list': (api.image_list, ['run', 'tag']),
'image/image': (api.image_image, ['run', 'tag', 'index']),
'audio/list': (api.audio_list, ['run', 'tag']),
......
......@@ -125,6 +125,18 @@ def get_scalar(log_reader, run, tag):
return results
def get_scalar_data(log_reader, run, tag):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data()
result = log_reader.get_log_data('scalar', run, tag)
with io.StringIO() as fp:
csv_writer = csv.writer(fp, delimiter='\t')
csv_writer.writerow(['id', 'tag', 'timestamp', 'value'])
csv_writer.writerows(result)
result = fp.getvalue()
return result
def get_image_tag_steps(log_reader, run, tag):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册