未验证 提交 23d63937 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

change param of add_embeddings

上级 55a8e237
......@@ -19,34 +19,34 @@ from visualdl import LogWriter
if __name__ == '__main__':
hot_vectors = [
mat = [
[1.3561076367500755, 1.3116267195134017, 1.6785401875616097],
[1.1039614644440658, 1.8891609992484688, 1.32030488587171],
[1.9924524852447711, 1.9358920727142739, 1.2124401279391606],
[1.4129542689796446, 1.7372166387197474, 1.7317806077076527],
[1.3913371800587777, 1.4684674577930312, 1.5214136352476377]]
labels = ["label_1", "label_2", "label_3", "label_4", "label_5"]
metadata = ["label_1", "label_2", "label_3", "label_4", "label_5"]
with LogWriter(logdir="./log/high_dimensional_test/train") as writer:
writer.add_embeddings(tag='default',
labels=labels,
hot_vectors=hot_vectors)
mat=mat,
metadata=metadata)
"""
'''
# You can code as follow if use multi-dimensional labels.
hot_vectors = [
mat = [
[1.3561076367500755, 1.3116267195134017, 1.6785401875616097],
[1.1039614644440658, 1.8891609992484688, 1.32030488587171],
[1.9924524852447711, 1.9358920727142739, 1.2124401279391606],
[1.4129542689796446, 1.7372166387197474, 1.7317806077076527],
[1.3913371800587777, 1.4684674577930312, 1.5214136352476377]]
labels = [["label_a_1", "label_a_2", "label_a_3", "label_a_4", "label_a_5"],
metadata = [["label_a_1", "label_a_2", "label_a_3", "label_a_4", "label_a_5"],
["label_b_1", "label_b_2", "label_b_3", "label_b_4", "label_b_5"]]
labels_meta = ["label_a", "label_b"]
metadata_header = ["label_a", "label_b"]
with LogWriter(logdir="./log/high_dimensional_test/train") as writer:
writer.add_embeddings(tag='default',
labels=labels,
labels_meta=labels_meta,
hot_vectors=hot_vectors)
"""
mat=mat,
metadata=metadata,
metadata_header=metadata_header)
'''
......@@ -17,6 +17,7 @@ import os
import time
import numpy as np
from visualdl.writer.record_writer import RecordFileWriter
from visualdl.server.log import logger
from visualdl.utils.img_util import merge_images
from visualdl.component.base_component import scalar, image, embedding, audio, \
histogram, pr_curve, roc_curve, meta_data, text
......@@ -241,71 +242,93 @@ class LogWriter(object):
walltime=walltime,
dataformats=dataformats)
def add_embeddings(self, tag, labels, hot_vectors, labels_meta=None, walltime=None):
def add_embeddings(self, tag, mat=None, metadata=None,
metadata_header=None, walltime=None, labels=None,
hot_vectors=None, labels_meta=None):
"""Add embeddings to vdl record file.
Args:
tag (string): Data identifier
labels (numpy.array or list): A 1D or 2D matrix of labels
hot_vectors (numpy.array or list): A matrix which each row is
mat (numpy.array or list): A matrix which each row is
feature of labels.
labels_meta (numpy.array or list): Meta data of labels.
metadata (numpy.array or list): A 1D or 2D matrix of labels
metadata_header (numpy.array or list): Meta data of labels.
walltime (int): Wall time of embeddings.
labels (numpy.array or list): Obsolete parameter, use `metadata` to
replace it.
hot_vectors (numpy.array or list): Obsolete parameter, use `mat` to
replace it.
labels_meta (numpy.array or list): Obsolete parameter, use
`metadata_header` to replace it.
Example 1:
hot_vectors = [
mat = [
[1.3561076367500755, 1.3116267195134017, 1.6785401875616097],
[1.1039614644440658, 1.8891609992484688, 1.32030488587171],
[1.9924524852447711, 1.9358920727142739, 1.2124401279391606],
[1.4129542689796446, 1.7372166387197474, 1.7317806077076527],
[1.3913371800587777, 1.4684674577930312, 1.5214136352476377]]
labels = ["label_1", "label_2", "label_3", "label_4", "label_5"]
metadata = ["label_1", "label_2", "label_3", "label_4", "label_5"]
# or like this
# labels = [["label_1", "label_2", "label_3", "label_4", "label_5"]]
# metadata = [["label_1", "label_2", "label_3", "label_4", "label_5"]]
writer.add_embeddings(tag='default',
labels=labels,
vectors=hot_vectors,
metadata=metadata,
mat=mat,
walltime=round(time.time() * 1000))
Example 2:
hot_vectors = [
mat = [
[1.3561076367500755, 1.3116267195134017, 1.6785401875616097],
[1.1039614644440658, 1.8891609992484688, 1.32030488587171],
[1.9924524852447711, 1.9358920727142739, 1.2124401279391606],
[1.4129542689796446, 1.7372166387197474, 1.7317806077076527],
[1.3913371800587777, 1.4684674577930312, 1.5214136352476377]]
labels = [["label_a_1", "label_a_2", "label_a_3", "label_a_4", "label_a_5"],
metadata = [["label_a_1", "label_a_2", "label_a_3", "label_a_4", "label_a_5"],
["label_b_1", "label_b_2", "label_b_3", "label_b_4", "label_b_5"]]
labels_meta = ["label_a", "label_2"]
metadata_header = ["label_a", "label_2"]
writer.add_embeddings(tag='default',
labels=labels,
labels_meta=labels_meta,
vectors=hot_vectors,
metadata=metadata,
metadata_header=metadata_header,
mat=mat,
walltime=round(time.time() * 1000))
"""
if '%' in tag:
raise RuntimeError("% can't appear in tag!")
if isinstance(hot_vectors, np.ndarray):
hot_vectors = hot_vectors.tolist()
if isinstance(labels, np.ndarray):
labels = labels.tolist()
if isinstance(labels[0], list) and not labels_meta:
labels_meta = ["label_%d" % i for i in range(len(labels))]
if not mat and hot_vectors:
mat = hot_vectors
logger.warning('Parameter `hot_vectors` in function '
'`add_embeddings` will be deprecated in '
'future, use `mat` instead.')
if not metadata and labels:
metadata = labels
logger.warning(
'Parameter `labels` in function `add_embeddings` will be '
'deprecated in future, use `metadata` instead.')
if not metadata_header and labels_meta:
metadata_header = labels_meta
logger.warning(
'Parameter `labels_meta` in function `add_embeddings` will be'
' deprecated in future, use `metadata_header` instead.')
if isinstance(mat, np.ndarray):
mat = mat.tolist()
if isinstance(metadata, np.ndarray):
metadata = metadata.tolist()
if isinstance(metadata[0], list) and not metadata_header:
metadata_header = ["label_%d" % i for i in range(len(metadata))]
step = 0
walltime = round(time.time() * 1000) if walltime is None else walltime
self._get_file_writer().add_record(
embedding(
tag=tag,
labels=labels,
labels_meta=labels_meta,
hot_vectors=hot_vectors,
labels=metadata,
labels_meta=metadata_header,
hot_vectors=mat,
step=step,
walltime=walltime))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册