未验证 提交 2b9e8c65 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

Add histogram, add hdfs. (#657)

* Add histogram, add hdfs.
上级 b61595ea
# Copyright (c) 2020 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
# coding=utf-8
from visualdl import LogWriter
import numpy as np
if __name__ == '__main__':
values = np.arange(0, 1000)
with LogWriter(logdir="./log/histogram_test/train") as writer:
for index in range(5):
writer.add_histogram(tag='default1',
values=values+index,
step=index,
buckets=10)
...@@ -8,3 +8,4 @@ Flask-Babel >= 1.0.0 ...@@ -8,3 +8,4 @@ Flask-Babel >= 1.0.0
six >= 1.14.0 six >= 1.14.0
protobuf >= 3.1.0 protobuf >= 3.1.0
opencv-python opencv-python
hdfs
...@@ -24,5 +24,11 @@ components = { ...@@ -24,5 +24,11 @@ components = {
}, },
"audio": { "audio": {
"enabled": False "enabled": False
},
"histogram": {
"enabled": False
},
"graph": {
"enabled": False
} }
} }
...@@ -131,3 +131,11 @@ def audio(tag, audio_array, sample_rate, step, walltime): ...@@ -131,3 +131,11 @@ def audio(tag, audio_array, sample_rate, step, walltime):
return Record(values=[ return Record(values=[
Record.Value(id=step, tag=tag, timestamp=walltime, audio=audio_data) Record.Value(id=step, tag=tag, timestamp=walltime, audio=audio_data)
]) ])
def histogram(tag, hist, bin_edges, step, walltime):
histogram = Record.Histogram(hist=hist, bin_edges=bin_edges)
return Record(values=[
Record.Value(
id=step, tag=tag, timestamp=walltime, histogram=histogram)
])
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import tempfile import tempfile
import hdfs
# Note: Some codes here refer to TensorBoardX. # Note: Some codes here refer to TensorBoardX.
# A good default block size depends on the system in question. # A good default block size depends on the system in question.
...@@ -90,6 +91,51 @@ class LocalFileSystem(object): ...@@ -90,6 +91,51 @@ class LocalFileSystem(object):
default_file_factory.register_filesystem("", LocalFileSystem()) default_file_factory.register_filesystem("", LocalFileSystem())
class HDFileSystem(object):
def __init__(self):
self.cli = hdfs.config.Config().get_client('dev')
def exists(self, path):
if self.cli.status(hdfs_path=path[7:], strict=False) is None:
return False
else:
return True
def makedirs(self, path):
self.cli.makedirs(hdfs_path=path[7:])
@staticmethod
def join(path, *paths):
return os.path.join(path, *paths)
def read(self, filename, binary_mode=False, size=0, continue_from=None):
offset = 0
if continue_from is not None:
offset = continue_from.get("last_offset", 0)
encoding = None if binary_mode else "utf-8"
with self.cli.read(hdfs_path=filename[7:], offset=offset, encoding=encoding) as reader:
data = reader.read()
continue_from_token = {"last_offset": offset + len(data)}
return data, continue_from_token
def append(self, filename, file_content, binary_mode=False):
self.cli.write(hdfs_path=filename[7:], data=file_content, append=True)
def write(self, filename, file_content, binary_mode=False):
self.cli.write(hdfs_path=filename[7:], data=file_content)
def walk(self, dir):
walks = self.cli.walk(hdfs_path=dir[7:])
return (['hdfs://'+root, dirs, files] for root, dirs, files in walks)
try:
default_file_factory.register_filesystem("hdfs", HDFileSystem())
except hdfs.util.HdfsError:
print("HDFS initialization failed, please check if .hdfscli,cfg exists.")
class BFile(object): class BFile(object):
def __init__(self, filename, mode): def __init__(self, filename, mode):
if mode not in ('r', 'rb', 'br', 'w', 'wb', 'bw'): if mode not in ('r', 'rb', 'br', 'w', 'wb', 'bw'):
......
...@@ -29,6 +29,11 @@ message Record { ...@@ -29,6 +29,11 @@ message Record {
bytes encoded_vectors = 2; bytes encoded_vectors = 2;
} }
message Histogram {
repeated double hist = 1 [packed = true];
repeated double bin_edges = 2 [packed = true];
};
message Value { message Value {
int64 id = 1; int64 id = 1;
string tag = 2; string tag = 2;
...@@ -38,6 +43,7 @@ message Record { ...@@ -38,6 +43,7 @@ message Record {
Image image = 5; Image image = 5;
Audio audio = 6; Audio audio = 6;
Embeddings embeddings = 7; Embeddings embeddings = 7;
Histogram histogram = 8;
} }
} }
......
...@@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='visualdl', package='visualdl',
syntax='proto3', syntax='proto3',
serialized_options=None, serialized_options=None,
serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xdf\x04\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1a<\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\xd6\x01\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x42\x0b\n\tone_valueb\x06proto3' serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xc6\x05\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1a<\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1a\x87\x02\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x42\x0b\n\tone_valueb\x06proto3'
) )
...@@ -216,6 +216,43 @@ _RECORD_BYTES_EMBEDDINGS = _descriptor.Descriptor( ...@@ -216,6 +216,43 @@ _RECORD_BYTES_EMBEDDINGS = _descriptor.Descriptor(
serialized_end=417, serialized_end=417,
) )
_RECORD_HISTOGRAM = _descriptor.Descriptor(
name='Histogram',
full_name='visualdl.Record.Histogram',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='hist', full_name='visualdl.Record.Histogram.hist', index=0,
number=1, type=1, cpp_type=5, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=b'\020\001', file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='bin_edges', full_name='visualdl.Record.Histogram.bin_edges', index=1,
number=2, type=1, cpp_type=5, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=b'\020\001', file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=419,
serialized_end=471,
)
_RECORD_VALUE = _descriptor.Descriptor( _RECORD_VALUE = _descriptor.Descriptor(
name='Value', name='Value',
full_name='visualdl.Record.Value', full_name='visualdl.Record.Value',
...@@ -272,6 +309,13 @@ _RECORD_VALUE = _descriptor.Descriptor( ...@@ -272,6 +309,13 @@ _RECORD_VALUE = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='histogram', full_name='visualdl.Record.Value.histogram', index=7,
number=8, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
], ],
extensions=[ extensions=[
], ],
...@@ -287,8 +331,8 @@ _RECORD_VALUE = _descriptor.Descriptor( ...@@ -287,8 +331,8 @@ _RECORD_VALUE = _descriptor.Descriptor(
name='one_value', full_name='visualdl.Record.Value.one_value', name='one_value', full_name='visualdl.Record.Value.one_value',
index=0, containing_type=None, fields=[]), index=0, containing_type=None, fields=[]),
], ],
serialized_start=420, serialized_start=474,
serialized_end=634, serialized_end=737,
) )
_RECORD = _descriptor.Descriptor( _RECORD = _descriptor.Descriptor(
...@@ -308,7 +352,7 @@ _RECORD = _descriptor.Descriptor( ...@@ -308,7 +352,7 @@ _RECORD = _descriptor.Descriptor(
], ],
extensions=[ extensions=[
], ],
nested_types=[_RECORD_IMAGE, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_VALUE, ], nested_types=[_RECORD_IMAGE, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_HISTOGRAM, _RECORD_VALUE, ],
enum_types=[ enum_types=[
], ],
serialized_options=None, serialized_options=None,
...@@ -318,7 +362,7 @@ _RECORD = _descriptor.Descriptor( ...@@ -318,7 +362,7 @@ _RECORD = _descriptor.Descriptor(
oneofs=[ oneofs=[
], ],
serialized_start=27, serialized_start=27,
serialized_end=634, serialized_end=737,
) )
_RECORD_IMAGE.containing_type = _RECORD _RECORD_IMAGE.containing_type = _RECORD
...@@ -327,9 +371,11 @@ _RECORD_EMBEDDING.containing_type = _RECORD ...@@ -327,9 +371,11 @@ _RECORD_EMBEDDING.containing_type = _RECORD
_RECORD_EMBEDDINGS.fields_by_name['embeddings'].message_type = _RECORD_EMBEDDING _RECORD_EMBEDDINGS.fields_by_name['embeddings'].message_type = _RECORD_EMBEDDING
_RECORD_EMBEDDINGS.containing_type = _RECORD _RECORD_EMBEDDINGS.containing_type = _RECORD
_RECORD_BYTES_EMBEDDINGS.containing_type = _RECORD _RECORD_BYTES_EMBEDDINGS.containing_type = _RECORD
_RECORD_HISTOGRAM.containing_type = _RECORD
_RECORD_VALUE.fields_by_name['image'].message_type = _RECORD_IMAGE _RECORD_VALUE.fields_by_name['image'].message_type = _RECORD_IMAGE
_RECORD_VALUE.fields_by_name['audio'].message_type = _RECORD_AUDIO _RECORD_VALUE.fields_by_name['audio'].message_type = _RECORD_AUDIO
_RECORD_VALUE.fields_by_name['embeddings'].message_type = _RECORD_EMBEDDINGS _RECORD_VALUE.fields_by_name['embeddings'].message_type = _RECORD_EMBEDDINGS
_RECORD_VALUE.fields_by_name['histogram'].message_type = _RECORD_HISTOGRAM
_RECORD_VALUE.containing_type = _RECORD _RECORD_VALUE.containing_type = _RECORD
_RECORD_VALUE.oneofs_by_name['one_value'].fields.append( _RECORD_VALUE.oneofs_by_name['one_value'].fields.append(
_RECORD_VALUE.fields_by_name['value']) _RECORD_VALUE.fields_by_name['value'])
...@@ -343,6 +389,9 @@ _RECORD_VALUE.fields_by_name['audio'].containing_oneof = _RECORD_VALUE.oneofs_by ...@@ -343,6 +389,9 @@ _RECORD_VALUE.fields_by_name['audio'].containing_oneof = _RECORD_VALUE.oneofs_by
_RECORD_VALUE.oneofs_by_name['one_value'].fields.append( _RECORD_VALUE.oneofs_by_name['one_value'].fields.append(
_RECORD_VALUE.fields_by_name['embeddings']) _RECORD_VALUE.fields_by_name['embeddings'])
_RECORD_VALUE.fields_by_name['embeddings'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value'] _RECORD_VALUE.fields_by_name['embeddings'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value']
_RECORD_VALUE.oneofs_by_name['one_value'].fields.append(
_RECORD_VALUE.fields_by_name['histogram'])
_RECORD_VALUE.fields_by_name['histogram'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value']
_RECORD.fields_by_name['values'].message_type = _RECORD_VALUE _RECORD.fields_by_name['values'].message_type = _RECORD_VALUE
DESCRIPTOR.message_types_by_name['Record'] = _RECORD DESCRIPTOR.message_types_by_name['Record'] = _RECORD
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -384,6 +433,13 @@ Record = _reflection.GeneratedProtocolMessageType('Record', (_message.Message,), ...@@ -384,6 +433,13 @@ Record = _reflection.GeneratedProtocolMessageType('Record', (_message.Message,),
}) })
, ,
'Histogram' : _reflection.GeneratedProtocolMessageType('Histogram', (_message.Message,), {
'DESCRIPTOR' : _RECORD_HISTOGRAM,
'__module__' : 'record_pb2'
# @@protoc_insertion_point(class_scope:visualdl.Record.Histogram)
})
,
'Value' : _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), { 'Value' : _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), {
'DESCRIPTOR' : _RECORD_VALUE, 'DESCRIPTOR' : _RECORD_VALUE,
'__module__' : 'record_pb2' '__module__' : 'record_pb2'
...@@ -400,7 +456,10 @@ _sym_db.RegisterMessage(Record.Audio) ...@@ -400,7 +456,10 @@ _sym_db.RegisterMessage(Record.Audio)
_sym_db.RegisterMessage(Record.Embedding) _sym_db.RegisterMessage(Record.Embedding)
_sym_db.RegisterMessage(Record.Embeddings) _sym_db.RegisterMessage(Record.Embeddings)
_sym_db.RegisterMessage(Record.bytes_embeddings) _sym_db.RegisterMessage(Record.bytes_embeddings)
_sym_db.RegisterMessage(Record.Histogram)
_sym_db.RegisterMessage(Record.Value) _sym_db.RegisterMessage(Record.Value)
_RECORD_HISTOGRAM.fields_by_name['hist']._options = None
_RECORD_HISTOGRAM.fields_by_name['bin_edges']._options = None
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)
...@@ -61,6 +61,10 @@ class LogReader(object): ...@@ -61,6 +61,10 @@ class LogReader(object):
self.load_new_data(update=True) self.load_new_data(update=True)
self._a_tags = {} self._a_tags = {}
@property
def logdir(self):
return self.dir
def parse_from_bin(self, record_bin): def parse_from_bin(self, record_bin):
"""Register to self._tags by component type. """Register to self._tags by component type.
...@@ -83,6 +87,8 @@ class LogReader(object): ...@@ -83,6 +87,8 @@ class LogReader(object):
component = "embeddings" component = "embeddings"
elif "audio" == value_type: elif "audio" == value_type:
component = "audio" component = "audio"
elif "histogram" == value_type:
component = "histogram"
else: else:
raise TypeError("Invalid value type `%s`." % value_type) raise TypeError("Invalid value type `%s`." % value_type)
self._tags[path] = component self._tags[path] = component
...@@ -202,9 +208,14 @@ class LogReader(object): ...@@ -202,9 +208,14 @@ class LogReader(object):
def components(self, update=False): def components(self, update=False):
"""Get components type used by vdl. """Get components type used by vdl.
""" """
if self.logdir is None:
return set()
if update is True: if update is True:
self.load_new_data(update=update) self.load_new_data(update=update)
return list(set(self._tags.values())) components_set = set(self._tags.values())
if 0 == len(components_set):
return {'scalar'}
return components_set
def load_new_data(self, update=True): def load_new_data(self, update=True):
"""Load remain data. """Load remain data.
...@@ -212,5 +223,6 @@ class LogReader(object): ...@@ -212,5 +223,6 @@ class LogReader(object):
Make sure all readers for every vdl log file are registered, load all Make sure all readers for every vdl log file are registered, load all
remain data. remain data.
""" """
self.register_readers(update=update) if self.logdir is not None:
self.add_remain() self.register_readers(update=update)
self.add_remain()
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ======================================================================= # =======================================================================
from visualdl.io import bfile from visualdl.io import bfile
import struct import struct
from hdfs.util import HdfsError
class _RecordReader(object): class _RecordReader(object):
...@@ -30,10 +31,13 @@ class _RecordReader(object): ...@@ -30,10 +31,13 @@ class _RecordReader(object):
def get_next(self): def get_next(self):
# Read the header # Read the header
self._curr_event = None self._curr_event = None
header_str = self.file_handle.read(8) try:
header_str = self.file_handle.read(8)
except HdfsError:
raise EOFError('No more events to read on HDFS.')
if len(header_str) != 8: if len(header_str) != 8:
# Hit EOF so raise and exit # Hit EOF so raise and exit
raise EOFError('No more events to read') raise EOFError('No more events to read on LFS.')
header = struct.unpack('Q', header_str) header = struct.unpack('Q', header_str)
header_len = int(header[0]) header_len = int(header[0])
event_str = self.file_handle.read(header_len) event_str = self.file_handle.read(header_len)
......
...@@ -123,6 +123,7 @@ class Api(object): ...@@ -123,6 +123,7 @@ class Api(object):
key = os.path.join('data/plugin/audio/audio', run, tag) key = os.path.join('data/plugin/audio/audio', run, tag)
return self._get_with_retry(key, lib.get_audio_tag_steps, run, tag) return self._get_with_retry(key, lib.get_audio_tag_steps, run, tag)
@result()
def audio_audio(self, run, tag, index=0): def audio_audio(self, run, tag, index=0):
index = int(index) index = int(index)
key = os.path.join('data/plugin/audio/individualAudio', run, tag, str(index)) key = os.path.join('data/plugin/audio/individualAudio', run, tag, str(index))
...@@ -134,6 +135,15 @@ class Api(object): ...@@ -134,6 +135,15 @@ class Api(object):
key = os.path.join('data/plugin/embeddings/embeddings', run, str(dimension), reduction) key = os.path.join('data/plugin/embeddings/embeddings', run, str(dimension), reduction)
return self._get_with_retry(key, lib.get_embeddings, run, tag, reduction, dimension) return self._get_with_retry(key, lib.get_embeddings, run, tag, reduction, dimension)
@result()
def histogram_tags(self):
return self._get_with_retry('data/plugin/histogram/tags', lib.get_histogram_tags)
@result()
def histogram_histogram(self, run, tag):
key = os.path.join('data/plugin/embeddings/embeddings', run, tag)
return self._get_with_retry(key, lib.get_embeddings, run, tag)
def create_api_call(logdir, cache_timeout): def create_api_call(logdir, cache_timeout):
api = Api(logdir, cache_timeout) api = Api(logdir, cache_timeout)
...@@ -146,12 +156,14 @@ def create_api_call(logdir, cache_timeout): ...@@ -146,12 +156,14 @@ def create_api_call(logdir, cache_timeout):
'images/tags': (api.images_tags, []), 'images/tags': (api.images_tags, []),
'audio/tags': (api.audio_tags, []), 'audio/tags': (api.audio_tags, []),
'embeddings/tags': (api.embeddings_tags, []), 'embeddings/tags': (api.embeddings_tags, []),
'histogram/tags': (api.histogram_tags, []),
'scalars/list': (api.scalars_list, ['run', 'tag']), 'scalars/list': (api.scalars_list, ['run', 'tag']),
'images/list': (api.images_list, ['run', 'tag']), 'images/list': (api.images_list, ['run', 'tag']),
'images/image': (api.images_image, ['run', 'tag', 'index']), 'images/image': (api.images_image, ['run', 'tag', 'index']),
'audio/list': (api.audio_list, ['run', 'tag']), 'audio/list': (api.audio_list, ['run', 'tag']),
'audio/audio': (api.audio_audio, ['run', 'tag', 'index']), 'audio/audio': (api.audio_audio, ['run', 'tag', 'index']),
'embeddings/embedding': (api.embeddings_embedding, ['run', 'tag', 'reduction', 'dimension']) 'embeddings/embedding': (api.embeddings_embedding, ['run', 'tag', 'reduction', 'dimension']),
'histogram/histogram': (api.histogram_histogram, ['run', 'tag'])
} }
def call(path: str, args): def call(path: str, args):
......
...@@ -89,6 +89,7 @@ def create_app(args): ...@@ -89,6 +89,7 @@ def create_app(args):
lang = get_locale() lang = get_locale()
if lang == default_language: if lang == default_language:
return redirect(public_path + '/index', code=302) return redirect(public_path + '/index', code=302)
lang = support_language[0] if lang is None else lang
return redirect(public_path + '/' + lang + '/index', code=302) return redirect(public_path + '/' + lang + '/index', code=302)
@app.route(public_path + '/<path:filename>') @app.route(public_path + '/<path:filename>')
......
...@@ -38,11 +38,6 @@ class DefaultArgs(object): ...@@ -38,11 +38,6 @@ class DefaultArgs(object):
def validate_args(args): def validate_args(args):
# exit if no logdir specified
if not args.logdir or args.logdir is None:
logger.error('Log directory is not specified.')
sys.exit(-1)
# if not in API mode, public path cannot be set to root path # if not in API mode, public path cannot be set to root path
if not args.api_only and args.public_path == '/': if not args.api_only and args.public_path == '/':
logger.error('Public path cannot be set to root path.') logger.error('Public path cannot be set to root path.')
...@@ -91,7 +86,6 @@ def parse_args(): ...@@ -91,7 +86,6 @@ def parse_args():
parser = ArgumentParser(description="VisualDL, a tool to visualize deep learning.") parser = ArgumentParser(description="VisualDL, a tool to visualize deep learning.")
parser.add_argument( parser.add_argument(
"--logdir", "--logdir",
required=True,
action="store", action="store",
dest="logdir", dest="logdir",
nargs="+", nargs="+",
...@@ -142,8 +136,5 @@ def parse_args(): ...@@ -142,8 +136,5 @@ def parse_args():
) )
args = parser.parse_args() args = parser.parse_args()
# print help if no logdir specified
if not args.logdir:
parser.print_help()
return format_args(args) return format_args(args)
...@@ -22,7 +22,9 @@ from visualdl.utils.string_util import encode_tag, decode_tag ...@@ -22,7 +22,9 @@ from visualdl.utils.string_util import encode_tag, decode_tag
def get_components(log_reader): def get_components(log_reader):
return log_reader.components(update=True) components = log_reader.components(update=True)
components.add('graph')
return list(components)
def get_runs(log_reader): def get_runs(log_reader):
...@@ -108,6 +110,10 @@ def get_embeddings_tags(log_reader): ...@@ -108,6 +110,10 @@ def get_embeddings_tags(log_reader):
return get_logs(log_reader, "embeddings") return get_logs(log_reader, "embeddings")
def get_histogram_tags(log_reader):
return get_logs(log_reader, "histogram")
def get_embeddings(log_reader, run, tag, reduction, dimension=2): def get_embeddings(log_reader, run, tag, reduction, dimension=2):
log_reader.load_new_data() log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("embeddings").get_items( records = log_reader.data_manager.get_reservoir("embeddings").get_items(
...@@ -131,6 +137,24 @@ def get_embeddings(log_reader, run, tag, reduction, dimension=2): ...@@ -131,6 +137,24 @@ def get_embeddings(log_reader, run, tag, reduction, dimension=2):
return {"embedding": low_dim_embs.tolist(), "labels": labels} return {"embedding": low_dim_embs.tolist(), "labels": labels}
def get_histogram(log_reader, run, tag):
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("histogram").get_items(
run, decode_tag(tag))
results = []
for item in records:
histogram = item.histogram
hist = histogram.hist
bin_edges = histogram.bin_edges
histogram_data = []
for index in range(len(hist)):
histogram_data.append([bin_edges[index], bin_edges[index+1], hist[index]])
results.append([item.timestamp, item.id, histogram_data])
return results
def retry(ntimes, function, time2sleep, *args, **kwargs): def retry(ntimes, function, time2sleep, *args, **kwargs):
''' '''
try to execute `function` `ntimes`, if exception catched, the thread will try to execute `function` `ntimes`, if exception catched, the thread will
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import os import os
import time import time
from visualdl.writer.record_writer import RecordFileWriter from visualdl.writer.record_writer import RecordFileWriter
from visualdl.component.base_component import scalar, image, embedding, audio from visualdl.component.base_component import scalar, image, embedding, audio, histogram
import numpy as np import numpy as np
...@@ -99,6 +99,10 @@ class LogWriter(object): ...@@ -99,6 +99,10 @@ class LogWriter(object):
self._get_file_writer() self._get_file_writer()
self.loggers = {} self.loggers = {}
@property
def logdir(self):
return self._logdir
def _get_file_writer(self): def _get_file_writer(self):
if not self._write_to_disk: if not self._write_to_disk:
self._file_writer = DummyFileWriter(logdir=self._logdir) self._file_writer = DummyFileWriter(logdir=self._logdir)
...@@ -242,6 +246,41 @@ class LogWriter(object): ...@@ -242,6 +246,41 @@ class LogWriter(object):
step=step, step=step,
walltime=walltime)) walltime=walltime))
def add_histogram(self,
tag,
values,
step,
walltime=None,
buckets=10):
"""Add an histogram to vdl record file.
Args:
tag (string): Data identifier
value (numpy.ndarray or list): value represented by a numpy.array or list
step (int): Step of histogram
walltime (int): Wall time of audio
buckets (int): Number of buckets, default is 10
Example:
values = np.arange(0, 1000)
with LogWriter(logdir="./log/histogram_test/train") as writer:
for index in range(5):
writer.add_histogram(tag='default',
values=values+index,
step=index)
"""
if '%' in tag:
raise RuntimeError("% can't appear in tag!")
hist, bin_edges = np.histogram(values, bins=buckets)
walltime = round(time.time()) if walltime is None else walltime
self._get_file_writer().add_record(
histogram(
tag=tag,
hist=hist,
bin_edges=bin_edges,
step=step,
walltime=walltime))
def flush(self): def flush(self):
"""Flush all data in cache to disk. """Flush all data in cache to disk.
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册