From f883251ff7e3f118e1f6f9b04d5154d9186d08c3 Mon Sep 17 00:00:00 2001 From: ShenYuhan Date: Tue, 26 Jan 2021 16:19:35 +0800 Subject: [PATCH] add text --- visualdl/component/__init__.py | 3 + visualdl/component/base_component.py | 16 +++++ visualdl/proto/record.proto | 5 ++ visualdl/proto/record_pb2.py | 92 +++++++++++++++++++++------- visualdl/reader/reader.py | 2 + visualdl/server/api.py | 18 ++++++ visualdl/server/data_manager.py | 7 ++- visualdl/server/lib.py | 20 ++++++ visualdl/writer/writer.py | 23 ++++++- 9 files changed, 162 insertions(+), 24 deletions(-) diff --git a/visualdl/component/__init__.py b/visualdl/component/__init__.py index e3aa580b..591bb651 100644 --- a/visualdl/component/__init__.py +++ b/visualdl/component/__init__.py @@ -20,6 +20,9 @@ components = { "image": { "enabled": False }, + "text": { + "enabled": False + }, "embedding": { "enabled": False }, diff --git a/visualdl/component/base_component.py b/visualdl/component/base_component.py index cd0869fc..d0342c38 100644 --- a/visualdl/component/base_component.py +++ b/visualdl/component/base_component.py @@ -257,6 +257,22 @@ def audio(tag, audio_array, sample_rate, step, walltime): ]) +def text(tag, text_string, step, walltime=None): + """Package data to one image. + Args: + tag (string): Data identifier + text_string (string): Value of text + step (int): Step of text + walltime (int): Wall time of text + Return: + Package with format of record_pb2.Record + """ + _text = Record.Text(encoded_text_string=text_string) + return Record(values=[ + Record.Value(id=step, tag=tag, timestamp=walltime, text=_text) + ]) + + def histogram(tag, hist, bin_edges, step, walltime): """Package data to one histogram. diff --git a/visualdl/proto/record.proto b/visualdl/proto/record.proto index d827cd20..952218b1 100644 --- a/visualdl/proto/record.proto +++ b/visualdl/proto/record.proto @@ -7,6 +7,10 @@ message Record { bytes encoded_image_string = 4; } + message Text { + string encoded_text_string = 1; + } + message Audio { float sample_rate = 1; int64 num_channels = 2; @@ -70,6 +74,7 @@ message Record { PRCurve pr_curve = 9; MetaData meta_data = 10; ROC_Curve roc_curve = 11; + Text text = 12; } } diff --git a/visualdl/proto/record_pb2.py b/visualdl/proto/record_pb2.py index 4efd9102..223dc587 100644 --- a/visualdl/proto/record_pb2.py +++ b/visualdl/proto/record_pb2.py @@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( package='visualdl', syntax='proto3', serialized_options=None, - serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xe0\x08\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x03(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1aP\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x12\x12\n\nlabel_meta\x18\x02 \x03(\t\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1al\n\x07PRCurve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x11\n\tprecision\x18\x05 \x03(\x01\x12\x0e\n\x06recall\x18\x06 \x03(\x01\x1a\x65\n\tROC_Curve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x0b\n\x03tpr\x18\x05 \x03(\x01\x12\x0b\n\x03\x66pr\x18\x06 \x03(\x01\x1a \n\x08MetaData\x12\x14\n\x0c\x64isplay_name\x18\x01 \x01(\t\x1a\x96\x03\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x12,\n\x08pr_curve\x18\t \x01(\x0b\x32\x18.visualdl.Record.PRCurveH\x00\x12.\n\tmeta_data\x18\n \x01(\x0b\x32\x19.visualdl.Record.MetaDataH\x00\x12/\n\troc_curve\x18\x0b \x01(\x0b\x32\x1a.visualdl.Record.ROC_CurveH\x00\x42\x0b\n\tone_valueb\x06proto3' + serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xac\t\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a#\n\x04Text\x12\x1b\n\x13\x65ncoded_text_string\x18\x01 \x01(\t\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x03(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1aP\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x12\x12\n\nlabel_meta\x18\x02 \x03(\t\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1al\n\x07PRCurve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x11\n\tprecision\x18\x05 \x03(\x01\x12\x0e\n\x06recall\x18\x06 \x03(\x01\x1a\x65\n\tROC_Curve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x0b\n\x03tpr\x18\x05 \x03(\x01\x12\x0b\n\x03\x66pr\x18\x06 \x03(\x01\x1a \n\x08MetaData\x12\x14\n\x0c\x64isplay_name\x18\x01 \x01(\t\x1a\xbd\x03\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x12,\n\x08pr_curve\x18\t \x01(\x0b\x32\x18.visualdl.Record.PRCurveH\x00\x12.\n\tmeta_data\x18\n \x01(\x0b\x32\x19.visualdl.Record.MetaDataH\x00\x12/\n\troc_curve\x18\x0b \x01(\x0b\x32\x1a.visualdl.Record.ROC_CurveH\x00\x12%\n\x04text\x18\x0c \x01(\x0b\x32\x15.visualdl.Record.TextH\x00\x42\x0b\n\tone_valueb\x06proto3' ) @@ -54,6 +54,36 @@ _RECORD_IMAGE = _descriptor.Descriptor( serialized_end=114, ) +_RECORD_TEXT = _descriptor.Descriptor( + name='Text', + full_name='visualdl.Record.Text', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='encoded_text_string', full_name='visualdl.Record.Text.encoded_text_string', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=116, + serialized_end=151, +) + _RECORD_AUDIO = _descriptor.Descriptor( name='Audio', full_name='visualdl.Record.Audio', @@ -108,8 +138,8 @@ _RECORD_AUDIO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=116, - serialized_end=241, + serialized_start=153, + serialized_end=278, ) _RECORD_EMBEDDING = _descriptor.Descriptor( @@ -145,8 +175,8 @@ _RECORD_EMBEDDING = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=243, - serialized_end=286, + serialized_start=280, + serialized_end=323, ) _RECORD_EMBEDDINGS = _descriptor.Descriptor( @@ -182,8 +212,8 @@ _RECORD_EMBEDDINGS = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=288, - serialized_end=368, + serialized_start=325, + serialized_end=405, ) _RECORD_BYTES_EMBEDDINGS = _descriptor.Descriptor( @@ -219,8 +249,8 @@ _RECORD_BYTES_EMBEDDINGS = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=370, - serialized_end=437, + serialized_start=407, + serialized_end=474, ) _RECORD_HISTOGRAM = _descriptor.Descriptor( @@ -256,8 +286,8 @@ _RECORD_HISTOGRAM = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=439, - serialized_end=491, + serialized_start=476, + serialized_end=528, ) _RECORD_PRCURVE = _descriptor.Descriptor( @@ -321,8 +351,8 @@ _RECORD_PRCURVE = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=493, - serialized_end=601, + serialized_start=530, + serialized_end=638, ) _RECORD_ROC_CURVE = _descriptor.Descriptor( @@ -386,8 +416,8 @@ _RECORD_ROC_CURVE = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=603, - serialized_end=704, + serialized_start=640, + serialized_end=741, ) _RECORD_METADATA = _descriptor.Descriptor( @@ -416,8 +446,8 @@ _RECORD_METADATA = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=706, - serialized_end=738, + serialized_start=743, + serialized_end=775, ) _RECORD_VALUE = _descriptor.Descriptor( @@ -504,6 +534,13 @@ _RECORD_VALUE = _descriptor.Descriptor( message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='text', full_name='visualdl.Record.Value.text', index=11, + number=12, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -519,8 +556,8 @@ _RECORD_VALUE = _descriptor.Descriptor( name='one_value', full_name='visualdl.Record.Value.one_value', index=0, containing_type=None, fields=[]), ], - serialized_start=741, - serialized_end=1147, + serialized_start=778, + serialized_end=1223, ) _RECORD = _descriptor.Descriptor( @@ -540,7 +577,7 @@ _RECORD = _descriptor.Descriptor( ], extensions=[ ], - nested_types=[_RECORD_IMAGE, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_HISTOGRAM, _RECORD_PRCURVE, _RECORD_ROC_CURVE, _RECORD_METADATA, _RECORD_VALUE, ], + nested_types=[_RECORD_IMAGE, _RECORD_TEXT, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_HISTOGRAM, _RECORD_PRCURVE, _RECORD_ROC_CURVE, _RECORD_METADATA, _RECORD_VALUE, ], enum_types=[ ], serialized_options=None, @@ -550,10 +587,11 @@ _RECORD = _descriptor.Descriptor( oneofs=[ ], serialized_start=27, - serialized_end=1147, + serialized_end=1223, ) _RECORD_IMAGE.containing_type = _RECORD +_RECORD_TEXT.containing_type = _RECORD _RECORD_AUDIO.containing_type = _RECORD _RECORD_EMBEDDING.containing_type = _RECORD _RECORD_EMBEDDINGS.fields_by_name['embeddings'].message_type = _RECORD_EMBEDDING @@ -570,6 +608,7 @@ _RECORD_VALUE.fields_by_name['histogram'].message_type = _RECORD_HISTOGRAM _RECORD_VALUE.fields_by_name['pr_curve'].message_type = _RECORD_PRCURVE _RECORD_VALUE.fields_by_name['meta_data'].message_type = _RECORD_METADATA _RECORD_VALUE.fields_by_name['roc_curve'].message_type = _RECORD_ROC_CURVE +_RECORD_VALUE.fields_by_name['text'].message_type = _RECORD_TEXT _RECORD_VALUE.containing_type = _RECORD _RECORD_VALUE.oneofs_by_name['one_value'].fields.append( _RECORD_VALUE.fields_by_name['value']) @@ -595,6 +634,9 @@ _RECORD_VALUE.fields_by_name['meta_data'].containing_oneof = _RECORD_VALUE.oneof _RECORD_VALUE.oneofs_by_name['one_value'].fields.append( _RECORD_VALUE.fields_by_name['roc_curve']) _RECORD_VALUE.fields_by_name['roc_curve'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value'] +_RECORD_VALUE.oneofs_by_name['one_value'].fields.append( + _RECORD_VALUE.fields_by_name['text']) +_RECORD_VALUE.fields_by_name['text'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value'] _RECORD.fields_by_name['values'].message_type = _RECORD_VALUE DESCRIPTOR.message_types_by_name['Record'] = _RECORD _sym_db.RegisterFileDescriptor(DESCRIPTOR) @@ -608,6 +650,13 @@ Record = _reflection.GeneratedProtocolMessageType('Record', (_message.Message,), }) , + 'Text' : _reflection.GeneratedProtocolMessageType('Text', (_message.Message,), { + 'DESCRIPTOR' : _RECORD_TEXT, + '__module__' : 'record_pb2' + # @@protoc_insertion_point(class_scope:visualdl.Record.Text) + }) + , + 'Audio' : _reflection.GeneratedProtocolMessageType('Audio', (_message.Message,), { 'DESCRIPTOR' : _RECORD_AUDIO, '__module__' : 'record_pb2' @@ -676,6 +725,7 @@ Record = _reflection.GeneratedProtocolMessageType('Record', (_message.Message,), }) _sym_db.RegisterMessage(Record) _sym_db.RegisterMessage(Record.Image) +_sym_db.RegisterMessage(Record.Text) _sym_db.RegisterMessage(Record.Audio) _sym_db.RegisterMessage(Record.Embedding) _sym_db.RegisterMessage(Record.Embeddings) diff --git a/visualdl/reader/reader.py b/visualdl/reader/reader.py index c1b2c613..723f0dde 100644 --- a/visualdl/reader/reader.py +++ b/visualdl/reader/reader.py @@ -180,6 +180,8 @@ class LogReader(object): elif "meta_data" == value_type: self.update_meta_data(record) component = "meta_data" + elif "text" == value_type: + component = "text" else: raise TypeError("Invalid value type `%s`." % value_type) self._tags[path] = component diff --git a/visualdl/server/api.py b/visualdl/server/api.py index a0abe6ef..89a38820 100644 --- a/visualdl/server/api.py +++ b/visualdl/server/api.py @@ -101,6 +101,10 @@ class Api(object): def image_tags(self): return self._get_with_retry('data/plugin/images/tags', lib.get_image_tags) + @result() + def text_tags(self): + return self._get_with_retry('data/plugin/text/tags', lib.get_text_tags) + @result() def audio_tags(self): return self._get_with_retry('data/plugin/audio/tags', lib.get_audio_tags) @@ -138,6 +142,17 @@ class Api(object): key = os.path.join('data/plugin/images/individualImage', mode, tag, str(index)) return self._get_with_retry(key, lib.get_individual_image, mode, tag, index) + @result() + def text_list(self, mode, tag): + key = os.path.join('data/plugin/text/text', mode, tag) + return self._get_with_retry(key, lib.get_text_tag_steps, mode, tag) + + @result('text/plain') + def text_text(self, mode, tag, index=0): + index = int(index) + key = os.path.join('data/plugin/text/individualText', mode, tag, str(index)) + return self._get_with_retry(key, lib.get_individual_text, mode, tag, index) + @result() def audio_list(self, run, tag): key = os.path.join('data/plugin/audio/audio', run, tag) @@ -216,6 +231,7 @@ def create_api_call(logdir, model, cache_timeout): 'logs': (api.logs, []), 'scalar/tags': (api.scalar_tags, []), 'image/tags': (api.image_tags, []), + 'text/tags': (api.text_tags, []), 'audio/tags': (api.audio_tags, []), 'embedding/tags': (api.embedding_tags, []), 'histogram/tags': (api.histogram_tags, []), @@ -225,6 +241,8 @@ def create_api_call(logdir, model, cache_timeout): 'scalar/data': (api.scalar_data, ['run', 'tag', 'type']), 'image/list': (api.image_list, ['run', 'tag']), 'image/image': (api.image_image, ['run', 'tag', 'index']), + 'text/list': (api.text_list, ['run', 'tag']), + 'text/text': (api.text_text, ['run', 'tag', 'index']), 'audio/list': (api.audio_list, ['run', 'tag']), 'audio/audio': (api.audio_audio, ['run', 'tag', 'index']), 'embedding/embedding': (api.embedding_embedding, ['run', 'tag', 'reduction', 'dimension']), diff --git a/visualdl/server/data_manager.py b/visualdl/server/data_manager.py index 9acadcdf..8fe27ccb 100644 --- a/visualdl/server/data_manager.py +++ b/visualdl/server/data_manager.py @@ -25,7 +25,8 @@ DEFAULT_PLUGIN_MAXSIZE = { "audio": 10, "pr_curve": 300, "roc_curve": 300, - "meta_data": 100 + "meta_data": 100, + "text": 10 } @@ -350,7 +351,9 @@ class DataManager(object): "roc_curve": Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["roc_curve"]), "meta_data": - Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["meta_data"]) + Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["meta_data"]), + "text": + Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["text"]) } self._mutex = threading.Lock() diff --git a/visualdl/server/lib.py b/visualdl/server/lib.py index 22a19c01..b9a42dbf 100644 --- a/visualdl/server/lib.py +++ b/visualdl/server/lib.py @@ -158,6 +158,26 @@ def get_individual_image(log_reader, run, tag, step_index): return records[step_index].image.encoded_image_string +def get_text_tag_steps(log_reader, run, tag): + run = log_reader.name2tags[run] if run in log_reader.name2tags else run + log_reader.load_new_data() + records = log_reader.data_manager.get_reservoir("text").get_items( + run, decode_tag(tag)) + result = [{ + "step": item.id, + "wallTime": s2ms(item.timestamp) + } for item in records] + return result + + +def get_individual_text(log_reader, run, tag, step_index): + run = log_reader.name2tags[run] if run in log_reader.name2tags else run + log_reader.load_new_data() + records = log_reader.data_manager.get_reservoir("text").get_items( + run, decode_tag(tag)) + return records[step_index].text.encoded_text_string + + def get_audio_tag_steps(log_reader, run, tag): run = log_reader.name2tags[run] if run in log_reader.name2tags else run log_reader.load_new_data() diff --git a/visualdl/writer/writer.py b/visualdl/writer/writer.py index 1714d8e1..278c7a72 100644 --- a/visualdl/writer/writer.py +++ b/visualdl/writer/writer.py @@ -18,7 +18,8 @@ import time import numpy as np from visualdl.writer.record_writer import RecordFileWriter from visualdl.utils.img_util import merge_images -from visualdl.component.base_component import scalar, image, embedding, audio, histogram, pr_curve, roc_curve, meta_data +from visualdl.component.base_component import scalar, image, embedding, audio, \ + histogram, pr_curve, roc_curve, meta_data, text class DummyFileWriter(object): @@ -190,6 +191,26 @@ class LogWriter(object): image(tag=tag, image_array=img, step=step, walltime=walltime, dataformats=dataformats)) + def add_text(self, tag, text_string, step=None, walltime=None): + """Add an text to vdl record file. + Args: + tag (string): Data identifier + text_string (string): Value of text + step (int): Step of text + walltime (int): Wall time of text + Example: + for index in range(1, 101): + writer.add_text(tag="train/loss", text_string=str(index) + 'text', step=index) + """ + if '%' in tag: + raise RuntimeError("% can't appear in tag!") + walltime = round( + time.time() * 1000) if walltime is None else walltime + self._get_file_writer().add_record( + text( + tag=tag, text_string=text_string, step=step, + walltime=walltime)) + def add_image_matrix(self, tag, imgs, step, rows=-1, scale=1.0, walltime=None, dataformats="HWC"): """Add an image to vdl record file. -- GitLab