diff --git a/demo/components/histogram_test.py b/demo/components/histogram_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4b3ed85688fa79db2c8f895012e8a1938155bf00 --- /dev/null +++ b/demo/components/histogram_test.py @@ -0,0 +1,27 @@ +# Copyright (c) 2020 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +# coding=utf-8 +from visualdl import LogWriter +import numpy as np + + +if __name__ == '__main__': + values = np.arange(0, 1000) + with LogWriter(logdir="./log/histogram_test/train") as writer: + for index in range(5): + writer.add_histogram(tag='default1', + values=values+index, + step=index, + buckets=10) diff --git a/requirements.txt b/requirements.txt index 8afc5e046f08adeb0d2d1dd15e711b1b58934b36..afafa10a896fbe5888683e5efd45437b243b3d0e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ Flask-Babel >= 1.0.0 six >= 1.14.0 protobuf >= 3.1.0 opencv-python +hdfs diff --git a/visualdl/component/__init__.py b/visualdl/component/__init__.py index eaf6de06da38ad1f33ac9a988ab52bff43303444..249ab41f6241014915b56c266e8a62a808c436b5 100644 --- a/visualdl/component/__init__.py +++ b/visualdl/component/__init__.py @@ -24,5 +24,11 @@ components = { }, "audio": { "enabled": False + }, + "histogram": { + "enabled": False + }, + "graph": { + "enabled": False } } diff --git a/visualdl/component/base_component.py b/visualdl/component/base_component.py index 72038bee2efb951cdb3b8914159d8123b499eae0..ece24e3a6b5131e644035f2043580b2a9a39c0fa 100644 --- a/visualdl/component/base_component.py +++ b/visualdl/component/base_component.py @@ -131,3 +131,11 @@ def audio(tag, audio_array, sample_rate, step, walltime): return Record(values=[ Record.Value(id=step, tag=tag, timestamp=walltime, audio=audio_data) ]) + + +def histogram(tag, hist, bin_edges, step, walltime): + histogram = Record.Histogram(hist=hist, bin_edges=bin_edges) + return Record(values=[ + Record.Value( + id=step, tag=tag, timestamp=walltime, histogram=histogram) + ]) diff --git a/visualdl/io/bfile.py b/visualdl/io/bfile.py index d074492cd0cb4a79cfeea3d9596c1ddc3608e42b..e7a7bdbb35ad20f97dc25133b483e0a5be2e56a2 100644 --- a/visualdl/io/bfile.py +++ b/visualdl/io/bfile.py @@ -15,6 +15,7 @@ import os import tempfile +import hdfs # Note: Some codes here refer to TensorBoardX. # A good default block size depends on the system in question. @@ -90,6 +91,51 @@ class LocalFileSystem(object): default_file_factory.register_filesystem("", LocalFileSystem()) +class HDFileSystem(object): + def __init__(self): + self.cli = hdfs.config.Config().get_client('dev') + + def exists(self, path): + if self.cli.status(hdfs_path=path[7:], strict=False) is None: + return False + else: + return True + + def makedirs(self, path): + self.cli.makedirs(hdfs_path=path[7:]) + + @staticmethod + def join(path, *paths): + return os.path.join(path, *paths) + + def read(self, filename, binary_mode=False, size=0, continue_from=None): + offset = 0 + if continue_from is not None: + offset = continue_from.get("last_offset", 0) + + encoding = None if binary_mode else "utf-8" + with self.cli.read(hdfs_path=filename[7:], offset=offset, encoding=encoding) as reader: + data = reader.read() + continue_from_token = {"last_offset": offset + len(data)} + return data, continue_from_token + + def append(self, filename, file_content, binary_mode=False): + self.cli.write(hdfs_path=filename[7:], data=file_content, append=True) + + def write(self, filename, file_content, binary_mode=False): + self.cli.write(hdfs_path=filename[7:], data=file_content) + + def walk(self, dir): + walks = self.cli.walk(hdfs_path=dir[7:]) + return (['hdfs://'+root, dirs, files] for root, dirs, files in walks) + + +try: + default_file_factory.register_filesystem("hdfs", HDFileSystem()) +except hdfs.util.HdfsError: + print("HDFS initialization failed, please check if .hdfscli,cfg exists.") + + class BFile(object): def __init__(self, filename, mode): if mode not in ('r', 'rb', 'br', 'w', 'wb', 'bw'): diff --git a/visualdl/proto/record.proto b/visualdl/proto/record.proto index c7f23929ba39fb52b886d6757e1f3e7fbd7ad08c..627c2118bfced1105a66a4da5e33d790c1d2cfc8 100644 --- a/visualdl/proto/record.proto +++ b/visualdl/proto/record.proto @@ -29,6 +29,11 @@ message Record { bytes encoded_vectors = 2; } +message Histogram { + repeated double hist = 1 [packed = true]; + repeated double bin_edges = 2 [packed = true]; +}; + message Value { int64 id = 1; string tag = 2; @@ -38,6 +43,7 @@ message Record { Image image = 5; Audio audio = 6; Embeddings embeddings = 7; + Histogram histogram = 8; } } diff --git a/visualdl/proto/record_pb2.py b/visualdl/proto/record_pb2.py index 7d213a6bb20abb8f5a1d2e7f2ec0b631596ab3d2..d44ecf85348ad351903f779ddc99fa88d0668660 100644 --- a/visualdl/proto/record_pb2.py +++ b/visualdl/proto/record_pb2.py @@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( package='visualdl', syntax='proto3', serialized_options=None, - serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xdf\x04\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1a<\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\xd6\x01\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x42\x0b\n\tone_valueb\x06proto3' + serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xc6\x05\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1a<\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1a\x87\x02\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x42\x0b\n\tone_valueb\x06proto3' ) @@ -216,6 +216,43 @@ _RECORD_BYTES_EMBEDDINGS = _descriptor.Descriptor( serialized_end=417, ) +_RECORD_HISTOGRAM = _descriptor.Descriptor( + name='Histogram', + full_name='visualdl.Record.Histogram', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='hist', full_name='visualdl.Record.Histogram.hist', index=0, + number=1, type=1, cpp_type=5, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=b'\020\001', file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='bin_edges', full_name='visualdl.Record.Histogram.bin_edges', index=1, + number=2, type=1, cpp_type=5, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=b'\020\001', file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=419, + serialized_end=471, +) + _RECORD_VALUE = _descriptor.Descriptor( name='Value', full_name='visualdl.Record.Value', @@ -272,6 +309,13 @@ _RECORD_VALUE = _descriptor.Descriptor( message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='histogram', full_name='visualdl.Record.Value.histogram', index=7, + number=8, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -287,8 +331,8 @@ _RECORD_VALUE = _descriptor.Descriptor( name='one_value', full_name='visualdl.Record.Value.one_value', index=0, containing_type=None, fields=[]), ], - serialized_start=420, - serialized_end=634, + serialized_start=474, + serialized_end=737, ) _RECORD = _descriptor.Descriptor( @@ -308,7 +352,7 @@ _RECORD = _descriptor.Descriptor( ], extensions=[ ], - nested_types=[_RECORD_IMAGE, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_VALUE, ], + nested_types=[_RECORD_IMAGE, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_HISTOGRAM, _RECORD_VALUE, ], enum_types=[ ], serialized_options=None, @@ -318,7 +362,7 @@ _RECORD = _descriptor.Descriptor( oneofs=[ ], serialized_start=27, - serialized_end=634, + serialized_end=737, ) _RECORD_IMAGE.containing_type = _RECORD @@ -327,9 +371,11 @@ _RECORD_EMBEDDING.containing_type = _RECORD _RECORD_EMBEDDINGS.fields_by_name['embeddings'].message_type = _RECORD_EMBEDDING _RECORD_EMBEDDINGS.containing_type = _RECORD _RECORD_BYTES_EMBEDDINGS.containing_type = _RECORD +_RECORD_HISTOGRAM.containing_type = _RECORD _RECORD_VALUE.fields_by_name['image'].message_type = _RECORD_IMAGE _RECORD_VALUE.fields_by_name['audio'].message_type = _RECORD_AUDIO _RECORD_VALUE.fields_by_name['embeddings'].message_type = _RECORD_EMBEDDINGS +_RECORD_VALUE.fields_by_name['histogram'].message_type = _RECORD_HISTOGRAM _RECORD_VALUE.containing_type = _RECORD _RECORD_VALUE.oneofs_by_name['one_value'].fields.append( _RECORD_VALUE.fields_by_name['value']) @@ -343,6 +389,9 @@ _RECORD_VALUE.fields_by_name['audio'].containing_oneof = _RECORD_VALUE.oneofs_by _RECORD_VALUE.oneofs_by_name['one_value'].fields.append( _RECORD_VALUE.fields_by_name['embeddings']) _RECORD_VALUE.fields_by_name['embeddings'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value'] +_RECORD_VALUE.oneofs_by_name['one_value'].fields.append( + _RECORD_VALUE.fields_by_name['histogram']) +_RECORD_VALUE.fields_by_name['histogram'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value'] _RECORD.fields_by_name['values'].message_type = _RECORD_VALUE DESCRIPTOR.message_types_by_name['Record'] = _RECORD _sym_db.RegisterFileDescriptor(DESCRIPTOR) @@ -384,6 +433,13 @@ Record = _reflection.GeneratedProtocolMessageType('Record', (_message.Message,), }) , + 'Histogram' : _reflection.GeneratedProtocolMessageType('Histogram', (_message.Message,), { + 'DESCRIPTOR' : _RECORD_HISTOGRAM, + '__module__' : 'record_pb2' + # @@protoc_insertion_point(class_scope:visualdl.Record.Histogram) + }) + , + 'Value' : _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), { 'DESCRIPTOR' : _RECORD_VALUE, '__module__' : 'record_pb2' @@ -400,7 +456,10 @@ _sym_db.RegisterMessage(Record.Audio) _sym_db.RegisterMessage(Record.Embedding) _sym_db.RegisterMessage(Record.Embeddings) _sym_db.RegisterMessage(Record.bytes_embeddings) +_sym_db.RegisterMessage(Record.Histogram) _sym_db.RegisterMessage(Record.Value) +_RECORD_HISTOGRAM.fields_by_name['hist']._options = None +_RECORD_HISTOGRAM.fields_by_name['bin_edges']._options = None # @@protoc_insertion_point(module_scope) diff --git a/visualdl/reader/reader.py b/visualdl/reader/reader.py index 9befb7480b8caa40492135abd6f4e7a4a4afd779..c89db3b618322c0041f20db016b15902732b66be 100644 --- a/visualdl/reader/reader.py +++ b/visualdl/reader/reader.py @@ -61,6 +61,10 @@ class LogReader(object): self.load_new_data(update=True) self._a_tags = {} + @property + def logdir(self): + return self.dir + def parse_from_bin(self, record_bin): """Register to self._tags by component type. @@ -83,6 +87,8 @@ class LogReader(object): component = "embeddings" elif "audio" == value_type: component = "audio" + elif "histogram" == value_type: + component = "histogram" else: raise TypeError("Invalid value type `%s`." % value_type) self._tags[path] = component @@ -202,9 +208,14 @@ class LogReader(object): def components(self, update=False): """Get components type used by vdl. """ + if self.logdir is None: + return set() if update is True: self.load_new_data(update=update) - return list(set(self._tags.values())) + components_set = set(self._tags.values()) + if 0 == len(components_set): + return {'scalar'} + return components_set def load_new_data(self, update=True): """Load remain data. @@ -212,5 +223,6 @@ class LogReader(object): Make sure all readers for every vdl log file are registered, load all remain data. """ - self.register_readers(update=update) - self.add_remain() + if self.logdir is not None: + self.register_readers(update=update) + self.add_remain() diff --git a/visualdl/reader/record_reader.py b/visualdl/reader/record_reader.py index 0936af064a5c4390f85f73cd1cc11e347e495ddd..65f97484fde319cb862ae5013521e2b4bee1a4a5 100644 --- a/visualdl/reader/record_reader.py +++ b/visualdl/reader/record_reader.py @@ -14,6 +14,7 @@ # ======================================================================= from visualdl.io import bfile import struct +from hdfs.util import HdfsError class _RecordReader(object): @@ -30,10 +31,13 @@ class _RecordReader(object): def get_next(self): # Read the header self._curr_event = None - header_str = self.file_handle.read(8) + try: + header_str = self.file_handle.read(8) + except HdfsError: + raise EOFError('No more events to read on HDFS.') if len(header_str) != 8: # Hit EOF so raise and exit - raise EOFError('No more events to read') + raise EOFError('No more events to read on LFS.') header = struct.unpack('Q', header_str) header_len = int(header[0]) event_str = self.file_handle.read(header_len) diff --git a/visualdl/server/api.py b/visualdl/server/api.py index 355ddbedb30338ee9e6bd852973020672c9e7a74..463bcaf141d4d046d112a08838e242598d1e1263 100644 --- a/visualdl/server/api.py +++ b/visualdl/server/api.py @@ -123,6 +123,7 @@ class Api(object): key = os.path.join('data/plugin/audio/audio', run, tag) return self._get_with_retry(key, lib.get_audio_tag_steps, run, tag) + @result() def audio_audio(self, run, tag, index=0): index = int(index) key = os.path.join('data/plugin/audio/individualAudio', run, tag, str(index)) @@ -134,6 +135,15 @@ class Api(object): key = os.path.join('data/plugin/embeddings/embeddings', run, str(dimension), reduction) return self._get_with_retry(key, lib.get_embeddings, run, tag, reduction, dimension) + @result() + def histogram_tags(self): + return self._get_with_retry('data/plugin/histogram/tags', lib.get_histogram_tags) + + @result() + def histogram_histogram(self, run, tag): + key = os.path.join('data/plugin/embeddings/embeddings', run, tag) + return self._get_with_retry(key, lib.get_embeddings, run, tag) + def create_api_call(logdir, cache_timeout): api = Api(logdir, cache_timeout) @@ -146,12 +156,14 @@ def create_api_call(logdir, cache_timeout): 'images/tags': (api.images_tags, []), 'audio/tags': (api.audio_tags, []), 'embeddings/tags': (api.embeddings_tags, []), + 'histogram/tags': (api.histogram_tags, []), 'scalars/list': (api.scalars_list, ['run', 'tag']), 'images/list': (api.images_list, ['run', 'tag']), 'images/image': (api.images_image, ['run', 'tag', 'index']), 'audio/list': (api.audio_list, ['run', 'tag']), 'audio/audio': (api.audio_audio, ['run', 'tag', 'index']), - 'embeddings/embedding': (api.embeddings_embedding, ['run', 'tag', 'reduction', 'dimension']) + 'embeddings/embedding': (api.embeddings_embedding, ['run', 'tag', 'reduction', 'dimension']), + 'histogram/histogram': (api.histogram_histogram, ['run', 'tag']) } def call(path: str, args): diff --git a/visualdl/server/app.py b/visualdl/server/app.py index a36610122a87c4d0f70c1ecdd6689ae918a91d1d..db1f6ae52b79f5f02a4a594e43f1368c4c6ea47e 100644 --- a/visualdl/server/app.py +++ b/visualdl/server/app.py @@ -89,6 +89,7 @@ def create_app(args): lang = get_locale() if lang == default_language: return redirect(public_path + '/index', code=302) + lang = support_language[0] if lang is None else lang return redirect(public_path + '/' + lang + '/index', code=302) @app.route(public_path + '/') diff --git a/visualdl/server/args.py b/visualdl/server/args.py index 293bb94df0208e3c28582a7411def7bbcca2532c..ff42c227b5d4fddaee88f10ddccd209e9e56fa88 100644 --- a/visualdl/server/args.py +++ b/visualdl/server/args.py @@ -38,11 +38,6 @@ class DefaultArgs(object): def validate_args(args): - # exit if no logdir specified - if not args.logdir or args.logdir is None: - logger.error('Log directory is not specified.') - sys.exit(-1) - # if not in API mode, public path cannot be set to root path if not args.api_only and args.public_path == '/': logger.error('Public path cannot be set to root path.') @@ -91,7 +86,6 @@ def parse_args(): parser = ArgumentParser(description="VisualDL, a tool to visualize deep learning.") parser.add_argument( "--logdir", - required=True, action="store", dest="logdir", nargs="+", @@ -142,8 +136,5 @@ def parse_args(): ) args = parser.parse_args() - # print help if no logdir specified - if not args.logdir: - parser.print_help() return format_args(args) diff --git a/visualdl/server/lib.py b/visualdl/server/lib.py index 2d864aab9b9c2cd53184bdd7c0e2df20627ea7fd..371fce306fff5e56cd6febd6c182242312fae5ce 100644 --- a/visualdl/server/lib.py +++ b/visualdl/server/lib.py @@ -22,7 +22,9 @@ from visualdl.utils.string_util import encode_tag, decode_tag def get_components(log_reader): - return log_reader.components(update=True) + components = log_reader.components(update=True) + components.add('graph') + return list(components) def get_runs(log_reader): @@ -108,6 +110,10 @@ def get_embeddings_tags(log_reader): return get_logs(log_reader, "embeddings") +def get_histogram_tags(log_reader): + return get_logs(log_reader, "histogram") + + def get_embeddings(log_reader, run, tag, reduction, dimension=2): log_reader.load_new_data() records = log_reader.data_manager.get_reservoir("embeddings").get_items( @@ -131,6 +137,24 @@ def get_embeddings(log_reader, run, tag, reduction, dimension=2): return {"embedding": low_dim_embs.tolist(), "labels": labels} +def get_histogram(log_reader, run, tag): + log_reader.load_new_data() + records = log_reader.data_manager.get_reservoir("histogram").get_items( + run, decode_tag(tag)) + + results = [] + for item in records: + histogram = item.histogram + hist = histogram.hist + bin_edges = histogram.bin_edges + histogram_data = [] + for index in range(len(hist)): + histogram_data.append([bin_edges[index], bin_edges[index+1], hist[index]]) + results.append([item.timestamp, item.id, histogram_data]) + + return results + + def retry(ntimes, function, time2sleep, *args, **kwargs): ''' try to execute `function` `ntimes`, if exception catched, the thread will diff --git a/visualdl/writer/writer.py b/visualdl/writer/writer.py index d05e3a6cd78a65341b2b7770b45b9f899163c92f..fb4585726b3d98440cc43883909ce3fa1e13fe85 100644 --- a/visualdl/writer/writer.py +++ b/visualdl/writer/writer.py @@ -15,7 +15,7 @@ import os import time from visualdl.writer.record_writer import RecordFileWriter -from visualdl.component.base_component import scalar, image, embedding, audio +from visualdl.component.base_component import scalar, image, embedding, audio, histogram import numpy as np @@ -99,6 +99,10 @@ class LogWriter(object): self._get_file_writer() self.loggers = {} + @property + def logdir(self): + return self._logdir + def _get_file_writer(self): if not self._write_to_disk: self._file_writer = DummyFileWriter(logdir=self._logdir) @@ -242,6 +246,41 @@ class LogWriter(object): step=step, walltime=walltime)) + def add_histogram(self, + tag, + values, + step, + walltime=None, + buckets=10): + """Add an histogram to vdl record file. + + Args: + tag (string): Data identifier + value (numpy.ndarray or list): value represented by a numpy.array or list + step (int): Step of histogram + walltime (int): Wall time of audio + buckets (int): Number of buckets, default is 10 + + Example: + values = np.arange(0, 1000) + with LogWriter(logdir="./log/histogram_test/train") as writer: + for index in range(5): + writer.add_histogram(tag='default', + values=values+index, + step=index) + """ + if '%' in tag: + raise RuntimeError("% can't appear in tag!") + hist, bin_edges = np.histogram(values, bins=buckets) + walltime = round(time.time()) if walltime is None else walltime + self._get_file_writer().add_record( + histogram( + tag=tag, + hist=hist, + bin_edges=bin_edges, + step=step, + walltime=walltime)) + def flush(self): """Flush all data in cache to disk. """