未验证 提交 85cb577d 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

add fake run (#733)

上级 0ecfe3bf
...@@ -89,6 +89,7 @@ class LogWriter(logdir=None, ...@@ -89,6 +89,7 @@ class LogWriter(logdir=None,
| flush_secs | int | 日志记录消息队列的最大缓存时间,达到此时间则立即写入到日志文件 | | flush_secs | int | 日志记录消息队列的最大缓存时间,达到此时间则立即写入到日志文件 |
| filename_suffix | string | 为默认的日志文件名添加后缀 | | filename_suffix | string | 为默认的日志文件名添加后缀 |
| write_to_disk | boolean | 是否写入到磁盘 | | write_to_disk | boolean | 是否写入到磁盘 |
| display_name | string | 实际展示在面板中的`runs`名称,可隐藏路径信息 |
#### 示例 #### 示例
......
...@@ -36,6 +36,27 @@ def scalar(tag, value, step, walltime=None): ...@@ -36,6 +36,27 @@ def scalar(tag, value, step, walltime=None):
]) ])
def meta_data(tag='meta_data_tag', display_name="", step=0, walltime=None):
"""Package data to one meta_data.
Meta data is info for one record file, include `display_name` etc.
Args:
tag (string): Data identifier
display_name (string): Replace
step (int): Step of scalar
walltime (int): Wall time of scalar
Return:
Package with format of record_pb2.Record
"""
meta = Record.MetaData(display_name=display_name)
return Record(values=[
Record.Value(id=step, tag=tag, timestamp=walltime,
meta_data=meta)
])
def imgarray2bytes(np_array): def imgarray2bytes(np_array):
"""Convert image ndarray to bytes. """Convert image ndarray to bytes.
......
...@@ -43,6 +43,10 @@ message Record { ...@@ -43,6 +43,10 @@ message Record {
repeated double recall = 6; repeated double recall = 6;
} }
message MetaData {
string display_name = 1;
}
message Value { message Value {
int64 id = 1; int64 id = 1;
string tag = 2; string tag = 2;
...@@ -54,6 +58,7 @@ message Record { ...@@ -54,6 +58,7 @@ message Record {
Embeddings embeddings = 7; Embeddings embeddings = 7;
Histogram histogram = 8; Histogram histogram = 8;
PRCurve pr_curve = 9; PRCurve pr_curve = 9;
MetaData meta_data = 10;
} }
} }
......
...@@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='visualdl', package='visualdl',
syntax='proto3', syntax='proto3',
serialized_options=None, serialized_options=None,
serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xe2\x06\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1a<\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1al\n\x07PRCurve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x11\n\tprecision\x18\x05 \x03(\x01\x12\x0e\n\x06recall\x18\x06 \x03(\x01\x1a\xb5\x02\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x12,\n\x08pr_curve\x18\t \x01(\x0b\x32\x18.visualdl.Record.PRCurveH\x00\x42\x0b\n\tone_valueb\x06proto3' serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xb4\x07\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1a<\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1al\n\x07PRCurve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x11\n\tprecision\x18\x05 \x03(\x01\x12\x0e\n\x06recall\x18\x06 \x03(\x01\x1a \n\x08MetaData\x12\x14\n\x0c\x64isplay_name\x18\x01 \x01(\t\x1a\xe5\x02\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x12,\n\x08pr_curve\x18\t \x01(\x0b\x32\x18.visualdl.Record.PRCurveH\x00\x12.\n\tmeta_data\x18\n \x01(\x0b\x32\x19.visualdl.Record.MetaDataH\x00\x42\x0b\n\tone_valueb\x06proto3'
) )
...@@ -318,6 +318,36 @@ _RECORD_PRCURVE = _descriptor.Descriptor( ...@@ -318,6 +318,36 @@ _RECORD_PRCURVE = _descriptor.Descriptor(
serialized_end=581, serialized_end=581,
) )
_RECORD_METADATA = _descriptor.Descriptor(
name='MetaData',
full_name='visualdl.Record.MetaData',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='display_name', full_name='visualdl.Record.MetaData.display_name', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=583,
serialized_end=615,
)
_RECORD_VALUE = _descriptor.Descriptor( _RECORD_VALUE = _descriptor.Descriptor(
name='Value', name='Value',
full_name='visualdl.Record.Value', full_name='visualdl.Record.Value',
...@@ -388,6 +418,13 @@ _RECORD_VALUE = _descriptor.Descriptor( ...@@ -388,6 +418,13 @@ _RECORD_VALUE = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='meta_data', full_name='visualdl.Record.Value.meta_data', index=9,
number=10, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
], ],
extensions=[ extensions=[
], ],
...@@ -403,8 +440,8 @@ _RECORD_VALUE = _descriptor.Descriptor( ...@@ -403,8 +440,8 @@ _RECORD_VALUE = _descriptor.Descriptor(
name='one_value', full_name='visualdl.Record.Value.one_value', name='one_value', full_name='visualdl.Record.Value.one_value',
index=0, containing_type=None, fields=[]), index=0, containing_type=None, fields=[]),
], ],
serialized_start=584, serialized_start=618,
serialized_end=893, serialized_end=975,
) )
_RECORD = _descriptor.Descriptor( _RECORD = _descriptor.Descriptor(
...@@ -424,7 +461,7 @@ _RECORD = _descriptor.Descriptor( ...@@ -424,7 +461,7 @@ _RECORD = _descriptor.Descriptor(
], ],
extensions=[ extensions=[
], ],
nested_types=[_RECORD_IMAGE, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_HISTOGRAM, _RECORD_PRCURVE, _RECORD_VALUE, ], nested_types=[_RECORD_IMAGE, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_HISTOGRAM, _RECORD_PRCURVE, _RECORD_METADATA, _RECORD_VALUE, ],
enum_types=[ enum_types=[
], ],
serialized_options=None, serialized_options=None,
...@@ -434,7 +471,7 @@ _RECORD = _descriptor.Descriptor( ...@@ -434,7 +471,7 @@ _RECORD = _descriptor.Descriptor(
oneofs=[ oneofs=[
], ],
serialized_start=27, serialized_start=27,
serialized_end=893, serialized_end=975,
) )
_RECORD_IMAGE.containing_type = _RECORD _RECORD_IMAGE.containing_type = _RECORD
...@@ -445,11 +482,13 @@ _RECORD_EMBEDDINGS.containing_type = _RECORD ...@@ -445,11 +482,13 @@ _RECORD_EMBEDDINGS.containing_type = _RECORD
_RECORD_BYTES_EMBEDDINGS.containing_type = _RECORD _RECORD_BYTES_EMBEDDINGS.containing_type = _RECORD
_RECORD_HISTOGRAM.containing_type = _RECORD _RECORD_HISTOGRAM.containing_type = _RECORD
_RECORD_PRCURVE.containing_type = _RECORD _RECORD_PRCURVE.containing_type = _RECORD
_RECORD_METADATA.containing_type = _RECORD
_RECORD_VALUE.fields_by_name['image'].message_type = _RECORD_IMAGE _RECORD_VALUE.fields_by_name['image'].message_type = _RECORD_IMAGE
_RECORD_VALUE.fields_by_name['audio'].message_type = _RECORD_AUDIO _RECORD_VALUE.fields_by_name['audio'].message_type = _RECORD_AUDIO
_RECORD_VALUE.fields_by_name['embeddings'].message_type = _RECORD_EMBEDDINGS _RECORD_VALUE.fields_by_name['embeddings'].message_type = _RECORD_EMBEDDINGS
_RECORD_VALUE.fields_by_name['histogram'].message_type = _RECORD_HISTOGRAM _RECORD_VALUE.fields_by_name['histogram'].message_type = _RECORD_HISTOGRAM
_RECORD_VALUE.fields_by_name['pr_curve'].message_type = _RECORD_PRCURVE _RECORD_VALUE.fields_by_name['pr_curve'].message_type = _RECORD_PRCURVE
_RECORD_VALUE.fields_by_name['meta_data'].message_type = _RECORD_METADATA
_RECORD_VALUE.containing_type = _RECORD _RECORD_VALUE.containing_type = _RECORD
_RECORD_VALUE.oneofs_by_name['one_value'].fields.append( _RECORD_VALUE.oneofs_by_name['one_value'].fields.append(
_RECORD_VALUE.fields_by_name['value']) _RECORD_VALUE.fields_by_name['value'])
...@@ -469,6 +508,9 @@ _RECORD_VALUE.fields_by_name['histogram'].containing_oneof = _RECORD_VALUE.oneof ...@@ -469,6 +508,9 @@ _RECORD_VALUE.fields_by_name['histogram'].containing_oneof = _RECORD_VALUE.oneof
_RECORD_VALUE.oneofs_by_name['one_value'].fields.append( _RECORD_VALUE.oneofs_by_name['one_value'].fields.append(
_RECORD_VALUE.fields_by_name['pr_curve']) _RECORD_VALUE.fields_by_name['pr_curve'])
_RECORD_VALUE.fields_by_name['pr_curve'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value'] _RECORD_VALUE.fields_by_name['pr_curve'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value']
_RECORD_VALUE.oneofs_by_name['one_value'].fields.append(
_RECORD_VALUE.fields_by_name['meta_data'])
_RECORD_VALUE.fields_by_name['meta_data'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value']
_RECORD.fields_by_name['values'].message_type = _RECORD_VALUE _RECORD.fields_by_name['values'].message_type = _RECORD_VALUE
DESCRIPTOR.message_types_by_name['Record'] = _RECORD DESCRIPTOR.message_types_by_name['Record'] = _RECORD
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -524,6 +566,13 @@ Record = _reflection.GeneratedProtocolMessageType('Record', (_message.Message,), ...@@ -524,6 +566,13 @@ Record = _reflection.GeneratedProtocolMessageType('Record', (_message.Message,),
}) })
, ,
'MetaData' : _reflection.GeneratedProtocolMessageType('MetaData', (_message.Message,), {
'DESCRIPTOR' : _RECORD_METADATA,
'__module__' : 'record_pb2'
# @@protoc_insertion_point(class_scope:visualdl.Record.MetaData)
})
,
'Value' : _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), { 'Value' : _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), {
'DESCRIPTOR' : _RECORD_VALUE, 'DESCRIPTOR' : _RECORD_VALUE,
'__module__' : 'record_pb2' '__module__' : 'record_pb2'
...@@ -542,6 +591,7 @@ _sym_db.RegisterMessage(Record.Embeddings) ...@@ -542,6 +591,7 @@ _sym_db.RegisterMessage(Record.Embeddings)
_sym_db.RegisterMessage(Record.bytes_embeddings) _sym_db.RegisterMessage(Record.bytes_embeddings)
_sym_db.RegisterMessage(Record.Histogram) _sym_db.RegisterMessage(Record.Histogram)
_sym_db.RegisterMessage(Record.PRCurve) _sym_db.RegisterMessage(Record.PRCurve)
_sym_db.RegisterMessage(Record.MetaData)
_sym_db.RegisterMessage(Record.Value) _sym_db.RegisterMessage(Record.Value)
......
...@@ -55,6 +55,8 @@ class LogReader(object): ...@@ -55,6 +55,8 @@ class LogReader(object):
self.readers = {} self.readers = {}
self.walks = None self.walks = None
self._tags = {} self._tags = {}
self.name2tags = {}
self.tags2name = {}
self.file_readers = {} self.file_readers = {}
self._environments = components self._environments = components
...@@ -64,6 +66,7 @@ class LogReader(object): ...@@ -64,6 +66,7 @@ class LogReader(object):
self._model = "" self._model = ""
@property @property
def model(self): def model(self):
return self._model return self._model
...@@ -110,12 +113,21 @@ class LogReader(object): ...@@ -110,12 +113,21 @@ class LogReader(object):
component = "histogram" component = "histogram"
elif "pr_curve" == value_type: elif "pr_curve" == value_type:
component = "pr_curve" component = "pr_curve"
elif "meta_data" == value_type:
self.update_meta_data(record)
component = "meta_data"
else: else:
raise TypeError("Invalid value type `%s`." % value_type) raise TypeError("Invalid value type `%s`." % value_type)
self._tags[path] = component self._tags[path] = component
return self._tags[path], self.reader.dir, tag, value return self._tags[path], self.reader.dir, tag, value
def update_meta_data(self, record):
meta = record.values[0].meta_data
if meta.display_name:
self.name2tags[meta.display_name] = self.reader.dir
self.tags2name[self.reader.dir] = meta.display_name
def get_all_walk(self): def get_all_walk(self):
self.walks = {} self.walks = {}
for dir in self.dir: for dir in self.dir:
......
...@@ -24,7 +24,8 @@ DEFAULT_PLUGIN_MAXSIZE = { ...@@ -24,7 +24,8 @@ DEFAULT_PLUGIN_MAXSIZE = {
"histogram": 100, "histogram": 100,
"embeddings": 50000, "embeddings": 50000,
"audio": 10, "audio": 10,
"pr_curve": 300 "pr_curve": 300,
"meta_data": 100
} }
...@@ -277,7 +278,9 @@ class DataManager(object): ...@@ -277,7 +278,9 @@ class DataManager(object):
"audio": "audio":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["audio"]), Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["audio"]),
"pr_curve": "pr_curve":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["pr_curve"]) Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["pr_curve"]),
"meta_data":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["meta_data"])
} }
self._mutex = threading.Lock() self._mutex = threading.Lock()
......
...@@ -29,7 +29,13 @@ def get_components(log_reader): ...@@ -29,7 +29,13 @@ def get_components(log_reader):
def get_runs(log_reader): def get_runs(log_reader):
return log_reader.runs() runs = []
for item in log_reader.runs():
if item in log_reader.tags2name:
runs.append(log_reader.tags2name[item])
else:
runs.append(item)
return runs
def get_tags(log_reader): def get_tags(log_reader):
...@@ -47,7 +53,15 @@ def get_logs(log_reader, component): ...@@ -47,7 +53,15 @@ def get_logs(log_reader, component):
tags[run].append(tag) tags[run].append(tag)
else: else:
tags[run] = [tag] tags[run] = [tag]
return tags fake_tags = {}
for key, value in tags.items():
if key in log_reader.tags2name:
fake_tags[log_reader.tags2name[key]] = value
else:
fake_tags[key] = value
return fake_tags
def get_scalar_tags(log_reader): def get_scalar_tags(log_reader):
...@@ -55,6 +69,7 @@ def get_scalar_tags(log_reader): ...@@ -55,6 +69,7 @@ def get_scalar_tags(log_reader):
def get_scalar(log_reader, run, tag): def get_scalar(log_reader, run, tag):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data() log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("scalar").get_items( records = log_reader.data_manager.get_reservoir("scalar").get_items(
run, decode_tag(tag)) run, decode_tag(tag))
...@@ -67,6 +82,7 @@ def get_image_tags(log_reader): ...@@ -67,6 +82,7 @@ def get_image_tags(log_reader):
def get_image_tag_steps(log_reader, run, tag): def get_image_tag_steps(log_reader, run, tag):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data() log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("image").get_items( records = log_reader.data_manager.get_reservoir("image").get_items(
run, decode_tag(tag)) run, decode_tag(tag))
...@@ -78,6 +94,7 @@ def get_image_tag_steps(log_reader, run, tag): ...@@ -78,6 +94,7 @@ def get_image_tag_steps(log_reader, run, tag):
def get_individual_image(log_reader, run, tag, step_index): def get_individual_image(log_reader, run, tag, step_index):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data() log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("image").get_items( records = log_reader.data_manager.get_reservoir("image").get_items(
run, decode_tag(tag)) run, decode_tag(tag))
...@@ -89,6 +106,7 @@ def get_audio_tags(log_reader): ...@@ -89,6 +106,7 @@ def get_audio_tags(log_reader):
def get_audio_tag_steps(log_reader, run, tag): def get_audio_tag_steps(log_reader, run, tag):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data() log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("audio").get_items( records = log_reader.data_manager.get_reservoir("audio").get_items(
run, decode_tag(tag)) run, decode_tag(tag))
...@@ -100,6 +118,7 @@ def get_audio_tag_steps(log_reader, run, tag): ...@@ -100,6 +118,7 @@ def get_audio_tag_steps(log_reader, run, tag):
def get_individual_audio(log_reader, run, tag, step_index): def get_individual_audio(log_reader, run, tag, step_index):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data() log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("audio").get_items( records = log_reader.data_manager.get_reservoir("audio").get_items(
run, decode_tag(tag)) run, decode_tag(tag))
...@@ -120,6 +139,7 @@ def get_pr_curve_tags(log_reader): ...@@ -120,6 +139,7 @@ def get_pr_curve_tags(log_reader):
def get_pr_curve(log_reader, run, tag): def get_pr_curve(log_reader, run, tag):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data() log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("pr_curve").get_items( records = log_reader.data_manager.get_reservoir("pr_curve").get_items(
run, decode_tag(tag)) run, decode_tag(tag))
...@@ -141,6 +161,7 @@ def get_pr_curve(log_reader, run, tag): ...@@ -141,6 +161,7 @@ def get_pr_curve(log_reader, run, tag):
def get_pr_curve_step(log_reader, run, tag=None): def get_pr_curve_step(log_reader, run, tag=None):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
tag = get_pr_curve_tags(log_reader)[run][0] if tag is None else tag tag = get_pr_curve_tags(log_reader)[run][0] if tag is None else tag
log_reader.load_new_data() log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("pr_curve").get_items( records = log_reader.data_manager.get_reservoir("pr_curve").get_items(
...@@ -150,6 +171,7 @@ def get_pr_curve_step(log_reader, run, tag=None): ...@@ -150,6 +171,7 @@ def get_pr_curve_step(log_reader, run, tag=None):
def get_embeddings(log_reader, run, tag, reduction, dimension=2): def get_embeddings(log_reader, run, tag, reduction, dimension=2):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data() log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("embeddings").get_items( records = log_reader.data_manager.get_reservoir("embeddings").get_items(
run, decode_tag(tag)) run, decode_tag(tag))
...@@ -173,6 +195,7 @@ def get_embeddings(log_reader, run, tag, reduction, dimension=2): ...@@ -173,6 +195,7 @@ def get_embeddings(log_reader, run, tag, reduction, dimension=2):
def get_histogram(log_reader, run, tag): def get_histogram(log_reader, run, tag):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data() log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("histogram").get_items( records = log_reader.data_manager.get_reservoir("histogram").get_items(
run, decode_tag(tag)) run, decode_tag(tag))
......
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
import time import time
import numpy as np import numpy as np
from visualdl.writer.record_writer import RecordFileWriter from visualdl.writer.record_writer import RecordFileWriter
from visualdl.component.base_component import scalar, image, embedding, audio, histogram, pr_curve from visualdl.component.base_component import scalar, image, embedding, audio, histogram, pr_curve, meta_data
class DummyFileWriter(object): class DummyFileWriter(object):
...@@ -66,6 +66,7 @@ class LogWriter(object): ...@@ -66,6 +66,7 @@ class LogWriter(object):
flush_secs=120, flush_secs=120,
filename_suffix='', filename_suffix='',
write_to_disk=True, write_to_disk=True,
display_name='',
**kwargs): **kwargs):
"""Create a instance of class `LogWriter` and create a vdl log file with """Create a instance of class `LogWriter` and create a vdl log file with
given args. given args.
...@@ -98,6 +99,7 @@ class LogWriter(object): ...@@ -98,6 +99,7 @@ class LogWriter(object):
self._all_writers = {} self._all_writers = {}
self._get_file_writer() self._get_file_writer()
self.loggers = {} self.loggers = {}
self.add_meta(display_name=display_name)
@property @property
def logdir(self): def logdir(self):
...@@ -118,6 +120,22 @@ class LogWriter(object): ...@@ -118,6 +120,22 @@ class LogWriter(object):
self._all_writers.update({self._logdir: self._file_writer}) self._all_writers.update({self._logdir: self._file_writer})
return self._file_writer return self._file_writer
def add_meta(self, tag='meta_data_tag', display_name='', step=0, walltime=None):
"""Add a meta to vdl record file.
Args:
tag (string): Data identifier
display_name (string): Display name of `runs`.
step (int): Step of meta.
walltime (int): Wall time of scalar
"""
if '%' in tag:
raise RuntimeError("% can't appear in tag!")
walltime = round(time.time()) if walltime is None else walltime
self._get_file_writer().add_record(
meta_data(tag=tag, display_name=display_name, step=step,
walltime=walltime))
def add_scalar(self, tag, value, step, walltime=None): def add_scalar(self, tag, value, step, walltime=None):
"""Add a scalar to vdl record file. """Add a scalar to vdl record file.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册