未验证 提交 363a7bf3 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

support multi-dimentional vectors api (#874)

上级 81240db7
...@@ -211,48 +211,6 @@ After launching the panel by one of the above methods, developers can see the vi ...@@ -211,48 +211,6 @@ After launching the panel by one of the above methods, developers can see the vi
<img src="https://user-images.githubusercontent.com/48054808/90868674-ba321f00-e3c9-11ea-83c1-f03c6dd19187.png" width="70%"/> <img src="https://user-images.githubusercontent.com/48054808/90868674-ba321f00-e3c9-11ea-83c1-f03c6dd19187.png" width="70%"/>
</p> </p>
### 3. Read data in log files using LogReader
VisualDL also provide `LogReader` interface to read raw data from log files.
```python
class LogReader(logdir=None,
file_name='')
```
#### interface parameters
| parameters | type | meaning |
| ---------- | ------ | ------------------------------------ |
| logdir | string | Path to the log directory. Required. |
| file_name | string | File name of the log file. Required. |
#### Example
Suppose there is a log file named `vdlrecords.1605533348.log` in directory `./log`. We can get scalar data in `loss` tag by
```python
from visualdl import LogReader
reader = LogReader(logdir='./log', file_name='vdlrecords.1605533348.log')
data = reader.get_data('scalar', 'loss')
print(data)
```
The result is a list of
```python
...
id: 5
tag: "Metrics/Training(Step): loss"
timestamp: 1605533356039
value: 3.1297709941864014
...
```
For more information of `LogReader`, please refer to [LogReader](./docs/io/LogReader.md).
## Function Preview ## Function Preview
### Scalar ### Scalar
......
...@@ -29,3 +29,22 @@ if __name__ == '__main__': ...@@ -29,3 +29,22 @@ if __name__ == '__main__':
writer.add_embeddings(tag='default', writer.add_embeddings(tag='default',
labels=labels, labels=labels,
hot_vectors=hot_vectors) hot_vectors=hot_vectors)
"""
# You can code as follow if use multi-dimensional labels.
hot_vectors = [
[1.3561076367500755, 1.3116267195134017, 1.6785401875616097],
[1.1039614644440658, 1.8891609992484688, 1.32030488587171],
[1.9924524852447711, 1.9358920727142739, 1.2124401279391606],
[1.4129542689796446, 1.7372166387197474, 1.7317806077076527],
[1.3913371800587777, 1.4684674577930312, 1.5214136352476377]]
labels = [["label_a_1", "label_a_2", "label_a_3", "label_a_4", "label_a_5"],
["label_b_1", "label_b_2", "label_b_3", "label_b_4", "label_b_5"]]
labels_meta = ["label_a", "label_b"]
with LogWriter(logdir="./log/high_dimensional_test/train") as writer:
writer.add_embeddings(tag='default',
labels=labels,
labels_meta=labels_meta,
hot_vectors=hot_vectors)
"""
...@@ -159,12 +159,12 @@ def image(tag, image_array, step, walltime=None, dataformats="HWC"): ...@@ -159,12 +159,12 @@ def image(tag, image_array, step, walltime=None, dataformats="HWC"):
]) ])
def embedding(tag, labels, hot_vectors, step, walltime=None): def embedding(tag, labels, hot_vectors, step, labels_meta=None, walltime=None):
"""Package data to one embedding. """Package data to one embedding.
Args: Args:
tag (string): Data identifier tag (string): Data identifier
labels (numpy.array or list): A list of labels. labels (list): A list of labels.
hot_vectors (numpy.array or list): A matrix which each row is hot_vectors (numpy.array or list): A matrix which each row is
feature of labels. feature of labels.
step (int): Step of embeddings. step (int): Step of embeddings.
...@@ -175,9 +175,18 @@ def embedding(tag, labels, hot_vectors, step, walltime=None): ...@@ -175,9 +175,18 @@ def embedding(tag, labels, hot_vectors, step, walltime=None):
""" """
embeddings = Record.Embeddings() embeddings = Record.Embeddings()
for index in range(len(hot_vectors)): if labels_meta:
embeddings.embeddings.append( embeddings.label_meta.extend(labels_meta)
Record.Embedding(label=labels[index], vectors=hot_vectors[index]))
if isinstance(labels[0], list):
temp = []
for index in range(len(labels[0])):
temp.append([label[index] for label in labels])
labels = temp
for label, hot_vector in zip(labels, hot_vectors):
if not isinstance(label, list):
label = [label]
embeddings.embeddings.append(Record.Embedding(label=label, vectors=hot_vector))
return Record(values=[ return Record(values=[
Record.Value( Record.Value(
......
...@@ -16,12 +16,13 @@ message Record { ...@@ -16,12 +16,13 @@ message Record {
} }
message Embedding { message Embedding {
string label = 1; repeated string label = 1;
repeated float vectors = 2; repeated float vectors = 2;
} }
message Embeddings { message Embeddings {
repeated Embedding embeddings = 1; repeated Embedding embeddings = 1;
repeated string label_meta = 2;
} }
message bytes_embeddings { message bytes_embeddings {
......
...@@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='visualdl', package='visualdl',
syntax='proto3', syntax='proto3',
serialized_options=None, serialized_options=None,
serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xb4\x07\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1a<\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1al\n\x07PRCurve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x11\n\tprecision\x18\x05 \x03(\x01\x12\x0e\n\x06recall\x18\x06 \x03(\x01\x1a \n\x08MetaData\x12\x14\n\x0c\x64isplay_name\x18\x01 \x01(\t\x1a\xe5\x02\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x12,\n\x08pr_curve\x18\t \x01(\x0b\x32\x18.visualdl.Record.PRCurveH\x00\x12.\n\tmeta_data\x18\n \x01(\x0b\x32\x19.visualdl.Record.MetaDataH\x00\x42\x0b\n\tone_valueb\x06proto3' serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xc8\x07\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x03(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1aP\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x12\x12\n\nlabel_meta\x18\x02 \x03(\t\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1al\n\x07PRCurve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x11\n\tprecision\x18\x05 \x03(\x01\x12\x0e\n\x06recall\x18\x06 \x03(\x01\x1a \n\x08MetaData\x12\x14\n\x0c\x64isplay_name\x18\x01 \x01(\t\x1a\xe5\x02\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x12,\n\x08pr_curve\x18\t \x01(\x0b\x32\x18.visualdl.Record.PRCurveH\x00\x12.\n\tmeta_data\x18\n \x01(\x0b\x32\x19.visualdl.Record.MetaDataH\x00\x42\x0b\n\tone_valueb\x06proto3'
) )
...@@ -121,8 +121,8 @@ _RECORD_EMBEDDING = _descriptor.Descriptor( ...@@ -121,8 +121,8 @@ _RECORD_EMBEDDING = _descriptor.Descriptor(
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='label', full_name='visualdl.Record.Embedding.label', index=0, name='label', full_name='visualdl.Record.Embedding.label', index=0,
number=1, type=9, cpp_type=9, label=1, number=1, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=b"".decode('utf-8'), has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
...@@ -163,6 +163,13 @@ _RECORD_EMBEDDINGS = _descriptor.Descriptor( ...@@ -163,6 +163,13 @@ _RECORD_EMBEDDINGS = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='label_meta', full_name='visualdl.Record.Embeddings.label_meta', index=1,
number=2, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
], ],
extensions=[ extensions=[
], ],
...@@ -176,7 +183,7 @@ _RECORD_EMBEDDINGS = _descriptor.Descriptor( ...@@ -176,7 +183,7 @@ _RECORD_EMBEDDINGS = _descriptor.Descriptor(
oneofs=[ oneofs=[
], ],
serialized_start=288, serialized_start=288,
serialized_end=348, serialized_end=368,
) )
_RECORD_BYTES_EMBEDDINGS = _descriptor.Descriptor( _RECORD_BYTES_EMBEDDINGS = _descriptor.Descriptor(
...@@ -212,8 +219,8 @@ _RECORD_BYTES_EMBEDDINGS = _descriptor.Descriptor( ...@@ -212,8 +219,8 @@ _RECORD_BYTES_EMBEDDINGS = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=350, serialized_start=370,
serialized_end=417, serialized_end=437,
) )
_RECORD_HISTOGRAM = _descriptor.Descriptor( _RECORD_HISTOGRAM = _descriptor.Descriptor(
...@@ -249,8 +256,8 @@ _RECORD_HISTOGRAM = _descriptor.Descriptor( ...@@ -249,8 +256,8 @@ _RECORD_HISTOGRAM = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=419, serialized_start=439,
serialized_end=471, serialized_end=491,
) )
_RECORD_PRCURVE = _descriptor.Descriptor( _RECORD_PRCURVE = _descriptor.Descriptor(
...@@ -314,8 +321,8 @@ _RECORD_PRCURVE = _descriptor.Descriptor( ...@@ -314,8 +321,8 @@ _RECORD_PRCURVE = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=473, serialized_start=493,
serialized_end=581, serialized_end=601,
) )
_RECORD_METADATA = _descriptor.Descriptor( _RECORD_METADATA = _descriptor.Descriptor(
...@@ -344,8 +351,8 @@ _RECORD_METADATA = _descriptor.Descriptor( ...@@ -344,8 +351,8 @@ _RECORD_METADATA = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=583, serialized_start=603,
serialized_end=615, serialized_end=635,
) )
_RECORD_VALUE = _descriptor.Descriptor( _RECORD_VALUE = _descriptor.Descriptor(
...@@ -440,8 +447,8 @@ _RECORD_VALUE = _descriptor.Descriptor( ...@@ -440,8 +447,8 @@ _RECORD_VALUE = _descriptor.Descriptor(
name='one_value', full_name='visualdl.Record.Value.one_value', name='one_value', full_name='visualdl.Record.Value.one_value',
index=0, containing_type=None, fields=[]), index=0, containing_type=None, fields=[]),
], ],
serialized_start=618, serialized_start=638,
serialized_end=975, serialized_end=995,
) )
_RECORD = _descriptor.Descriptor( _RECORD = _descriptor.Descriptor(
...@@ -471,7 +478,7 @@ _RECORD = _descriptor.Descriptor( ...@@ -471,7 +478,7 @@ _RECORD = _descriptor.Descriptor(
oneofs=[ oneofs=[
], ],
serialized_start=27, serialized_start=27,
serialized_end=975, serialized_end=995,
) )
_RECORD_IMAGE.containing_type = _RECORD _RECORD_IMAGE.containing_type = _RECORD
......
...@@ -226,7 +226,11 @@ def get_embedding_labels(log_reader, name): ...@@ -226,7 +226,11 @@ def get_embedding_labels(log_reader, name):
run, decode_tag(tag)) run, decode_tag(tag))
labels = [] labels = []
for item in records[0].embeddings.embeddings: for item in records[0].embeddings.embeddings:
labels.append([item.label]) labels.append(item.label)
label_meta = records[0].embeddings.label_meta
if label_meta:
labels = [label_meta] + labels
with io.StringIO() as fp: with io.StringIO() as fp:
csv_writer = csv.writer(fp, delimiter='\t') csv_writer = csv.writer(fp, delimiter='\t')
......
...@@ -188,17 +188,18 @@ class LogWriter(object): ...@@ -188,17 +188,18 @@ class LogWriter(object):
image(tag=tag, image_array=img, step=step, walltime=walltime, image(tag=tag, image_array=img, step=step, walltime=walltime,
dataformats=dataformats)) dataformats=dataformats))
def add_embeddings(self, tag, labels, hot_vectors, walltime=None): def add_embeddings(self, tag, labels, hot_vectors, labels_meta=None, walltime=None):
"""Add embeddings to vdl record file. """Add embeddings to vdl record file.
Args: Args:
tag (string): Data identifier tag (string): Data identifier
labels (numpy.array or list): A list of labels. labels (numpy.array or list): A 1D or 2D matrix of labels
hot_vectors (numpy.array or list): A matrix which each row is hot_vectors (numpy.array or list): A matrix which each row is
feature of labels. feature of labels.
labels_meta (numpy.array or list): Meta data of labels.
walltime (int): Wall time of embeddings. walltime (int): Wall time of embeddings.
Example: Example 1:
hot_vectors = [ hot_vectors = [
[1.3561076367500755, 1.3116267195134017, 1.6785401875616097], [1.3561076367500755, 1.3116267195134017, 1.6785401875616097],
[1.1039614644440658, 1.8891609992484688, 1.32030488587171], [1.1039614644440658, 1.8891609992484688, 1.32030488587171],
...@@ -207,9 +208,32 @@ class LogWriter(object): ...@@ -207,9 +208,32 @@ class LogWriter(object):
[1.3913371800587777, 1.4684674577930312, 1.5214136352476377]] [1.3913371800587777, 1.4684674577930312, 1.5214136352476377]]
labels = ["label_1", "label_2", "label_3", "label_4", "label_5"] labels = ["label_1", "label_2", "label_3", "label_4", "label_5"]
# or like this
# labels = [["label_1", "label_2", "label_3", "label_4", "label_5"]]
writer.add_embeddings(tag='default',
labels=labels,
vectors=hot_vectors,
walltime=round(time.time() * 1000))
writer.add_embedding(labels=labels, vectors=hot_vectors, Example 2:
walltime=round(time.time() * 1000)) hot_vectors = [
[1.3561076367500755, 1.3116267195134017, 1.6785401875616097],
[1.1039614644440658, 1.8891609992484688, 1.32030488587171],
[1.9924524852447711, 1.9358920727142739, 1.2124401279391606],
[1.4129542689796446, 1.7372166387197474, 1.7317806077076527],
[1.3913371800587777, 1.4684674577930312, 1.5214136352476377]]
labels = [["label_a_1", "label_a_2", "label_a_3", "label_a_4", "label_a_5"],
["label_b_1", "label_b_2", "label_b_3", "label_b_4", "label_b_5"]]
labels_meta = ["label_a", "label_2"]
writer.add_embeddings(tag='default',
labels=labels,
labels_meta=labels_meta,
vectors=hot_vectors,
walltime=round(time.time() * 1000))
""" """
if '%' in tag: if '%' in tag:
raise RuntimeError("% can't appear in tag!") raise RuntimeError("% can't appear in tag!")
...@@ -217,12 +241,17 @@ class LogWriter(object): ...@@ -217,12 +241,17 @@ class LogWriter(object):
hot_vectors = hot_vectors.tolist() hot_vectors = hot_vectors.tolist()
if isinstance(labels, np.ndarray): if isinstance(labels, np.ndarray):
labels = labels.tolist() labels = labels.tolist()
if isinstance(labels[0], list) and not labels_meta:
labels_meta = ["label_%d" % i for i in range(len(labels))]
step = 0 step = 0
walltime = round(time.time() * 1000) if walltime is None else walltime walltime = round(time.time() * 1000) if walltime is None else walltime
self._get_file_writer().add_record( self._get_file_writer().add_record(
embedding( embedding(
tag=tag, tag=tag,
labels=labels, labels=labels,
labels_meta=labels_meta,
hot_vectors=hot_vectors, hot_vectors=hot_vectors,
step=step, step=step,
walltime=walltime)) walltime=walltime))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册