未验证 提交 c6d4cdd1 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

Add pr curve. (#688)

* Add pr curve.
上级 e698eda6
# Copyright (c) 2020 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
# coding=utf-8
from visualdl import LogWriter
import numpy as np
with LogWriter("./log/pr_curve_test/train") as writer:
for step in range(3):
labels = np.random.randint(2, size=100)
predictions = np.random.rand(100)
writer.add_pr_curve(tag='pr_curve',
labels=labels,
predictions=predictions,
step=step,
num_thresholds=5)
...@@ -134,8 +134,146 @@ def audio(tag, audio_array, sample_rate, step, walltime): ...@@ -134,8 +134,146 @@ def audio(tag, audio_array, sample_rate, step, walltime):
def histogram(tag, hist, bin_edges, step, walltime): def histogram(tag, hist, bin_edges, step, walltime):
"""Package data to one histogram.
Args:
tag (string): Data identifier
hist (numpy.ndarray or list): The values of the histogram
bin_edges (numpy.ndarray or list): The bin edges
step (int): Step of histogram
walltime (int): Wall time of histogram
Return:
Package with format of record_pb2.Record
"""
histogram = Record.Histogram(hist=hist, bin_edges=bin_edges) histogram = Record.Histogram(hist=hist, bin_edges=bin_edges)
return Record(values=[ return Record(values=[
Record.Value( Record.Value(
id=step, tag=tag, timestamp=walltime, histogram=histogram) id=step, tag=tag, timestamp=walltime, histogram=histogram)
]) ])
def compute_curve(labels, predictions, num_thresholds=None, weights=None):
""" Compute precision-recall curve data by labels and predictions.
Args:
labels (numpy.ndarray or list): Binary labels for each element.
predictions (numpy.ndarray or list): The probability that an element be
classified as true.
num_thresholds (int): Number of thresholds used to draw the curve.
weights (float): Multiple of data to display on the curve.
"""
_MINIMUM_COUNT = 1e-7
if weights is None:
weights = 1.0
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
float_labels = labels.astype(np.float)
histogram_range = (0, num_thresholds - 1)
tp_buckets, _ = np.histogram(
bucket_indices,
bins=num_thresholds,
range=histogram_range,
weights=float_labels * weights)
fp_buckets, _ = np.histogram(
bucket_indices,
bins=num_thresholds,
range=histogram_range,
weights=(1.0 - float_labels) * weights)
# Obtain the reverse cumulative sum.
tp = np.cumsum(tp_buckets[::-1])[::-1]
fp = np.cumsum(fp_buckets[::-1])[::-1]
tn = fp[0] - fp
fn = tp[0] - tp
precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
data = {
'tp': tp.astype(int).tolist(),
'fp': fp.astype(int).tolist(),
'tn': tn.astype(int).tolist(),
'fn': fn.astype(int).tolist(),
'precision': precision.astype(float).tolist(),
'recall': recall.astype(float).tolist()
}
return data
def pr_curve(tag, labels, predictions, step, walltime, num_thresholds=127,
weights=None):
"""Package data to one pr_curve.
Args:
tag (string): Data identifier
labels (numpy.ndarray or list): Binary labels for each element.
predictions (numpy.ndarray or list): The probability that an element be
classified as true.
step (int): Step of pr_curve
walltime (int): Wall time of pr_curve
num_thresholds (int): Number of thresholds used to draw the curve.
weights (float): Multiple of data to display on the curve.
Return:
Package with format of record_pb2.Record
"""
num_thresholds = min(num_thresholds, 127)
prcurve_map = compute_curve(labels, predictions, num_thresholds, weights)
return pr_curve_raw(tag=tag,
tp=prcurve_map['tp'],
fp=prcurve_map['fp'],
tn=prcurve_map['tn'],
fn=prcurve_map['fn'],
precision=prcurve_map['precision'],
recall=prcurve_map['recall'],
step=step,
walltime=walltime)
def pr_curve_raw(tag, tp, fp, tn, fn, precision, recall, step, walltime):
"""Package raw data to one pr_curve.
Args:
tag (string): Data identifier
tp (list): True Positive.
fp (list): False Positive.
tn (list): True Negative.
fn (list): False Negative.
precision (list): The fraction of retrieved documents that are relevant
to the query:
recall (list): The fraction of the relevant documents that are
successfully retrieved.
step (int): Step of pr_curve
walltime (int): Wall time of pr_curve
num_thresholds (int): Number of thresholds used to draw the curve.
weights (float): Multiple of data to display on the curve.
Return:
Package with format of record_pb2.Record
"""
"""
if isinstance(tp, np.ndarray):
tp = tp.astype(int).tolist()
if isinstance(fp, np.ndarray):
fp = fp.astype(int).tolist()
if isinstance(tn, np.ndarray):
tn = tn.astype(int).tolist()
if isinstance(fn, np.ndarray):
fn = fn.astype(int).tolist()
if isinstance(precision, np.ndarray):
precision = precision.astype(int).tolist()
if isinstance(recall, np.ndarray):
recall = recall.astype(int).tolist()
"""
prcurve = Record.PRCurve(TP=tp,
FP=fp,
TN=tn,
FN=fn,
precision=precision,
recall=recall)
return Record(values=[
Record.Value(
id=step, tag=tag, timestamp=walltime, pr_curve=prcurve)
])
...@@ -29,10 +29,19 @@ message Record { ...@@ -29,10 +29,19 @@ message Record {
bytes encoded_vectors = 2; bytes encoded_vectors = 2;
} }
message Histogram { message Histogram {
repeated double hist = 1 [packed = true]; repeated double hist = 1 [packed = true];
repeated double bin_edges = 2 [packed = true]; repeated double bin_edges = 2 [packed = true];
}; }
message PRCurve {
repeated int64 TP = 1 [packed = true];
repeated int64 FP = 2 [packed = true];
repeated int64 TN = 3 [packed = true];
repeated int64 FN = 4 [packed = true];
repeated double precision = 5;
repeated double recall = 6;
}
message Value { message Value {
int64 id = 1; int64 id = 1;
...@@ -44,6 +53,7 @@ message Histogram { ...@@ -44,6 +53,7 @@ message Histogram {
Audio audio = 6; Audio audio = 6;
Embeddings embeddings = 7; Embeddings embeddings = 7;
Histogram histogram = 8; Histogram histogram = 8;
PRCurve pr_curve = 9;
} }
} }
......
...@@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='visualdl', package='visualdl',
syntax='proto3', syntax='proto3',
serialized_options=None, serialized_options=None,
serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xc6\x05\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1a<\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1a\x87\x02\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x42\x0b\n\tone_valueb\x06proto3' serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xe2\x06\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1a<\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1al\n\x07PRCurve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x11\n\tprecision\x18\x05 \x03(\x01\x12\x0e\n\x06recall\x18\x06 \x03(\x01\x1a\xb5\x02\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x12,\n\x08pr_curve\x18\t \x01(\x0b\x32\x18.visualdl.Record.PRCurveH\x00\x42\x0b\n\tone_valueb\x06proto3'
) )
...@@ -253,6 +253,71 @@ _RECORD_HISTOGRAM = _descriptor.Descriptor( ...@@ -253,6 +253,71 @@ _RECORD_HISTOGRAM = _descriptor.Descriptor(
serialized_end=471, serialized_end=471,
) )
_RECORD_PRCURVE = _descriptor.Descriptor(
name='PRCurve',
full_name='visualdl.Record.PRCurve',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='TP', full_name='visualdl.Record.PRCurve.TP', index=0,
number=1, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=b'\020\001', file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='FP', full_name='visualdl.Record.PRCurve.FP', index=1,
number=2, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=b'\020\001', file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='TN', full_name='visualdl.Record.PRCurve.TN', index=2,
number=3, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=b'\020\001', file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='FN', full_name='visualdl.Record.PRCurve.FN', index=3,
number=4, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=b'\020\001', file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='precision', full_name='visualdl.Record.PRCurve.precision', index=4,
number=5, type=1, cpp_type=5, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='recall', full_name='visualdl.Record.PRCurve.recall', index=5,
number=6, type=1, cpp_type=5, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=473,
serialized_end=581,
)
_RECORD_VALUE = _descriptor.Descriptor( _RECORD_VALUE = _descriptor.Descriptor(
name='Value', name='Value',
full_name='visualdl.Record.Value', full_name='visualdl.Record.Value',
...@@ -316,6 +381,13 @@ _RECORD_VALUE = _descriptor.Descriptor( ...@@ -316,6 +381,13 @@ _RECORD_VALUE = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='pr_curve', full_name='visualdl.Record.Value.pr_curve', index=8,
number=9, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
], ],
extensions=[ extensions=[
], ],
...@@ -331,8 +403,8 @@ _RECORD_VALUE = _descriptor.Descriptor( ...@@ -331,8 +403,8 @@ _RECORD_VALUE = _descriptor.Descriptor(
name='one_value', full_name='visualdl.Record.Value.one_value', name='one_value', full_name='visualdl.Record.Value.one_value',
index=0, containing_type=None, fields=[]), index=0, containing_type=None, fields=[]),
], ],
serialized_start=474, serialized_start=584,
serialized_end=737, serialized_end=893,
) )
_RECORD = _descriptor.Descriptor( _RECORD = _descriptor.Descriptor(
...@@ -352,7 +424,7 @@ _RECORD = _descriptor.Descriptor( ...@@ -352,7 +424,7 @@ _RECORD = _descriptor.Descriptor(
], ],
extensions=[ extensions=[
], ],
nested_types=[_RECORD_IMAGE, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_HISTOGRAM, _RECORD_VALUE, ], nested_types=[_RECORD_IMAGE, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_HISTOGRAM, _RECORD_PRCURVE, _RECORD_VALUE, ],
enum_types=[ enum_types=[
], ],
serialized_options=None, serialized_options=None,
...@@ -362,7 +434,7 @@ _RECORD = _descriptor.Descriptor( ...@@ -362,7 +434,7 @@ _RECORD = _descriptor.Descriptor(
oneofs=[ oneofs=[
], ],
serialized_start=27, serialized_start=27,
serialized_end=737, serialized_end=893,
) )
_RECORD_IMAGE.containing_type = _RECORD _RECORD_IMAGE.containing_type = _RECORD
...@@ -372,10 +444,12 @@ _RECORD_EMBEDDINGS.fields_by_name['embeddings'].message_type = _RECORD_EMBEDDING ...@@ -372,10 +444,12 @@ _RECORD_EMBEDDINGS.fields_by_name['embeddings'].message_type = _RECORD_EMBEDDING
_RECORD_EMBEDDINGS.containing_type = _RECORD _RECORD_EMBEDDINGS.containing_type = _RECORD
_RECORD_BYTES_EMBEDDINGS.containing_type = _RECORD _RECORD_BYTES_EMBEDDINGS.containing_type = _RECORD
_RECORD_HISTOGRAM.containing_type = _RECORD _RECORD_HISTOGRAM.containing_type = _RECORD
_RECORD_PRCURVE.containing_type = _RECORD
_RECORD_VALUE.fields_by_name['image'].message_type = _RECORD_IMAGE _RECORD_VALUE.fields_by_name['image'].message_type = _RECORD_IMAGE
_RECORD_VALUE.fields_by_name['audio'].message_type = _RECORD_AUDIO _RECORD_VALUE.fields_by_name['audio'].message_type = _RECORD_AUDIO
_RECORD_VALUE.fields_by_name['embeddings'].message_type = _RECORD_EMBEDDINGS _RECORD_VALUE.fields_by_name['embeddings'].message_type = _RECORD_EMBEDDINGS
_RECORD_VALUE.fields_by_name['histogram'].message_type = _RECORD_HISTOGRAM _RECORD_VALUE.fields_by_name['histogram'].message_type = _RECORD_HISTOGRAM
_RECORD_VALUE.fields_by_name['pr_curve'].message_type = _RECORD_PRCURVE
_RECORD_VALUE.containing_type = _RECORD _RECORD_VALUE.containing_type = _RECORD
_RECORD_VALUE.oneofs_by_name['one_value'].fields.append( _RECORD_VALUE.oneofs_by_name['one_value'].fields.append(
_RECORD_VALUE.fields_by_name['value']) _RECORD_VALUE.fields_by_name['value'])
...@@ -392,6 +466,9 @@ _RECORD_VALUE.fields_by_name['embeddings'].containing_oneof = _RECORD_VALUE.oneo ...@@ -392,6 +466,9 @@ _RECORD_VALUE.fields_by_name['embeddings'].containing_oneof = _RECORD_VALUE.oneo
_RECORD_VALUE.oneofs_by_name['one_value'].fields.append( _RECORD_VALUE.oneofs_by_name['one_value'].fields.append(
_RECORD_VALUE.fields_by_name['histogram']) _RECORD_VALUE.fields_by_name['histogram'])
_RECORD_VALUE.fields_by_name['histogram'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value'] _RECORD_VALUE.fields_by_name['histogram'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value']
_RECORD_VALUE.oneofs_by_name['one_value'].fields.append(
_RECORD_VALUE.fields_by_name['pr_curve'])
_RECORD_VALUE.fields_by_name['pr_curve'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value']
_RECORD.fields_by_name['values'].message_type = _RECORD_VALUE _RECORD.fields_by_name['values'].message_type = _RECORD_VALUE
DESCRIPTOR.message_types_by_name['Record'] = _RECORD DESCRIPTOR.message_types_by_name['Record'] = _RECORD
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -440,6 +517,13 @@ Record = _reflection.GeneratedProtocolMessageType('Record', (_message.Message,), ...@@ -440,6 +517,13 @@ Record = _reflection.GeneratedProtocolMessageType('Record', (_message.Message,),
}) })
, ,
'PRCurve' : _reflection.GeneratedProtocolMessageType('PRCurve', (_message.Message,), {
'DESCRIPTOR' : _RECORD_PRCURVE,
'__module__' : 'record_pb2'
# @@protoc_insertion_point(class_scope:visualdl.Record.PRCurve)
})
,
'Value' : _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), { 'Value' : _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), {
'DESCRIPTOR' : _RECORD_VALUE, 'DESCRIPTOR' : _RECORD_VALUE,
'__module__' : 'record_pb2' '__module__' : 'record_pb2'
...@@ -457,9 +541,14 @@ _sym_db.RegisterMessage(Record.Embedding) ...@@ -457,9 +541,14 @@ _sym_db.RegisterMessage(Record.Embedding)
_sym_db.RegisterMessage(Record.Embeddings) _sym_db.RegisterMessage(Record.Embeddings)
_sym_db.RegisterMessage(Record.bytes_embeddings) _sym_db.RegisterMessage(Record.bytes_embeddings)
_sym_db.RegisterMessage(Record.Histogram) _sym_db.RegisterMessage(Record.Histogram)
_sym_db.RegisterMessage(Record.PRCurve)
_sym_db.RegisterMessage(Record.Value) _sym_db.RegisterMessage(Record.Value)
_RECORD_HISTOGRAM.fields_by_name['hist']._options = None _RECORD_HISTOGRAM.fields_by_name['hist']._options = None
_RECORD_HISTOGRAM.fields_by_name['bin_edges']._options = None _RECORD_HISTOGRAM.fields_by_name['bin_edges']._options = None
_RECORD_PRCURVE.fields_by_name['TP']._options = None
_RECORD_PRCURVE.fields_by_name['FP']._options = None
_RECORD_PRCURVE.fields_by_name['TN']._options = None
_RECORD_PRCURVE.fields_by_name['FN']._options = None
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)
...@@ -106,6 +106,8 @@ class LogReader(object): ...@@ -106,6 +106,8 @@ class LogReader(object):
component = "audio" component = "audio"
elif "histogram" == value_type: elif "histogram" == value_type:
component = "histogram" component = "histogram"
elif "pr_curve" == value_type:
component = "pr_curve"
else: else:
raise TypeError("Invalid value type `%s`." % value_type) raise TypeError("Invalid value type `%s`." % value_type)
self._tags[path] = component self._tags[path] = component
......
...@@ -109,6 +109,10 @@ class Api(object): ...@@ -109,6 +109,10 @@ class Api(object):
def embeddings_tags(self): def embeddings_tags(self):
return self._get_with_retry('data/plugin/embeddings/tags', lib.get_embeddings_tags) return self._get_with_retry('data/plugin/embeddings/tags', lib.get_embeddings_tags)
@result()
def pr_curve_tags(self):
return self._get_with_retry('data/plugin/pr_curves/tags', lib.get_pr_curve_tags)
@result() @result()
def scalars_list(self, run, tag): def scalars_list(self, run, tag):
key = os.path.join('data/plugin/scalars/scalars', run, tag) key = os.path.join('data/plugin/scalars/scalars', run, tag)
...@@ -151,6 +155,16 @@ class Api(object): ...@@ -151,6 +155,16 @@ class Api(object):
key = os.path.join('data/plugin/histogram/histogram', run, tag) key = os.path.join('data/plugin/histogram/histogram', run, tag)
return self._get_with_retry(key, lib.get_histogram, run, tag) return self._get_with_retry(key, lib.get_histogram, run, tag)
@result()
def pr_curves_pr_curve(self, run, tag):
key = os.path.join('data/plugin/pr_curves/pr_curve', run, tag)
return self._get_with_retry(key, lib.get_pr_curve, run, tag)
@result()
def pr_curves_steps(self, run):
key = os.path.join('data/plugin/pr_curves/steps', run)
return self._get_with_retry(key, lib.get_pr_curve_step, run)
@result('application/octet-stream', lambda s: {"Content-Disposition": 'attachment; filename="%s"' % s.model_name} if len(s.model_name) else None) @result('application/octet-stream', lambda s: {"Content-Disposition": 'attachment; filename="%s"' % s.model_name} if len(s.model_name) else None)
def graphs_graph(self): def graphs_graph(self):
key = os.path.join('data/plugin/graphs/graph') key = os.path.join('data/plugin/graphs/graph')
...@@ -169,6 +183,7 @@ def create_api_call(logdir, model, cache_timeout): ...@@ -169,6 +183,7 @@ def create_api_call(logdir, model, cache_timeout):
'audio/tags': (api.audio_tags, []), 'audio/tags': (api.audio_tags, []),
'embeddings/tags': (api.embeddings_tags, []), 'embeddings/tags': (api.embeddings_tags, []),
'histogram/tags': (api.histogram_tags, []), 'histogram/tags': (api.histogram_tags, []),
'pr-curve/tags': (api.pr_curve_tags, []),
'scalars/list': (api.scalars_list, ['run', 'tag']), 'scalars/list': (api.scalars_list, ['run', 'tag']),
'images/list': (api.images_list, ['run', 'tag']), 'images/list': (api.images_list, ['run', 'tag']),
'images/image': (api.images_image, ['run', 'tag', 'index']), 'images/image': (api.images_image, ['run', 'tag', 'index']),
...@@ -176,7 +191,9 @@ def create_api_call(logdir, model, cache_timeout): ...@@ -176,7 +191,9 @@ def create_api_call(logdir, model, cache_timeout):
'audio/audio': (api.audio_audio, ['run', 'tag', 'index']), 'audio/audio': (api.audio_audio, ['run', 'tag', 'index']),
'embeddings/embedding': (api.embeddings_embedding, ['run', 'tag', 'reduction', 'dimension']), 'embeddings/embedding': (api.embeddings_embedding, ['run', 'tag', 'reduction', 'dimension']),
'histogram/list': (api.histogram_list, ['run', 'tag']), 'histogram/list': (api.histogram_list, ['run', 'tag']),
'graphs/graph': (api.graphs_graph, []) 'graphs/graph': (api.graphs_graph, []),
'pr-curve/list': (api.pr_curves_pr_curve, ['run', 'tag']),
'pr-curve/steps': (api.pr_curves_steps, ['run'])
} }
def call(path: str, args): def call(path: str, args):
......
...@@ -23,7 +23,8 @@ DEFAULT_PLUGIN_MAXSIZE = { ...@@ -23,7 +23,8 @@ DEFAULT_PLUGIN_MAXSIZE = {
"image": 10, "image": 10,
"histogram": 100, "histogram": 100,
"embeddings": 50000, "embeddings": 50000,
"audio": 10 "audio": 10,
"pr_curve": 300
} }
...@@ -274,7 +275,9 @@ class DataManager(object): ...@@ -274,7 +275,9 @@ class DataManager(object):
"embeddings": "embeddings":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["embeddings"]), Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["embeddings"]),
"audio": "audio":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["audio"]) Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["audio"]),
"pr_curve":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["pr_curve"])
} }
self._mutex = threading.Lock() self._mutex = threading.Lock()
......
...@@ -114,6 +114,40 @@ def get_histogram_tags(log_reader): ...@@ -114,6 +114,40 @@ def get_histogram_tags(log_reader):
return get_logs(log_reader, "histogram") return get_logs(log_reader, "histogram")
def get_pr_curve_tags(log_reader):
return get_logs(log_reader, "pr_curve")
def get_pr_curve(log_reader, run, tag):
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("pr_curve").get_items(
run, decode_tag(tag))
results = []
for item in records:
pr_curve = item.pr_curve
length = len(pr_curve.precision)
num_thresholds = [float(v) / length for v in range(1, length + 1)]
results.append([item.timestamp,
item.id,
list(pr_curve.precision),
list(pr_curve.recall),
list(pr_curve.TP),
list(pr_curve.FP),
list(pr_curve.TN),
list(pr_curve.FN),
num_thresholds])
return results
def get_pr_curve_step(log_reader, run, tag=None):
tag = get_pr_curve_tags(log_reader)[run][0] if tag is None else tag
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("pr_curve").get_items(
run, decode_tag(tag))
results = [[item.timestamp, item.id] for item in records]
return results
def get_embeddings(log_reader, run, tag, reduction, dimension=2): def get_embeddings(log_reader, run, tag, reduction, dimension=2):
log_reader.load_new_data() log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("embeddings").get_items( records = log_reader.data_manager.get_reservoir("embeddings").get_items(
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
# ======================================================================= # =======================================================================
import os import os
import time import time
from visualdl.writer.record_writer import RecordFileWriter
from visualdl.component.base_component import scalar, image, embedding, audio, histogram
import numpy as np import numpy as np
from visualdl.writer.record_writer import RecordFileWriter
from visualdl.component.base_component import scalar, image, embedding, audio, histogram, pr_curve
class DummyFileWriter(object): class DummyFileWriter(object):
...@@ -281,6 +281,50 @@ class LogWriter(object): ...@@ -281,6 +281,50 @@ class LogWriter(object):
step=step, step=step,
walltime=walltime)) walltime=walltime))
def add_pr_curve(self,
tag,
labels,
predictions,
step,
num_thresholds=10,
weights=None,
walltime=None):
"""Add an precision-recall curve to vdl record file.
Args:
tag (string): Data identifier
labels (numpy.ndarray or list): Binary labels for each element.
predictions (numpy.ndarray or list): The probability that an element
be classified as true.
step (int): Step of pr curve.
weights (float): Multiple of data to display on the curve.
num_thresholds (int): Number of thresholds used to draw the curve.
walltime (int): Wall time of pr curve.
Example:
with LogWriter(logdir="./log/pr_curve_test/train") as writer:
for index in range(3):
labels = np.random.randint(2, size=100)
predictions = np.random.rand(100)
writer.add_pr_curve(tag='default',
labels=labels,
predictions=predictions,
step=index)
"""
if '%' in tag:
raise RuntimeError("% can't appear in tag!")
walltime = round(time.time()) if walltime is None else walltime
self._get_file_writer().add_record(
pr_curve(
tag=tag,
labels=labels,
predictions=predictions,
step=step,
walltime=walltime,
num_thresholds=num_thresholds,
weights=weights
))
def flush(self): def flush(self):
"""Flush all data in cache to disk. """Flush all data in cache to disk.
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册