未验证 提交 d946813b 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

Add hyper parameter api.

上级 92eb7c1d
......@@ -293,6 +293,51 @@ def histogram(tag, hist, bin_edges, step, walltime):
])
def hparam(name, hparam_dict, metric_list, walltime):
"""Package data to one histogram.
Args:
name (str): Name of hparam.
hparam_dict (dictionary): Each key-value pair in the dictionary is the
name of the hyper parameter and it's corresponding value. The type of the value
can be one of `bool`, `string`, `float`, `int`, or `None`.
metric_list (list): Name of all metrics.
walltime (int): Wall time of hparam.
Return:
Package with format of record_pb2.Record
"""
hm = Record.HParam()
hm.name = name
for k, v in hparam_dict.items():
if v is None:
continue
hparamInfo = Record.HParam.HparamInfo()
hparamInfo.name = k
if isinstance(v, int):
hparamInfo.int_value = v
hm.hparamInfos.append(hparamInfo)
elif isinstance(v, float):
hparamInfo.float_value = v
hm.hparamInfos.append(hparamInfo)
elif isinstance(v, str):
hparamInfo.string_value = v
hm.hparamInfos.append(hparamInfo)
else:
print("The value of %s must be int, float or str, not %s" % (k, str(type(v))))
for metric in metric_list:
metricInfo = Record.HParam.HparamInfo()
metricInfo.name = metric
metricInfo.float_value = 0
hm.metricInfos.append(metricInfo)
return Record(values=[
Record.Value(
id=1, tag="hparam", timestamp=walltime, hparam=hm)
])
def compute_curve(labels, predictions, num_thresholds=None, weights=None):
""" Compute precision-recall curve data by labels and predictions.
......
......@@ -57,6 +57,20 @@ message Record {
repeated double fpr = 6;
}
message HParam {
message HparamInfo {
oneof type {
int64 int_value = 1;
double float_value = 2;
string string_value = 3;
};
string name = 4;
}
repeated HparamInfo hparamInfos = 1;
repeated HparamInfo metricInfos = 2;
string name = 3;
}
message MetaData {
string display_name = 1;
}
......@@ -75,6 +89,7 @@ message Record {
MetaData meta_data = 10;
ROC_Curve roc_curve = 11;
Text text = 12;
HParam hparam = 13;
}
}
......
......@@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='visualdl',
syntax='proto3',
serialized_options=None,
serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xac\t\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a#\n\x04Text\x12\x1b\n\x13\x65ncoded_text_string\x18\x01 \x01(\t\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x03(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1aP\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x12\x12\n\nlabel_meta\x18\x02 \x03(\t\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1al\n\x07PRCurve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x11\n\tprecision\x18\x05 \x03(\x01\x12\x0e\n\x06recall\x18\x06 \x03(\x01\x1a\x65\n\tROC_Curve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x0b\n\x03tpr\x18\x05 \x03(\x01\x12\x0b\n\x03\x66pr\x18\x06 \x03(\x01\x1a \n\x08MetaData\x12\x14\n\x0c\x64isplay_name\x18\x01 \x01(\t\x1a\xbd\x03\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x12,\n\x08pr_curve\x18\t \x01(\x0b\x32\x18.visualdl.Record.PRCurveH\x00\x12.\n\tmeta_data\x18\n \x01(\x0b\x32\x19.visualdl.Record.MetaDataH\x00\x12/\n\troc_curve\x18\x0b \x01(\x0b\x32\x1a.visualdl.Record.ROC_CurveH\x00\x12%\n\x04text\x18\x0c \x01(\x0b\x32\x15.visualdl.Record.TextH\x00\x42\x0b\n\tone_valueb\x06proto3'
serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xca\x0b\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a#\n\x04Text\x12\x1b\n\x13\x65ncoded_text_string\x18\x01 \x01(\t\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x03(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1aP\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x12\x12\n\nlabel_meta\x18\x02 \x03(\t\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1al\n\x07PRCurve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x11\n\tprecision\x18\x05 \x03(\x01\x12\x0e\n\x06recall\x18\x06 \x03(\x01\x1a\x65\n\tROC_Curve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x0b\n\x03tpr\x18\x05 \x03(\x01\x12\x0b\n\x03\x66pr\x18\x06 \x03(\x01\x1a\xf0\x01\n\x06HParam\x12\x37\n\x0bhparamInfos\x18\x01 \x03(\x0b\x32\".visualdl.Record.HParam.HparamInfo\x12\x37\n\x0bmetricInfos\x18\x02 \x03(\x0b\x32\".visualdl.Record.HParam.HparamInfo\x12\x0c\n\x04name\x18\x03 \x01(\t\x1a\x66\n\nHparamInfo\x12\x13\n\tint_value\x18\x01 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x01H\x00\x12\x16\n\x0cstring_value\x18\x03 \x01(\tH\x00\x12\x0c\n\x04name\x18\x04 \x01(\tB\x06\n\x04type\x1a \n\x08MetaData\x12\x14\n\x0c\x64isplay_name\x18\x01 \x01(\t\x1a\xe8\x03\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x12,\n\x08pr_curve\x18\t \x01(\x0b\x32\x18.visualdl.Record.PRCurveH\x00\x12.\n\tmeta_data\x18\n \x01(\x0b\x32\x19.visualdl.Record.MetaDataH\x00\x12/\n\troc_curve\x18\x0b \x01(\x0b\x32\x1a.visualdl.Record.ROC_CurveH\x00\x12%\n\x04text\x18\x0c \x01(\x0b\x32\x15.visualdl.Record.TextH\x00\x12)\n\x06hparam\x18\r \x01(\x0b\x32\x17.visualdl.Record.HParamH\x00\x42\x0b\n\tone_valueb\x06proto3'
)
......@@ -420,6 +420,104 @@ _RECORD_ROC_CURVE = _descriptor.Descriptor(
serialized_end=741,
)
_RECORD_HPARAM_HPARAMINFO = _descriptor.Descriptor(
name='HparamInfo',
full_name='visualdl.Record.HParam.HparamInfo',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='int_value', full_name='visualdl.Record.HParam.HparamInfo.int_value', index=0,
number=1, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='float_value', full_name='visualdl.Record.HParam.HparamInfo.float_value', index=1,
number=2, type=1, cpp_type=5, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='string_value', full_name='visualdl.Record.HParam.HparamInfo.string_value', index=2,
number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='name', full_name='visualdl.Record.HParam.HparamInfo.name', index=3,
number=4, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
_descriptor.OneofDescriptor(
name='type', full_name='visualdl.Record.HParam.HparamInfo.type',
index=0, containing_type=None, fields=[]),
],
serialized_start=882,
serialized_end=984,
)
_RECORD_HPARAM = _descriptor.Descriptor(
name='HParam',
full_name='visualdl.Record.HParam',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='hparamInfos', full_name='visualdl.Record.HParam.hparamInfos', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='metricInfos', full_name='visualdl.Record.HParam.metricInfos', index=1,
number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='name', full_name='visualdl.Record.HParam.name', index=2,
number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[_RECORD_HPARAM_HPARAMINFO, ],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=744,
serialized_end=984,
)
_RECORD_METADATA = _descriptor.Descriptor(
name='MetaData',
full_name='visualdl.Record.MetaData',
......@@ -446,8 +544,8 @@ _RECORD_METADATA = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=743,
serialized_end=775,
serialized_start=986,
serialized_end=1018,
)
_RECORD_VALUE = _descriptor.Descriptor(
......@@ -541,6 +639,13 @@ _RECORD_VALUE = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='hparam', full_name='visualdl.Record.Value.hparam', index=12,
number=13, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
......@@ -556,8 +661,8 @@ _RECORD_VALUE = _descriptor.Descriptor(
name='one_value', full_name='visualdl.Record.Value.one_value',
index=0, containing_type=None, fields=[]),
],
serialized_start=778,
serialized_end=1223,
serialized_start=1021,
serialized_end=1509,
)
_RECORD = _descriptor.Descriptor(
......@@ -577,7 +682,7 @@ _RECORD = _descriptor.Descriptor(
],
extensions=[
],
nested_types=[_RECORD_IMAGE, _RECORD_TEXT, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_HISTOGRAM, _RECORD_PRCURVE, _RECORD_ROC_CURVE, _RECORD_METADATA, _RECORD_VALUE, ],
nested_types=[_RECORD_IMAGE, _RECORD_TEXT, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_HISTOGRAM, _RECORD_PRCURVE, _RECORD_ROC_CURVE, _RECORD_HPARAM, _RECORD_METADATA, _RECORD_VALUE, ],
enum_types=[
],
serialized_options=None,
......@@ -587,7 +692,7 @@ _RECORD = _descriptor.Descriptor(
oneofs=[
],
serialized_start=27,
serialized_end=1223,
serialized_end=1509,
)
_RECORD_IMAGE.containing_type = _RECORD
......@@ -600,6 +705,19 @@ _RECORD_BYTES_EMBEDDINGS.containing_type = _RECORD
_RECORD_HISTOGRAM.containing_type = _RECORD
_RECORD_PRCURVE.containing_type = _RECORD
_RECORD_ROC_CURVE.containing_type = _RECORD
_RECORD_HPARAM_HPARAMINFO.containing_type = _RECORD_HPARAM
_RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'].fields.append(
_RECORD_HPARAM_HPARAMINFO.fields_by_name['int_value'])
_RECORD_HPARAM_HPARAMINFO.fields_by_name['int_value'].containing_oneof = _RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type']
_RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'].fields.append(
_RECORD_HPARAM_HPARAMINFO.fields_by_name['float_value'])
_RECORD_HPARAM_HPARAMINFO.fields_by_name['float_value'].containing_oneof = _RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type']
_RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'].fields.append(
_RECORD_HPARAM_HPARAMINFO.fields_by_name['string_value'])
_RECORD_HPARAM_HPARAMINFO.fields_by_name['string_value'].containing_oneof = _RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type']
_RECORD_HPARAM.fields_by_name['hparamInfos'].message_type = _RECORD_HPARAM_HPARAMINFO
_RECORD_HPARAM.fields_by_name['metricInfos'].message_type = _RECORD_HPARAM_HPARAMINFO
_RECORD_HPARAM.containing_type = _RECORD
_RECORD_METADATA.containing_type = _RECORD
_RECORD_VALUE.fields_by_name['image'].message_type = _RECORD_IMAGE
_RECORD_VALUE.fields_by_name['audio'].message_type = _RECORD_AUDIO
......@@ -609,6 +727,7 @@ _RECORD_VALUE.fields_by_name['pr_curve'].message_type = _RECORD_PRCURVE
_RECORD_VALUE.fields_by_name['meta_data'].message_type = _RECORD_METADATA
_RECORD_VALUE.fields_by_name['roc_curve'].message_type = _RECORD_ROC_CURVE
_RECORD_VALUE.fields_by_name['text'].message_type = _RECORD_TEXT
_RECORD_VALUE.fields_by_name['hparam'].message_type = _RECORD_HPARAM
_RECORD_VALUE.containing_type = _RECORD
_RECORD_VALUE.oneofs_by_name['one_value'].fields.append(
_RECORD_VALUE.fields_by_name['value'])
......@@ -637,6 +756,9 @@ _RECORD_VALUE.fields_by_name['roc_curve'].containing_oneof = _RECORD_VALUE.oneof
_RECORD_VALUE.oneofs_by_name['one_value'].fields.append(
_RECORD_VALUE.fields_by_name['text'])
_RECORD_VALUE.fields_by_name['text'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value']
_RECORD_VALUE.oneofs_by_name['one_value'].fields.append(
_RECORD_VALUE.fields_by_name['hparam'])
_RECORD_VALUE.fields_by_name['hparam'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value']
_RECORD.fields_by_name['values'].message_type = _RECORD_VALUE
DESCRIPTOR.message_types_by_name['Record'] = _RECORD
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
......@@ -706,6 +828,20 @@ Record = _reflection.GeneratedProtocolMessageType('Record', (_message.Message,),
})
,
'HParam' : _reflection.GeneratedProtocolMessageType('HParam', (_message.Message,), {
'HparamInfo' : _reflection.GeneratedProtocolMessageType('HparamInfo', (_message.Message,), {
'DESCRIPTOR' : _RECORD_HPARAM_HPARAMINFO,
'__module__' : 'record_pb2'
# @@protoc_insertion_point(class_scope:visualdl.Record.HParam.HparamInfo)
})
,
'DESCRIPTOR' : _RECORD_HPARAM,
'__module__' : 'record_pb2'
# @@protoc_insertion_point(class_scope:visualdl.Record.HParam)
})
,
'MetaData' : _reflection.GeneratedProtocolMessageType('MetaData', (_message.Message,), {
'DESCRIPTOR' : _RECORD_METADATA,
'__module__' : 'record_pb2'
......@@ -733,6 +869,8 @@ _sym_db.RegisterMessage(Record.bytes_embeddings)
_sym_db.RegisterMessage(Record.Histogram)
_sym_db.RegisterMessage(Record.PRCurve)
_sym_db.RegisterMessage(Record.ROC_Curve)
_sym_db.RegisterMessage(Record.HParam)
_sym_db.RegisterMessage(Record.HParam.HparamInfo)
_sym_db.RegisterMessage(Record.MetaData)
_sym_db.RegisterMessage(Record.Value)
......
......@@ -182,6 +182,8 @@ class LogReader(object):
component = "meta_data"
elif "text" == value_type:
component = "text"
elif "hparam" == value_type:
component = "hyper_parameters"
else:
raise TypeError("Invalid value type `%s`." % value_type)
self._tags[path] = component
......
......@@ -121,6 +121,28 @@ class Api(object):
def roc_curve_tags(self):
return self._get_with_retry('data/plugin/roc_curves/tags', lib.get_roc_curve_tags)
@result()
def hparam_importance(self):
return self._get_with_retry('data/plugin/hparams/importance', lib.get_hparam_importance)
@result()
def hparam_indicator(self):
return self._get_with_retry('data/plugin/hparams/indicators', lib.get_hparam_indicator)
@result()
def hparam_list(self):
return self._get_with_retry('data/plugin/hparams/list', lib.get_hparam_list)
@result()
def hparam_metric(self, run, metric):
key = os.path.join('data/plugin/hparams/metric', run, metric)
return self._get_with_retry(key, lib.get_hparam_metric, run, metric)
@result('text/csv')
def hparam_data(self, type='tsv'):
key = os.path.join('data/plugin/hparams/data', type)
return self._get_with_retry(key, lib.get_hparam_data, type)
@result()
def scalar_list(self, run, tag):
key = os.path.join('data/plugin/scalars/scalars', run, tag)
......@@ -254,7 +276,12 @@ def create_api_call(logdir, model, cache_timeout):
'pr-curve/list': (api.pr_curves_pr_curve, ['run', 'tag']),
'roc-curve/list': (api.roc_curves_roc_curve, ['run', 'tag']),
'pr-curve/steps': (api.pr_curves_steps, ['run']),
'roc-curve/steps': (api.roc_curves_steps, ['run'])
'roc-curve/steps': (api.roc_curves_steps, ['run']),
'hparams/importance': (api.hparam_importance, []),
'hparams/data': (api.hparam_data, ['type']),
'hparams/indicators': (api.hparam_indicator, []),
'hparams/list': (api.hparam_list, []),
'hparams/metric': (api.hparam_metric, ['run', 'metric'])
}
def call(path: str, args):
......
......@@ -26,7 +26,8 @@ DEFAULT_PLUGIN_MAXSIZE = {
"pr_curve": 300,
"roc_curve": 300,
"meta_data": 100,
"text": 10
"text": 10,
"hyper_parameters": 10000
}
......@@ -353,7 +354,9 @@ class DataManager(object):
"meta_data":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["meta_data"]),
"text":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["text"])
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["text"]),
"hyper_parameters":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["hyper_parameters"])
}
self._mutex = threading.Lock()
......
......@@ -24,6 +24,8 @@ import numpy as np
from visualdl.server.log import logger
from visualdl.io import bfile
from visualdl.utils.string_util import encode_tag, decode_tag
from visualdl.utils.importance import calc_all_hyper_param_importance
from visualdl.utils.list_util import duplicate_removal
from visualdl.component import components
......@@ -115,6 +117,172 @@ for name in components.keys():
exec("get_%s_tags=partial(get_logs, component='%s')" % (name, name))
def get_hparam_data(log_reader, type='tsv'):
result = get_hparam_list(log_reader)
delimeter = '\t' if 'tsv' == type else ','
header = ['Trial ID']
hparams_header = []
metrics_header = []
for item in result:
hparams_header += item['hparams'].keys()
metrics_header += item['metrics'].keys()
name_set = set()
h_header = []
for hparam in hparams_header:
if hparam in name_set:
continue
name_set.add(hparam)
h_header.append(hparam)
name_set = set()
m_header = []
for metric in metrics_header:
if metric in name_set:
continue
name_set.add(metric)
m_header.append(metric)
trans_result = []
for item in result:
temp = {'Trial ID': item.get('name', '')}
temp.update(item.get('hparams', {}))
temp.update(item.get('metrics', {}))
trans_result.append(temp)
header = header + h_header + m_header
with io.StringIO() as fp:
csv_writer = csv.writer(fp, delimiter=delimeter)
csv_writer.writerow(header)
for item in trans_result:
row = []
for col_name in header:
row.append(item.get(col_name, ''))
csv_writer.writerow(row)
result = fp.getvalue()
return result
def get_hparam_importance(log_reader):
indicator = get_hparam_indicator(log_reader)
hparams = [item for item in indicator['hparams'] if (item['type'] != 'string')]
metrics = [item for item in indicator['metrics'] if (item['type'] != 'string')]
result = calc_all_hyper_param_importance(hparams, metrics)
return result
# flake8: noqa: C901
def get_hparam_indicator(log_reader):
run2tag = get_logs(log_reader, 'hyper_parameters')
runs = run2tag['runs']
hparams = {}
metrics = {}
records_list = []
for run in runs:
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("hyper_parameters").get_items(
run, decode_tag('hparam'))
records_list.append([records, run])
records_list.sort(key=lambda x: x[0][0].timestamp)
runs = [run for r, run in records_list]
for records, run in records_list:
for hparamInfo in records[0].hparam.hparamInfos:
type = hparamInfo.WhichOneof("type")
if "float_value" == type:
if hparamInfo.name not in hparams.keys():
hparams[hparamInfo.name] = {'name': hparamInfo.name,
'type': 'continuous',
'values': [hparamInfo.float_value]}
elif hparamInfo.float_value not in hparams[hparamInfo.name]['values']:
hparams[hparamInfo.name]['values'].append(hparamInfo.float_value)
elif "string_value" == type:
if hparamInfo.name not in hparams.keys():
hparams[hparamInfo.name] = {'name': hparamInfo.name,
'type': 'string',
'values': [hparamInfo.string_value]}
elif hparamInfo.string_value not in hparams[hparamInfo.name]['values']:
hparams[hparamInfo.name]['values'].append(hparamInfo.string_value)
elif "int_value" == type:
if hparamInfo.name not in hparams.keys():
hparams[hparamInfo.name] = {'name': hparamInfo.name,
'type': 'numeric',
'values': [hparamInfo.int_value]}
elif hparamInfo.int_value not in hparams[hparamInfo.name]['values']:
hparams[hparamInfo.name]['values'].append(hparamInfo.int_value)
else:
raise TypeError("Invalid hparams param value type `%s`." % type)
for metricInfo in records[0].hparam.metricInfos:
metrics[metricInfo.name] = {'name': metricInfo.name,
'type': 'continuous',
'values': []}
for run in runs:
try:
metrics_data = get_hparam_metric(log_reader, run, metricInfo.name)
metrics[metricInfo.name]['values'].append(metrics_data[-1][-1])
break
except:
logger.error('Missing data of metrics! Please make sure use add_scalar to log metrics data.')
if len(metrics[metricInfo.name]['values']) == 0:
metrics.pop(metricInfo.name)
else:
metrics[metricInfo.name].pop('values')
results = {'hparams': [value for key, value in hparams.items()],
'metrics': [value for key, value in metrics.items()]}
return results
def get_hparam_metric(log_reader, run, tag):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("scalar").get_items(
run, decode_tag(tag))
results = [[s2ms(item.timestamp), item.id, item.value] for item in records]
return results
def get_hparam_list(log_reader):
run2tag = get_logs(log_reader, 'hyper_parameters')
runs = run2tag['runs']
results = []
records_list = []
for run in runs:
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("hyper_parameters").get_items(
run, decode_tag('hparam'))
records_list.append([records, run])
records_list.sort(key=lambda x: x[0][0].timestamp)
for records, run in records_list:
hparams = {}
for hparamInfo in records[0].hparam.hparamInfos:
hparam_type = hparamInfo.WhichOneof("type")
if "float_value" == hparam_type:
hparams[hparamInfo.name] = hparamInfo.float_value
elif "string_value" == hparam_type:
hparams[hparamInfo.name] = hparamInfo.string_value
elif "int_value" == hparam_type:
hparams[hparamInfo.name] = hparamInfo.int_value
else:
raise TypeError("Invalid hparams param value type `%s`." % hparam_type)
metrics = {}
for metricInfo in records[0].hparam.metricInfos:
try:
metrics_data = get_hparam_metric(log_reader, run, metricInfo.name)
metrics[metricInfo.name] = metrics_data[-1][-1]
except:
logger.error('Missing data of metrics! Please make sure use add_scalar to log metrics data.')
metrics[metricInfo.name] = None
results.append({'name': run,
'hparams': hparams,
'metrics': metrics})
return results
def get_scalar(log_reader, run, tag):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data()
......
# Copyright (c) 2020 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
from functools import reduce
import numpy as np
import pandas as pd
from visualdl.server.log import logger
def calc_hyper_param_importance(df, hyper_param, target):
new_df = df[[hyper_param, target]]
no_missing_value_df = new_df.dropna()
# Can not calc pearson correlation coefficient when number of samples is less or equal than 2
if len(no_missing_value_df) <= 2:
logger.error("Number of samples is less or equal than 2.")
return 0
correlation = no_missing_value_df[target].corr(no_missing_value_df[hyper_param])
if np.isnan(correlation):
logger.warning("Correlation is nan!")
return 0
return abs(correlation)
def calc_all_hyper_param_importance(hparams, metrics):
results = {}
for metric in metrics:
for hparam in hparams:
flattened_lineage = {hparam['name']: hparam['values'], metric['name']: metric['values']}
result = calc_hyper_param_importance(pd.DataFrame(flattened_lineage), hparam['name'], metric['name'])
# print('%s - %s : result=' % (hparam, metric), result)
if hparam['name'] not in results.keys():
results[hparam['name']] = result
else:
results[hparam['name']] += result
sum_score = reduce(lambda x, y: x+y, results.values())
for key, value in results.items():
results[key] = value/sum_score
result = [{'name': key, 'value': value} for key, value in results.items()]
return result
# Copyright (c) 2021 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
def duplicate_removal(src_list):
name_scope = set()
dest_list = []
for item in src_list:
if item in name_scope:
continue
name_scope.add(item)
dest_list.append(item)
return dest_list
......@@ -20,8 +20,9 @@ from visualdl.writer.record_writer import RecordFileWriter
from visualdl.server.log import logger
from visualdl.utils.img_util import merge_images
from visualdl.utils.figure_util import figure_to_image
from visualdl.utils.md5_util import md5
from visualdl.component.base_component import scalar, image, embedding, audio, \
histogram, pr_curve, roc_curve, meta_data, text
histogram, pr_curve, roc_curve, meta_data, text, hparam
class DummyFileWriter(object):
......@@ -441,6 +442,45 @@ class LogWriter(object):
step=step,
walltime=walltime))
def add_hparams(self, hparam_dict, metric_list, walltime=None):
"""Add an histogram to vdl record file.
Args:
hparam_dict (dictionary): Each key-value pair in the dictionary is the
name of the hyper parameter and it's corresponding value. The type of the value
can be one of `bool`, `string`, `float`, `int`, or `None`.
metric_list (list): Name of all metrics.
walltime (int): Wall time of hparams.
Examples::
from visualdl import LogWriter
# Remember use add_scalar to log your metrics data!
with LogWriter('./log/hparams_test/train/run1') as writer:
writer.add_hparams({'lr': 0.1, 'bsize': 1, 'opt': 'sgd'}, ['hparam/accuracy', 'hparam/loss'])
for i in range(10):
writer.add_scalar('hparam/accuracy', i, i)
writer.add_scalar('hparam/loss', 2*i, i)
with LogWriter('./log/hparams_test/train/run2') as writer:
writer.add_hparams({'lr': 0.2, 'bsize': 2, 'opt': 'relu'}, ['hparam/accuracy', 'hparam/loss'])
for i in range(10):
writer.add_scalar('hparam/accuracy', 1.0/(i+1), i)
writer.add_scalar('hparam/loss', 5*i, i)
"""
if type(hparam_dict) is not dict:
raise TypeError('hparam_dict should be dictionary.')
if type(metric_list) is not list:
raise TypeError('metric_list should be list.')
walltime = round(time.time() * 1000) if walltime is None else walltime
self._get_file_writer().add_record(
hparam(
name=md5(self.file_name),
hparam_dict=hparam_dict,
metric_list=metric_list,
walltime=walltime))
def add_pr_curve(self,
tag,
labels,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册