diff --git a/requirements.txt b/requirements.txt index 5a4e55129a7370a72d1b2677cb7d8b34250ae623..17af24d6161f1091986b515e20e7a6f527a743f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ requests shellcheck-py six >= 1.14.0 matplotlib +pandas diff --git a/visualdl/component/base_component.py b/visualdl/component/base_component.py index d0342c3857b6338df370a62337e9be31624d5b84..b562c6173458c78a46d4fe0f87d755b02735c99d 100644 --- a/visualdl/component/base_component.py +++ b/visualdl/component/base_component.py @@ -293,6 +293,51 @@ def histogram(tag, hist, bin_edges, step, walltime): ]) +def hparam(name, hparam_dict, metric_list, walltime): + """Package data to one histogram. + + Args: + name (str): Name of hparam. + hparam_dict (dictionary): Each key-value pair in the dictionary is the + name of the hyper parameter and it's corresponding value. The type of the value + can be one of `bool`, `string`, `float`, `int`, or `None`. + metric_list (list): Name of all metrics. + walltime (int): Wall time of hparam. + + Return: + Package with format of record_pb2.Record + """ + + hm = Record.HParam() + hm.name = name + for k, v in hparam_dict.items(): + if v is None: + continue + hparamInfo = Record.HParam.HparamInfo() + hparamInfo.name = k + if isinstance(v, int): + hparamInfo.int_value = v + hm.hparamInfos.append(hparamInfo) + elif isinstance(v, float): + hparamInfo.float_value = v + hm.hparamInfos.append(hparamInfo) + elif isinstance(v, str): + hparamInfo.string_value = v + hm.hparamInfos.append(hparamInfo) + else: + print("The value of %s must be int, float or str, not %s" % (k, str(type(v)))) + for metric in metric_list: + metricInfo = Record.HParam.HparamInfo() + metricInfo.name = metric + metricInfo.float_value = 0 + hm.metricInfos.append(metricInfo) + + return Record(values=[ + Record.Value( + id=1, tag="hparam", timestamp=walltime, hparam=hm) + ]) + + def compute_curve(labels, predictions, num_thresholds=None, weights=None): """ Compute precision-recall curve data by labels and predictions. diff --git a/visualdl/proto/record.proto b/visualdl/proto/record.proto index 952218b181f59888ff46bc5d3162957af6ad73e2..680d0adbbb060ada968b36ad7ecd13eca4176c6a 100644 --- a/visualdl/proto/record.proto +++ b/visualdl/proto/record.proto @@ -57,6 +57,20 @@ message Record { repeated double fpr = 6; } + message HParam { + message HparamInfo { + oneof type { + int64 int_value = 1; + double float_value = 2; + string string_value = 3; + }; + string name = 4; + } + repeated HparamInfo hparamInfos = 1; + repeated HparamInfo metricInfos = 2; + string name = 3; + } + message MetaData { string display_name = 1; } @@ -75,6 +89,7 @@ message Record { MetaData meta_data = 10; ROC_Curve roc_curve = 11; Text text = 12; + HParam hparam = 13; } } diff --git a/visualdl/proto/record_pb2.py b/visualdl/proto/record_pb2.py index 223dc58761355f8c225571955b5a22bd5368ba42..1dc3fd22aa72e1d07588a885630cbd1cb8f1d9b8 100644 --- a/visualdl/proto/record_pb2.py +++ b/visualdl/proto/record_pb2.py @@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( package='visualdl', syntax='proto3', serialized_options=None, - serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xac\t\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a#\n\x04Text\x12\x1b\n\x13\x65ncoded_text_string\x18\x01 \x01(\t\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x03(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1aP\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x12\x12\n\nlabel_meta\x18\x02 \x03(\t\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1al\n\x07PRCurve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x11\n\tprecision\x18\x05 \x03(\x01\x12\x0e\n\x06recall\x18\x06 \x03(\x01\x1a\x65\n\tROC_Curve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x0b\n\x03tpr\x18\x05 \x03(\x01\x12\x0b\n\x03\x66pr\x18\x06 \x03(\x01\x1a \n\x08MetaData\x12\x14\n\x0c\x64isplay_name\x18\x01 \x01(\t\x1a\xbd\x03\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x12,\n\x08pr_curve\x18\t \x01(\x0b\x32\x18.visualdl.Record.PRCurveH\x00\x12.\n\tmeta_data\x18\n \x01(\x0b\x32\x19.visualdl.Record.MetaDataH\x00\x12/\n\troc_curve\x18\x0b \x01(\x0b\x32\x1a.visualdl.Record.ROC_CurveH\x00\x12%\n\x04text\x18\x0c \x01(\x0b\x32\x15.visualdl.Record.TextH\x00\x42\x0b\n\tone_valueb\x06proto3' + serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xca\x0b\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a#\n\x04Text\x12\x1b\n\x13\x65ncoded_text_string\x18\x01 \x01(\t\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x03(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1aP\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x12\x12\n\nlabel_meta\x18\x02 \x03(\t\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1al\n\x07PRCurve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x11\n\tprecision\x18\x05 \x03(\x01\x12\x0e\n\x06recall\x18\x06 \x03(\x01\x1a\x65\n\tROC_Curve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x0b\n\x03tpr\x18\x05 \x03(\x01\x12\x0b\n\x03\x66pr\x18\x06 \x03(\x01\x1a\xf0\x01\n\x06HParam\x12\x37\n\x0bhparamInfos\x18\x01 \x03(\x0b\x32\".visualdl.Record.HParam.HparamInfo\x12\x37\n\x0bmetricInfos\x18\x02 \x03(\x0b\x32\".visualdl.Record.HParam.HparamInfo\x12\x0c\n\x04name\x18\x03 \x01(\t\x1a\x66\n\nHparamInfo\x12\x13\n\tint_value\x18\x01 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x01H\x00\x12\x16\n\x0cstring_value\x18\x03 \x01(\tH\x00\x12\x0c\n\x04name\x18\x04 \x01(\tB\x06\n\x04type\x1a \n\x08MetaData\x12\x14\n\x0c\x64isplay_name\x18\x01 \x01(\t\x1a\xe8\x03\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x12,\n\x08pr_curve\x18\t \x01(\x0b\x32\x18.visualdl.Record.PRCurveH\x00\x12.\n\tmeta_data\x18\n \x01(\x0b\x32\x19.visualdl.Record.MetaDataH\x00\x12/\n\troc_curve\x18\x0b \x01(\x0b\x32\x1a.visualdl.Record.ROC_CurveH\x00\x12%\n\x04text\x18\x0c \x01(\x0b\x32\x15.visualdl.Record.TextH\x00\x12)\n\x06hparam\x18\r \x01(\x0b\x32\x17.visualdl.Record.HParamH\x00\x42\x0b\n\tone_valueb\x06proto3' ) @@ -420,6 +420,104 @@ _RECORD_ROC_CURVE = _descriptor.Descriptor( serialized_end=741, ) +_RECORD_HPARAM_HPARAMINFO = _descriptor.Descriptor( + name='HparamInfo', + full_name='visualdl.Record.HParam.HparamInfo', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='int_value', full_name='visualdl.Record.HParam.HparamInfo.int_value', index=0, + number=1, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='float_value', full_name='visualdl.Record.HParam.HparamInfo.float_value', index=1, + number=2, type=1, cpp_type=5, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='string_value', full_name='visualdl.Record.HParam.HparamInfo.string_value', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='name', full_name='visualdl.Record.HParam.HparamInfo.name', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='type', full_name='visualdl.Record.HParam.HparamInfo.type', + index=0, containing_type=None, fields=[]), + ], + serialized_start=882, + serialized_end=984, +) + +_RECORD_HPARAM = _descriptor.Descriptor( + name='HParam', + full_name='visualdl.Record.HParam', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='hparamInfos', full_name='visualdl.Record.HParam.hparamInfos', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='metricInfos', full_name='visualdl.Record.HParam.metricInfos', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='name', full_name='visualdl.Record.HParam.name', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_RECORD_HPARAM_HPARAMINFO, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=744, + serialized_end=984, +) + _RECORD_METADATA = _descriptor.Descriptor( name='MetaData', full_name='visualdl.Record.MetaData', @@ -446,8 +544,8 @@ _RECORD_METADATA = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=743, - serialized_end=775, + serialized_start=986, + serialized_end=1018, ) _RECORD_VALUE = _descriptor.Descriptor( @@ -541,6 +639,13 @@ _RECORD_VALUE = _descriptor.Descriptor( message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='hparam', full_name='visualdl.Record.Value.hparam', index=12, + number=13, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -556,8 +661,8 @@ _RECORD_VALUE = _descriptor.Descriptor( name='one_value', full_name='visualdl.Record.Value.one_value', index=0, containing_type=None, fields=[]), ], - serialized_start=778, - serialized_end=1223, + serialized_start=1021, + serialized_end=1509, ) _RECORD = _descriptor.Descriptor( @@ -577,7 +682,7 @@ _RECORD = _descriptor.Descriptor( ], extensions=[ ], - nested_types=[_RECORD_IMAGE, _RECORD_TEXT, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_HISTOGRAM, _RECORD_PRCURVE, _RECORD_ROC_CURVE, _RECORD_METADATA, _RECORD_VALUE, ], + nested_types=[_RECORD_IMAGE, _RECORD_TEXT, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_HISTOGRAM, _RECORD_PRCURVE, _RECORD_ROC_CURVE, _RECORD_HPARAM, _RECORD_METADATA, _RECORD_VALUE, ], enum_types=[ ], serialized_options=None, @@ -587,7 +692,7 @@ _RECORD = _descriptor.Descriptor( oneofs=[ ], serialized_start=27, - serialized_end=1223, + serialized_end=1509, ) _RECORD_IMAGE.containing_type = _RECORD @@ -600,6 +705,19 @@ _RECORD_BYTES_EMBEDDINGS.containing_type = _RECORD _RECORD_HISTOGRAM.containing_type = _RECORD _RECORD_PRCURVE.containing_type = _RECORD _RECORD_ROC_CURVE.containing_type = _RECORD +_RECORD_HPARAM_HPARAMINFO.containing_type = _RECORD_HPARAM +_RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'].fields.append( + _RECORD_HPARAM_HPARAMINFO.fields_by_name['int_value']) +_RECORD_HPARAM_HPARAMINFO.fields_by_name['int_value'].containing_oneof = _RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'] +_RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'].fields.append( + _RECORD_HPARAM_HPARAMINFO.fields_by_name['float_value']) +_RECORD_HPARAM_HPARAMINFO.fields_by_name['float_value'].containing_oneof = _RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'] +_RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'].fields.append( + _RECORD_HPARAM_HPARAMINFO.fields_by_name['string_value']) +_RECORD_HPARAM_HPARAMINFO.fields_by_name['string_value'].containing_oneof = _RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'] +_RECORD_HPARAM.fields_by_name['hparamInfos'].message_type = _RECORD_HPARAM_HPARAMINFO +_RECORD_HPARAM.fields_by_name['metricInfos'].message_type = _RECORD_HPARAM_HPARAMINFO +_RECORD_HPARAM.containing_type = _RECORD _RECORD_METADATA.containing_type = _RECORD _RECORD_VALUE.fields_by_name['image'].message_type = _RECORD_IMAGE _RECORD_VALUE.fields_by_name['audio'].message_type = _RECORD_AUDIO @@ -609,6 +727,7 @@ _RECORD_VALUE.fields_by_name['pr_curve'].message_type = _RECORD_PRCURVE _RECORD_VALUE.fields_by_name['meta_data'].message_type = _RECORD_METADATA _RECORD_VALUE.fields_by_name['roc_curve'].message_type = _RECORD_ROC_CURVE _RECORD_VALUE.fields_by_name['text'].message_type = _RECORD_TEXT +_RECORD_VALUE.fields_by_name['hparam'].message_type = _RECORD_HPARAM _RECORD_VALUE.containing_type = _RECORD _RECORD_VALUE.oneofs_by_name['one_value'].fields.append( _RECORD_VALUE.fields_by_name['value']) @@ -637,6 +756,9 @@ _RECORD_VALUE.fields_by_name['roc_curve'].containing_oneof = _RECORD_VALUE.oneof _RECORD_VALUE.oneofs_by_name['one_value'].fields.append( _RECORD_VALUE.fields_by_name['text']) _RECORD_VALUE.fields_by_name['text'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value'] +_RECORD_VALUE.oneofs_by_name['one_value'].fields.append( + _RECORD_VALUE.fields_by_name['hparam']) +_RECORD_VALUE.fields_by_name['hparam'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value'] _RECORD.fields_by_name['values'].message_type = _RECORD_VALUE DESCRIPTOR.message_types_by_name['Record'] = _RECORD _sym_db.RegisterFileDescriptor(DESCRIPTOR) @@ -706,6 +828,20 @@ Record = _reflection.GeneratedProtocolMessageType('Record', (_message.Message,), }) , + 'HParam' : _reflection.GeneratedProtocolMessageType('HParam', (_message.Message,), { + + 'HparamInfo' : _reflection.GeneratedProtocolMessageType('HparamInfo', (_message.Message,), { + 'DESCRIPTOR' : _RECORD_HPARAM_HPARAMINFO, + '__module__' : 'record_pb2' + # @@protoc_insertion_point(class_scope:visualdl.Record.HParam.HparamInfo) + }) + , + 'DESCRIPTOR' : _RECORD_HPARAM, + '__module__' : 'record_pb2' + # @@protoc_insertion_point(class_scope:visualdl.Record.HParam) + }) + , + 'MetaData' : _reflection.GeneratedProtocolMessageType('MetaData', (_message.Message,), { 'DESCRIPTOR' : _RECORD_METADATA, '__module__' : 'record_pb2' @@ -733,6 +869,8 @@ _sym_db.RegisterMessage(Record.bytes_embeddings) _sym_db.RegisterMessage(Record.Histogram) _sym_db.RegisterMessage(Record.PRCurve) _sym_db.RegisterMessage(Record.ROC_Curve) +_sym_db.RegisterMessage(Record.HParam) +_sym_db.RegisterMessage(Record.HParam.HparamInfo) _sym_db.RegisterMessage(Record.MetaData) _sym_db.RegisterMessage(Record.Value) diff --git a/visualdl/reader/reader.py b/visualdl/reader/reader.py index e3bef98b9bb2402dd67cf5f2bcc9ec972dccdb52..5dfcd9f917461c9fb5ee40c67c7b87990c2be08f 100644 --- a/visualdl/reader/reader.py +++ b/visualdl/reader/reader.py @@ -182,6 +182,8 @@ class LogReader(object): component = "meta_data" elif "text" == value_type: component = "text" + elif "hparam" == value_type: + component = "hyper_parameters" else: raise TypeError("Invalid value type `%s`." % value_type) self._tags[path] = component diff --git a/visualdl/server/api.py b/visualdl/server/api.py index 89a388201c96487b6ae6558bca7df928db123dcc..7f9c72199be2bd8d7f5270bef9237a367bc9bfda 100644 --- a/visualdl/server/api.py +++ b/visualdl/server/api.py @@ -121,6 +121,28 @@ class Api(object): def roc_curve_tags(self): return self._get_with_retry('data/plugin/roc_curves/tags', lib.get_roc_curve_tags) + @result() + def hparam_importance(self): + return self._get_with_retry('data/plugin/hparams/importance', lib.get_hparam_importance) + + @result() + def hparam_indicator(self): + return self._get_with_retry('data/plugin/hparams/indicators', lib.get_hparam_indicator) + + @result() + def hparam_list(self): + return self._get_with_retry('data/plugin/hparams/list', lib.get_hparam_list) + + @result() + def hparam_metric(self, run, metric): + key = os.path.join('data/plugin/hparams/metric', run, metric) + return self._get_with_retry(key, lib.get_hparam_metric, run, metric) + + @result('text/csv') + def hparam_data(self, type='tsv'): + key = os.path.join('data/plugin/hparams/data', type) + return self._get_with_retry(key, lib.get_hparam_data, type) + @result() def scalar_list(self, run, tag): key = os.path.join('data/plugin/scalars/scalars', run, tag) @@ -254,7 +276,12 @@ def create_api_call(logdir, model, cache_timeout): 'pr-curve/list': (api.pr_curves_pr_curve, ['run', 'tag']), 'roc-curve/list': (api.roc_curves_roc_curve, ['run', 'tag']), 'pr-curve/steps': (api.pr_curves_steps, ['run']), - 'roc-curve/steps': (api.roc_curves_steps, ['run']) + 'roc-curve/steps': (api.roc_curves_steps, ['run']), + 'hparams/importance': (api.hparam_importance, []), + 'hparams/data': (api.hparam_data, ['type']), + 'hparams/indicators': (api.hparam_indicator, []), + 'hparams/list': (api.hparam_list, []), + 'hparams/metric': (api.hparam_metric, ['run', 'metric']) } def call(path: str, args): diff --git a/visualdl/server/data_manager.py b/visualdl/server/data_manager.py index 8fe27ccb759a69b485e8555cc458a78c35a7c7ae..d4272cc77dedf3dfc2280b8dae13b1130afe5725 100644 --- a/visualdl/server/data_manager.py +++ b/visualdl/server/data_manager.py @@ -26,7 +26,8 @@ DEFAULT_PLUGIN_MAXSIZE = { "pr_curve": 300, "roc_curve": 300, "meta_data": 100, - "text": 10 + "text": 10, + "hyper_parameters": 10000 } @@ -353,7 +354,9 @@ class DataManager(object): "meta_data": Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["meta_data"]), "text": - Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["text"]) + Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["text"]), + "hyper_parameters": + Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["hyper_parameters"]) } self._mutex = threading.Lock() diff --git a/visualdl/server/lib.py b/visualdl/server/lib.py index c7266f3cddcdd70baf82a9fbbb8bcc03e9b23211..b6791853ac9d3701fa77d956f83daff4501669f8 100644 --- a/visualdl/server/lib.py +++ b/visualdl/server/lib.py @@ -24,6 +24,8 @@ import numpy as np from visualdl.server.log import logger from visualdl.io import bfile from visualdl.utils.string_util import encode_tag, decode_tag +from visualdl.utils.importance import calc_all_hyper_param_importance +from visualdl.utils.list_util import duplicate_removal from visualdl.component import components @@ -115,6 +117,172 @@ for name in components.keys(): exec("get_%s_tags=partial(get_logs, component='%s')" % (name, name)) +def get_hparam_data(log_reader, type='tsv'): + result = get_hparam_list(log_reader) + delimeter = '\t' if 'tsv' == type else ',' + header = ['Trial ID'] + hparams_header = [] + metrics_header = [] + for item in result: + hparams_header += item['hparams'].keys() + metrics_header += item['metrics'].keys() + name_set = set() + h_header = [] + for hparam in hparams_header: + if hparam in name_set: + continue + name_set.add(hparam) + h_header.append(hparam) + name_set = set() + m_header = [] + for metric in metrics_header: + if metric in name_set: + continue + name_set.add(metric) + m_header.append(metric) + trans_result = [] + for item in result: + temp = {'Trial ID': item.get('name', '')} + temp.update(item.get('hparams', {})) + temp.update(item.get('metrics', {})) + trans_result.append(temp) + header = header + h_header + m_header + with io.StringIO() as fp: + csv_writer = csv.writer(fp, delimiter=delimeter) + csv_writer.writerow(header) + for item in trans_result: + row = [] + for col_name in header: + row.append(item.get(col_name, '')) + csv_writer.writerow(row) + result = fp.getvalue() + return result + + +def get_hparam_importance(log_reader): + indicator = get_hparam_indicator(log_reader) + hparams = [item for item in indicator['hparams'] if (item['type'] != 'string')] + metrics = [item for item in indicator['metrics'] if (item['type'] != 'string')] + + result = calc_all_hyper_param_importance(hparams, metrics) + + return result + + +# flake8: noqa: C901 +def get_hparam_indicator(log_reader): + run2tag = get_logs(log_reader, 'hyper_parameters') + runs = run2tag['runs'] + hparams = {} + metrics = {} + records_list = [] + for run in runs: + run = log_reader.name2tags[run] if run in log_reader.name2tags else run + log_reader.load_new_data() + records = log_reader.data_manager.get_reservoir("hyper_parameters").get_items( + run, decode_tag('hparam')) + records_list.append([records, run]) + records_list.sort(key=lambda x: x[0][0].timestamp) + runs = [run for r, run in records_list] + for records, run in records_list: + for hparamInfo in records[0].hparam.hparamInfos: + type = hparamInfo.WhichOneof("type") + if "float_value" == type: + if hparamInfo.name not in hparams.keys(): + hparams[hparamInfo.name] = {'name': hparamInfo.name, + 'type': 'continuous', + 'values': [hparamInfo.float_value]} + elif hparamInfo.float_value not in hparams[hparamInfo.name]['values']: + hparams[hparamInfo.name]['values'].append(hparamInfo.float_value) + elif "string_value" == type: + if hparamInfo.name not in hparams.keys(): + hparams[hparamInfo.name] = {'name': hparamInfo.name, + 'type': 'string', + 'values': [hparamInfo.string_value]} + elif hparamInfo.string_value not in hparams[hparamInfo.name]['values']: + hparams[hparamInfo.name]['values'].append(hparamInfo.string_value) + elif "int_value" == type: + if hparamInfo.name not in hparams.keys(): + hparams[hparamInfo.name] = {'name': hparamInfo.name, + 'type': 'numeric', + 'values': [hparamInfo.int_value]} + elif hparamInfo.int_value not in hparams[hparamInfo.name]['values']: + hparams[hparamInfo.name]['values'].append(hparamInfo.int_value) + else: + raise TypeError("Invalid hparams param value type `%s`." % type) + + for metricInfo in records[0].hparam.metricInfos: + metrics[metricInfo.name] = {'name': metricInfo.name, + 'type': 'continuous', + 'values': []} + for run in runs: + try: + metrics_data = get_hparam_metric(log_reader, run, metricInfo.name) + metrics[metricInfo.name]['values'].append(metrics_data[-1][-1]) + break + except: + logger.error('Missing data of metrics! Please make sure use add_scalar to log metrics data.') + if len(metrics[metricInfo.name]['values']) == 0: + metrics.pop(metricInfo.name) + else: + metrics[metricInfo.name].pop('values') + + results = {'hparams': [value for key, value in hparams.items()], + 'metrics': [value for key, value in metrics.items()]} + + return results + + +def get_hparam_metric(log_reader, run, tag): + run = log_reader.name2tags[run] if run in log_reader.name2tags else run + log_reader.load_new_data() + records = log_reader.data_manager.get_reservoir("scalar").get_items( + run, decode_tag(tag)) + results = [[s2ms(item.timestamp), item.id, item.value] for item in records] + return results + + +def get_hparam_list(log_reader): + run2tag = get_logs(log_reader, 'hyper_parameters') + runs = run2tag['runs'] + results = [] + + records_list = [] + for run in runs: + run = log_reader.name2tags[run] if run in log_reader.name2tags else run + log_reader.load_new_data() + records = log_reader.data_manager.get_reservoir("hyper_parameters").get_items( + run, decode_tag('hparam')) + records_list.append([records, run]) + records_list.sort(key=lambda x: x[0][0].timestamp) + for records, run in records_list: + hparams = {} + for hparamInfo in records[0].hparam.hparamInfos: + hparam_type = hparamInfo.WhichOneof("type") + if "float_value" == hparam_type: + hparams[hparamInfo.name] = hparamInfo.float_value + elif "string_value" == hparam_type: + hparams[hparamInfo.name] = hparamInfo.string_value + elif "int_value" == hparam_type: + hparams[hparamInfo.name] = hparamInfo.int_value + else: + raise TypeError("Invalid hparams param value type `%s`." % hparam_type) + + metrics = {} + for metricInfo in records[0].hparam.metricInfos: + try: + metrics_data = get_hparam_metric(log_reader, run, metricInfo.name) + metrics[metricInfo.name] = metrics_data[-1][-1] + except: + logger.error('Missing data of metrics! Please make sure use add_scalar to log metrics data.') + metrics[metricInfo.name] = None + + results.append({'name': run, + 'hparams': hparams, + 'metrics': metrics}) + return results + + def get_scalar(log_reader, run, tag): run = log_reader.name2tags[run] if run in log_reader.name2tags else run log_reader.load_new_data() diff --git a/visualdl/utils/importance.py b/visualdl/utils/importance.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e911feb6cc3231d58b3b3c2a197851b8fc8bed --- /dev/null +++ b/visualdl/utils/importance.py @@ -0,0 +1,55 @@ +# Copyright (c) 2020 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +from functools import reduce + +import numpy as np +import pandas as pd + +from visualdl.server.log import logger + + +def calc_hyper_param_importance(df, hyper_param, target): + new_df = df[[hyper_param, target]] + no_missing_value_df = new_df.dropna() + + # Can not calc pearson correlation coefficient when number of samples is less or equal than 2 + if len(no_missing_value_df) <= 2: + logger.error("Number of samples is less or equal than 2.") + return 0 + + correlation = no_missing_value_df[target].corr(no_missing_value_df[hyper_param]) + if np.isnan(correlation): + logger.warning("Correlation is nan!") + return 0 + + return abs(correlation) + + +def calc_all_hyper_param_importance(hparams, metrics): + results = {} + for metric in metrics: + for hparam in hparams: + flattened_lineage = {hparam['name']: hparam['values'], metric['name']: metric['values']} + result = calc_hyper_param_importance(pd.DataFrame(flattened_lineage), hparam['name'], metric['name']) + # print('%s - %s : result=' % (hparam, metric), result) + if hparam['name'] not in results.keys(): + results[hparam['name']] = result + else: + results[hparam['name']] += result + sum_score = reduce(lambda x, y: x+y, results.values()) + for key, value in results.items(): + results[key] = value/sum_score + result = [{'name': key, 'value': value} for key, value in results.items()] + return result diff --git a/visualdl/utils/list_util.py b/visualdl/utils/list_util.py new file mode 100644 index 0000000000000000000000000000000000000000..cff1af876df674f8073d4b8245b11a1337ff722b --- /dev/null +++ b/visualdl/utils/list_util.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= + + +def duplicate_removal(src_list): + name_scope = set() + dest_list = [] + for item in src_list: + if item in name_scope: + continue + name_scope.add(item) + dest_list.append(item) + return dest_list diff --git a/visualdl/writer/writer.py b/visualdl/writer/writer.py index 214708fd8e017e736a2d8990f479b1dc1706f5da..ccc79556604277facebcb79de10125e614dcb018 100644 --- a/visualdl/writer/writer.py +++ b/visualdl/writer/writer.py @@ -20,8 +20,9 @@ from visualdl.writer.record_writer import RecordFileWriter from visualdl.server.log import logger from visualdl.utils.img_util import merge_images from visualdl.utils.figure_util import figure_to_image +from visualdl.utils.md5_util import md5 from visualdl.component.base_component import scalar, image, embedding, audio, \ - histogram, pr_curve, roc_curve, meta_data, text + histogram, pr_curve, roc_curve, meta_data, text, hparam class DummyFileWriter(object): @@ -441,6 +442,45 @@ class LogWriter(object): step=step, walltime=walltime)) + def add_hparams(self, hparam_dict, metric_list, walltime=None): + """Add an histogram to vdl record file. + + Args: + hparam_dict (dictionary): Each key-value pair in the dictionary is the + name of the hyper parameter and it's corresponding value. The type of the value + can be one of `bool`, `string`, `float`, `int`, or `None`. + metric_list (list): Name of all metrics. + walltime (int): Wall time of hparams. + + Examples:: + from visualdl import LogWriter + + # Remember use add_scalar to log your metrics data! + with LogWriter('./log/hparams_test/train/run1') as writer: + writer.add_hparams({'lr': 0.1, 'bsize': 1, 'opt': 'sgd'}, ['hparam/accuracy', 'hparam/loss']) + for i in range(10): + writer.add_scalar('hparam/accuracy', i, i) + writer.add_scalar('hparam/loss', 2*i, i) + + with LogWriter('./log/hparams_test/train/run2') as writer: + writer.add_hparams({'lr': 0.2, 'bsize': 2, 'opt': 'relu'}, ['hparam/accuracy', 'hparam/loss']) + for i in range(10): + writer.add_scalar('hparam/accuracy', 1.0/(i+1), i) + writer.add_scalar('hparam/loss', 5*i, i) + """ + if type(hparam_dict) is not dict: + raise TypeError('hparam_dict should be dictionary.') + if type(metric_list) is not list: + raise TypeError('metric_list should be list.') + walltime = round(time.time() * 1000) if walltime is None else walltime + + self._get_file_writer().add_record( + hparam( + name=md5(self.file_name), + hparam_dict=hparam_dict, + metric_list=metric_list, + walltime=walltime)) + def add_pr_curve(self, tag, labels,