未验证 提交 23d63937 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

change param of add_embeddings

上级 55a8e237
...@@ -19,34 +19,34 @@ from visualdl import LogWriter ...@@ -19,34 +19,34 @@ from visualdl import LogWriter
if __name__ == '__main__': if __name__ == '__main__':
hot_vectors = [ mat = [
[1.3561076367500755, 1.3116267195134017, 1.6785401875616097], [1.3561076367500755, 1.3116267195134017, 1.6785401875616097],
[1.1039614644440658, 1.8891609992484688, 1.32030488587171], [1.1039614644440658, 1.8891609992484688, 1.32030488587171],
[1.9924524852447711, 1.9358920727142739, 1.2124401279391606], [1.9924524852447711, 1.9358920727142739, 1.2124401279391606],
[1.4129542689796446, 1.7372166387197474, 1.7317806077076527], [1.4129542689796446, 1.7372166387197474, 1.7317806077076527],
[1.3913371800587777, 1.4684674577930312, 1.5214136352476377]] [1.3913371800587777, 1.4684674577930312, 1.5214136352476377]]
labels = ["label_1", "label_2", "label_3", "label_4", "label_5"] metadata = ["label_1", "label_2", "label_3", "label_4", "label_5"]
with LogWriter(logdir="./log/high_dimensional_test/train") as writer: with LogWriter(logdir="./log/high_dimensional_test/train") as writer:
writer.add_embeddings(tag='default', writer.add_embeddings(tag='default',
labels=labels, mat=mat,
hot_vectors=hot_vectors) metadata=metadata)
""" '''
# You can code as follow if use multi-dimensional labels. # You can code as follow if use multi-dimensional labels.
hot_vectors = [ mat = [
[1.3561076367500755, 1.3116267195134017, 1.6785401875616097], [1.3561076367500755, 1.3116267195134017, 1.6785401875616097],
[1.1039614644440658, 1.8891609992484688, 1.32030488587171], [1.1039614644440658, 1.8891609992484688, 1.32030488587171],
[1.9924524852447711, 1.9358920727142739, 1.2124401279391606], [1.9924524852447711, 1.9358920727142739, 1.2124401279391606],
[1.4129542689796446, 1.7372166387197474, 1.7317806077076527], [1.4129542689796446, 1.7372166387197474, 1.7317806077076527],
[1.3913371800587777, 1.4684674577930312, 1.5214136352476377]] [1.3913371800587777, 1.4684674577930312, 1.5214136352476377]]
labels = [["label_a_1", "label_a_2", "label_a_3", "label_a_4", "label_a_5"], metadata = [["label_a_1", "label_a_2", "label_a_3", "label_a_4", "label_a_5"],
["label_b_1", "label_b_2", "label_b_3", "label_b_4", "label_b_5"]] ["label_b_1", "label_b_2", "label_b_3", "label_b_4", "label_b_5"]]
labels_meta = ["label_a", "label_b"] metadata_header = ["label_a", "label_b"]
with LogWriter(logdir="./log/high_dimensional_test/train") as writer: with LogWriter(logdir="./log/high_dimensional_test/train") as writer:
writer.add_embeddings(tag='default', writer.add_embeddings(tag='default',
labels=labels, mat=mat,
labels_meta=labels_meta, metadata=metadata,
hot_vectors=hot_vectors) metadata_header=metadata_header)
""" '''
...@@ -17,6 +17,7 @@ import os ...@@ -17,6 +17,7 @@ import os
import time import time
import numpy as np import numpy as np
from visualdl.writer.record_writer import RecordFileWriter from visualdl.writer.record_writer import RecordFileWriter
from visualdl.server.log import logger
from visualdl.utils.img_util import merge_images from visualdl.utils.img_util import merge_images
from visualdl.component.base_component import scalar, image, embedding, audio, \ from visualdl.component.base_component import scalar, image, embedding, audio, \
histogram, pr_curve, roc_curve, meta_data, text histogram, pr_curve, roc_curve, meta_data, text
...@@ -241,71 +242,93 @@ class LogWriter(object): ...@@ -241,71 +242,93 @@ class LogWriter(object):
walltime=walltime, walltime=walltime,
dataformats=dataformats) dataformats=dataformats)
def add_embeddings(self, tag, labels, hot_vectors, labels_meta=None, walltime=None): def add_embeddings(self, tag, mat=None, metadata=None,
metadata_header=None, walltime=None, labels=None,
hot_vectors=None, labels_meta=None):
"""Add embeddings to vdl record file. """Add embeddings to vdl record file.
Args: Args:
tag (string): Data identifier tag (string): Data identifier
labels (numpy.array or list): A 1D or 2D matrix of labels mat (numpy.array or list): A matrix which each row is
hot_vectors (numpy.array or list): A matrix which each row is
feature of labels. feature of labels.
labels_meta (numpy.array or list): Meta data of labels. metadata (numpy.array or list): A 1D or 2D matrix of labels
metadata_header (numpy.array or list): Meta data of labels.
walltime (int): Wall time of embeddings. walltime (int): Wall time of embeddings.
labels (numpy.array or list): Obsolete parameter, use `metadata` to
replace it.
hot_vectors (numpy.array or list): Obsolete parameter, use `mat` to
replace it.
labels_meta (numpy.array or list): Obsolete parameter, use
`metadata_header` to replace it.
Example 1: Example 1:
hot_vectors = [ mat = [
[1.3561076367500755, 1.3116267195134017, 1.6785401875616097], [1.3561076367500755, 1.3116267195134017, 1.6785401875616097],
[1.1039614644440658, 1.8891609992484688, 1.32030488587171], [1.1039614644440658, 1.8891609992484688, 1.32030488587171],
[1.9924524852447711, 1.9358920727142739, 1.2124401279391606], [1.9924524852447711, 1.9358920727142739, 1.2124401279391606],
[1.4129542689796446, 1.7372166387197474, 1.7317806077076527], [1.4129542689796446, 1.7372166387197474, 1.7317806077076527],
[1.3913371800587777, 1.4684674577930312, 1.5214136352476377]] [1.3913371800587777, 1.4684674577930312, 1.5214136352476377]]
labels = ["label_1", "label_2", "label_3", "label_4", "label_5"] metadata = ["label_1", "label_2", "label_3", "label_4", "label_5"]
# or like this # or like this
# labels = [["label_1", "label_2", "label_3", "label_4", "label_5"]] # metadata = [["label_1", "label_2", "label_3", "label_4", "label_5"]]
writer.add_embeddings(tag='default', writer.add_embeddings(tag='default',
labels=labels, metadata=metadata,
vectors=hot_vectors, mat=mat,
walltime=round(time.time() * 1000)) walltime=round(time.time() * 1000))
Example 2: Example 2:
hot_vectors = [ mat = [
[1.3561076367500755, 1.3116267195134017, 1.6785401875616097], [1.3561076367500755, 1.3116267195134017, 1.6785401875616097],
[1.1039614644440658, 1.8891609992484688, 1.32030488587171], [1.1039614644440658, 1.8891609992484688, 1.32030488587171],
[1.9924524852447711, 1.9358920727142739, 1.2124401279391606], [1.9924524852447711, 1.9358920727142739, 1.2124401279391606],
[1.4129542689796446, 1.7372166387197474, 1.7317806077076527], [1.4129542689796446, 1.7372166387197474, 1.7317806077076527],
[1.3913371800587777, 1.4684674577930312, 1.5214136352476377]] [1.3913371800587777, 1.4684674577930312, 1.5214136352476377]]
labels = [["label_a_1", "label_a_2", "label_a_3", "label_a_4", "label_a_5"], metadata = [["label_a_1", "label_a_2", "label_a_3", "label_a_4", "label_a_5"],
["label_b_1", "label_b_2", "label_b_3", "label_b_4", "label_b_5"]] ["label_b_1", "label_b_2", "label_b_3", "label_b_4", "label_b_5"]]
labels_meta = ["label_a", "label_2"] metadata_header = ["label_a", "label_2"]
writer.add_embeddings(tag='default', writer.add_embeddings(tag='default',
labels=labels, metadata=metadata,
labels_meta=labels_meta, metadata_header=metadata_header,
vectors=hot_vectors, mat=mat,
walltime=round(time.time() * 1000)) walltime=round(time.time() * 1000))
""" """
if '%' in tag: if '%' in tag:
raise RuntimeError("% can't appear in tag!") raise RuntimeError("% can't appear in tag!")
if isinstance(hot_vectors, np.ndarray): if not mat and hot_vectors:
hot_vectors = hot_vectors.tolist() mat = hot_vectors
if isinstance(labels, np.ndarray): logger.warning('Parameter `hot_vectors` in function '
labels = labels.tolist() '`add_embeddings` will be deprecated in '
'future, use `mat` instead.')
if isinstance(labels[0], list) and not labels_meta: if not metadata and labels:
labels_meta = ["label_%d" % i for i in range(len(labels))] metadata = labels
logger.warning(
'Parameter `labels` in function `add_embeddings` will be '
'deprecated in future, use `metadata` instead.')
if not metadata_header and labels_meta:
metadata_header = labels_meta
logger.warning(
'Parameter `labels_meta` in function `add_embeddings` will be'
' deprecated in future, use `metadata_header` instead.')
if isinstance(mat, np.ndarray):
mat = mat.tolist()
if isinstance(metadata, np.ndarray):
metadata = metadata.tolist()
if isinstance(metadata[0], list) and not metadata_header:
metadata_header = ["label_%d" % i for i in range(len(metadata))]
step = 0 step = 0
walltime = round(time.time() * 1000) if walltime is None else walltime walltime = round(time.time() * 1000) if walltime is None else walltime
self._get_file_writer().add_record( self._get_file_writer().add_record(
embedding( embedding(
tag=tag, tag=tag,
labels=labels, labels=metadata,
labels_meta=labels_meta, labels_meta=metadata_header,
hot_vectors=hot_vectors, hot_vectors=mat,
step=step, step=step,
walltime=walltime)) walltime=walltime))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册