未验证 提交 1d65b86d 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

add backend api for hidi (#868)

* add backend code for hidi v2
上级 53a31e3b
...@@ -146,6 +146,20 @@ class Api(object): ...@@ -146,6 +146,20 @@ class Api(object):
key = os.path.join('data/plugin/embeddings/embeddings', run, str(dimension), reduction) key = os.path.join('data/plugin/embeddings/embeddings', run, str(dimension), reduction)
return self._get_with_retry(key, lib.get_embeddings, run, tag, reduction, dimension) return self._get_with_retry(key, lib.get_embeddings, run, tag, reduction, dimension)
@result()
def embedding_list(self):
return self._get_with_retry('data/plugin/embeddings/list', lib.get_embeddings_list)
@result('text/tab-separated-values')
def embedding_metadata(self, name):
key = os.path.join('data/plugin/embeddings/metadata', name)
return self._get_with_retry(key, lib.get_embedding_labels, name)
@result('application/octet-stream')
def embedding_tensor(self, name):
key = os.path.join('data/plugin/embeddings/tensor', name)
return self._get_with_retry(key, lib.get_embedding_tensors, name)
@result() @result()
def histogram_tags(self): def histogram_tags(self):
return self._get_with_retry('data/plugin/histogram/tags', lib.get_histogram_tags) return self._get_with_retry('data/plugin/histogram/tags', lib.get_histogram_tags)
...@@ -190,6 +204,9 @@ def create_api_call(logdir, model, cache_timeout): ...@@ -190,6 +204,9 @@ def create_api_call(logdir, model, cache_timeout):
'audio/list': (api.audio_list, ['run', 'tag']), 'audio/list': (api.audio_list, ['run', 'tag']),
'audio/audio': (api.audio_audio, ['run', 'tag', 'index']), 'audio/audio': (api.audio_audio, ['run', 'tag', 'index']),
'embedding/embedding': (api.embedding_embedding, ['run', 'tag', 'reduction', 'dimension']), 'embedding/embedding': (api.embedding_embedding, ['run', 'tag', 'reduction', 'dimension']),
'embedding/list': (api.embedding_list, []),
'embedding/tensor': (api.embedding_tensor, ['name']),
'embedding/metadata': (api.embedding_metadata, ['name']),
'histogram/list': (api.histogram_list, ['run', 'tag']), 'histogram/list': (api.histogram_list, ['run', 'tag']),
'graph/graph': (api.graph_graph, []), 'graph/graph': (api.graph_graph, []),
'pr-curve/list': (api.pr_curves_pr_curve, ['run', 'tag']), 'pr-curve/list': (api.pr_curves_pr_curve, ['run', 'tag']),
......
...@@ -21,7 +21,7 @@ DEFAULT_PLUGIN_MAXSIZE = { ...@@ -21,7 +21,7 @@ DEFAULT_PLUGIN_MAXSIZE = {
"scalar": 1000, "scalar": 1000,
"image": 10, "image": 10,
"histogram": 100, "histogram": 100,
"embeddings": 50000, "embeddings": 50000000,
"audio": 10, "audio": 10,
"pr_curve": 300, "pr_curve": 300,
"meta_data": 100 "meta_data": 100
......
...@@ -17,6 +17,8 @@ from __future__ import absolute_import ...@@ -17,6 +17,8 @@ from __future__ import absolute_import
import sys import sys
import time import time
import os import os
import io
import csv
from functools import partial from functools import partial
import numpy as np import numpy as np
from visualdl.server.log import logger from visualdl.server.log import logger
...@@ -27,6 +29,8 @@ from visualdl.component import components ...@@ -27,6 +29,8 @@ from visualdl.component import components
MODIFY_PREFIX = {} MODIFY_PREFIX = {}
MODIFIED_RUNS = [] MODIFIED_RUNS = []
EMBEDDING_NAME = {}
embedding_names = []
def s2ms(timestamp): def s2ms(timestamp):
...@@ -196,6 +200,56 @@ def get_pr_curve_step(log_reader, run, tag=None): ...@@ -196,6 +200,56 @@ def get_pr_curve_step(log_reader, run, tag=None):
return results return results
def get_embeddings_list(log_reader):
run2tag = get_logs(log_reader, 'embeddings')
for run, _tags in zip(run2tag['runs'], run2tag['tags']):
for tag in _tags:
name = path = os.path.join(run, tag)
if name in EMBEDDING_NAME:
return embedding_names
EMBEDDING_NAME.update({name: {'run': run, 'tag': tag}})
records = log_reader.data_manager.get_reservoir("embeddings").get_items(
run, decode_tag(tag))
row_len = len(records[0].embeddings.embeddings)
col_len = len(records[0].embeddings.embeddings[0].vectors)
shape = [row_len, col_len]
embedding_names.append({'name': name, 'shape': shape, 'path': path})
return embedding_names
def get_embedding_labels(log_reader, name):
run = EMBEDDING_NAME[name]['run']
tag = EMBEDDING_NAME[name]['tag']
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("embeddings").get_items(
run, decode_tag(tag))
labels = []
for item in records[0].embeddings.embeddings:
labels.append([item.label])
with io.StringIO() as fp:
csv_writer = csv.writer(fp, delimiter='\t')
csv_writer.writerows(labels)
labels = fp.getvalue()
# labels = "\n".join(str(i) for i in labels)
return labels
def get_embedding_tensors(log_reader, name):
run = EMBEDDING_NAME[name]['run']
tag = EMBEDDING_NAME[name]['tag']
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("embeddings").get_items(
run, decode_tag(tag))
vectors = []
for item in records[0].embeddings.embeddings:
vectors.append(item.vectors)
vectors = np.array(vectors).flatten().astype(np.float32).tobytes()
return vectors
def get_embeddings(log_reader, run, tag, reduction, dimension=2): def get_embeddings(log_reader, run, tag, reduction, dimension=2):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data() log_reader.load_new_data()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册