Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
VisualDL
提交
c6d4cdd1
V
VisualDL
项目概览
PaddlePaddle
/
VisualDL
大约 1 年 前同步成功
通知
88
Star
4655
Fork
642
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
2
Wiki
5
Wiki
分析
仓库
DevOps
项目成员
Pages
V
VisualDL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
2
合并请求
2
Pages
分析
分析
仓库分析
DevOps
Wiki
5
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c6d4cdd1
编写于
6月 29, 2020
作者:
走神的阿圆
提交者:
GitHub
6月 29, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add pr curve. (#688)
* Add pr curve.
上级
e698eda6
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
378 addition
and
14 deletion
+378
-14
demo/components/pr_curve_test.py
demo/components/pr_curve_test.py
+27
-0
visualdl/component/base_component.py
visualdl/component/base_component.py
+138
-0
visualdl/proto/record.proto
visualdl/proto/record.proto
+14
-4
visualdl/proto/record_pb2.py
visualdl/proto/record_pb2.py
+94
-5
visualdl/reader/reader.py
visualdl/reader/reader.py
+2
-0
visualdl/server/api.py
visualdl/server/api.py
+18
-1
visualdl/server/data_manager.py
visualdl/server/data_manager.py
+5
-2
visualdl/server/lib.py
visualdl/server/lib.py
+34
-0
visualdl/writer/writer.py
visualdl/writer/writer.py
+46
-2
未找到文件。
demo/components/pr_curve_test.py
0 → 100644
浏览文件 @
c6d4cdd1
# Copyright (c) 2020 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
# coding=utf-8
from
visualdl
import
LogWriter
import
numpy
as
np
with
LogWriter
(
"./log/pr_curve_test/train"
)
as
writer
:
for
step
in
range
(
3
):
labels
=
np
.
random
.
randint
(
2
,
size
=
100
)
predictions
=
np
.
random
.
rand
(
100
)
writer
.
add_pr_curve
(
tag
=
'pr_curve'
,
labels
=
labels
,
predictions
=
predictions
,
step
=
step
,
num_thresholds
=
5
)
visualdl/component/base_component.py
浏览文件 @
c6d4cdd1
...
...
@@ -134,8 +134,146 @@ def audio(tag, audio_array, sample_rate, step, walltime):
def
histogram
(
tag
,
hist
,
bin_edges
,
step
,
walltime
):
"""Package data to one histogram.
Args:
tag (string): Data identifier
hist (numpy.ndarray or list): The values of the histogram
bin_edges (numpy.ndarray or list): The bin edges
step (int): Step of histogram
walltime (int): Wall time of histogram
Return:
Package with format of record_pb2.Record
"""
histogram
=
Record
.
Histogram
(
hist
=
hist
,
bin_edges
=
bin_edges
)
return
Record
(
values
=
[
Record
.
Value
(
id
=
step
,
tag
=
tag
,
timestamp
=
walltime
,
histogram
=
histogram
)
])
def
compute_curve
(
labels
,
predictions
,
num_thresholds
=
None
,
weights
=
None
):
""" Compute precision-recall curve data by labels and predictions.
Args:
labels (numpy.ndarray or list): Binary labels for each element.
predictions (numpy.ndarray or list): The probability that an element be
classified as true.
num_thresholds (int): Number of thresholds used to draw the curve.
weights (float): Multiple of data to display on the curve.
"""
_MINIMUM_COUNT
=
1e-7
if
weights
is
None
:
weights
=
1.0
bucket_indices
=
np
.
int32
(
np
.
floor
(
predictions
*
(
num_thresholds
-
1
)))
float_labels
=
labels
.
astype
(
np
.
float
)
histogram_range
=
(
0
,
num_thresholds
-
1
)
tp_buckets
,
_
=
np
.
histogram
(
bucket_indices
,
bins
=
num_thresholds
,
range
=
histogram_range
,
weights
=
float_labels
*
weights
)
fp_buckets
,
_
=
np
.
histogram
(
bucket_indices
,
bins
=
num_thresholds
,
range
=
histogram_range
,
weights
=
(
1.0
-
float_labels
)
*
weights
)
# Obtain the reverse cumulative sum.
tp
=
np
.
cumsum
(
tp_buckets
[::
-
1
])[::
-
1
]
fp
=
np
.
cumsum
(
fp_buckets
[::
-
1
])[::
-
1
]
tn
=
fp
[
0
]
-
fp
fn
=
tp
[
0
]
-
tp
precision
=
tp
/
np
.
maximum
(
_MINIMUM_COUNT
,
tp
+
fp
)
recall
=
tp
/
np
.
maximum
(
_MINIMUM_COUNT
,
tp
+
fn
)
data
=
{
'tp'
:
tp
.
astype
(
int
).
tolist
(),
'fp'
:
fp
.
astype
(
int
).
tolist
(),
'tn'
:
tn
.
astype
(
int
).
tolist
(),
'fn'
:
fn
.
astype
(
int
).
tolist
(),
'precision'
:
precision
.
astype
(
float
).
tolist
(),
'recall'
:
recall
.
astype
(
float
).
tolist
()
}
return
data
def
pr_curve
(
tag
,
labels
,
predictions
,
step
,
walltime
,
num_thresholds
=
127
,
weights
=
None
):
"""Package data to one pr_curve.
Args:
tag (string): Data identifier
labels (numpy.ndarray or list): Binary labels for each element.
predictions (numpy.ndarray or list): The probability that an element be
classified as true.
step (int): Step of pr_curve
walltime (int): Wall time of pr_curve
num_thresholds (int): Number of thresholds used to draw the curve.
weights (float): Multiple of data to display on the curve.
Return:
Package with format of record_pb2.Record
"""
num_thresholds
=
min
(
num_thresholds
,
127
)
prcurve_map
=
compute_curve
(
labels
,
predictions
,
num_thresholds
,
weights
)
return
pr_curve_raw
(
tag
=
tag
,
tp
=
prcurve_map
[
'tp'
],
fp
=
prcurve_map
[
'fp'
],
tn
=
prcurve_map
[
'tn'
],
fn
=
prcurve_map
[
'fn'
],
precision
=
prcurve_map
[
'precision'
],
recall
=
prcurve_map
[
'recall'
],
step
=
step
,
walltime
=
walltime
)
def
pr_curve_raw
(
tag
,
tp
,
fp
,
tn
,
fn
,
precision
,
recall
,
step
,
walltime
):
"""Package raw data to one pr_curve.
Args:
tag (string): Data identifier
tp (list): True Positive.
fp (list): False Positive.
tn (list): True Negative.
fn (list): False Negative.
precision (list): The fraction of retrieved documents that are relevant
to the query:
recall (list): The fraction of the relevant documents that are
successfully retrieved.
step (int): Step of pr_curve
walltime (int): Wall time of pr_curve
num_thresholds (int): Number of thresholds used to draw the curve.
weights (float): Multiple of data to display on the curve.
Return:
Package with format of record_pb2.Record
"""
"""
if isinstance(tp, np.ndarray):
tp = tp.astype(int).tolist()
if isinstance(fp, np.ndarray):
fp = fp.astype(int).tolist()
if isinstance(tn, np.ndarray):
tn = tn.astype(int).tolist()
if isinstance(fn, np.ndarray):
fn = fn.astype(int).tolist()
if isinstance(precision, np.ndarray):
precision = precision.astype(int).tolist()
if isinstance(recall, np.ndarray):
recall = recall.astype(int).tolist()
"""
prcurve
=
Record
.
PRCurve
(
TP
=
tp
,
FP
=
fp
,
TN
=
tn
,
FN
=
fn
,
precision
=
precision
,
recall
=
recall
)
return
Record
(
values
=
[
Record
.
Value
(
id
=
step
,
tag
=
tag
,
timestamp
=
walltime
,
pr_curve
=
prcurve
)
])
visualdl/proto/record.proto
浏览文件 @
c6d4cdd1
...
...
@@ -29,10 +29,19 @@ message Record {
bytes
encoded_vectors
=
2
;
}
message
Histogram
{
repeated
double
hist
=
1
[
packed
=
true
];
repeated
double
bin_edges
=
2
[
packed
=
true
];
};
message
Histogram
{
repeated
double
hist
=
1
[
packed
=
true
];
repeated
double
bin_edges
=
2
[
packed
=
true
];
}
message
PRCurve
{
repeated
int64
TP
=
1
[
packed
=
true
];
repeated
int64
FP
=
2
[
packed
=
true
];
repeated
int64
TN
=
3
[
packed
=
true
];
repeated
int64
FN
=
4
[
packed
=
true
];
repeated
double
precision
=
5
;
repeated
double
recall
=
6
;
}
message
Value
{
int64
id
=
1
;
...
...
@@ -44,6 +53,7 @@ message Histogram {
Audio
audio
=
6
;
Embeddings
embeddings
=
7
;
Histogram
histogram
=
8
;
PRCurve
pr_curve
=
9
;
}
}
...
...
visualdl/proto/record_pb2.py
浏览文件 @
c6d4cdd1
...
...
@@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package
=
'visualdl'
,
syntax
=
'proto3'
,
serialized_options
=
None
,
serialized_pb
=
b
'
\n\x0c
record.proto
\x12\x08
visualdl
\"\x
c6\x05\n\x06
Record
\x12
&
\n\x06
values
\x18\x01
\x03
(
\x0b\x32\x16
.visualdl.Record.Value
\x1a
%
\n\x05
Image
\x12\x1c\n\x14\x65
ncoded_image_string
\x18\x04
\x01
(
\x0c\x1a
}
\n\x05\x41
udio
\x12\x13\n\x0b
sample_rate
\x18\x01
\x01
(
\x02\x12\x14\n\x0c
num_channels
\x18\x02
\x01
(
\x03\x12\x15\n\r
length_frames
\x18\x03
\x01
(
\x03\x12\x1c\n\x14\x65
ncoded_audio_string
\x18\x04
\x01
(
\x0c\x12\x14\n\x0c\x63
ontent_type
\x18\x05
\x01
(
\t\x1a
+
\n\t
Embedding
\x12\r\n\x05
label
\x18\x01
\x01
(
\t\x12\x0f\n\x07
vectors
\x18\x02
\x03
(
\x02\x1a
<
\n\n
Embeddings
\x12
.
\n\n
embeddings
\x18\x01
\x03
(
\x0b\x32\x1a
.visualdl.Record.Embedding
\x1a\x43\n\x10\x62
ytes_embeddings
\x12\x16\n\x0e\x65
ncoded_labels
\x18\x01
\x01
(
\x0c\x12\x17\n\x0f\x65
ncoded_vectors
\x18\x02
\x01
(
\x0c\x1a\x34\n\t
Histogram
\x12\x10\n\x04
hist
\x18\x01
\x03
(
\x01\x42\x02\x10\x01\x12\x15\n\t
bin_edges
\x18\x02
\x03
(
\x01\x42\x02\x10\x01\x1a\x87\x02\n\x05
Value
\x12\n\n\x02
id
\x18\x01
\x01
(
\x03\x12\x0b\n\x03
tag
\x18\x02
\x01
(
\t\x12\x11\n\t
timestamp
\x18\x03
\x01
(
\x03\x12\x0f\n\x05
value
\x18\x04
\x01
(
\x02
H
\x00\x12\'\n\x05
image
\x18\x05
\x01
(
\x0b\x32\x16
.visualdl.Record.ImageH
\x00\x12\'\n\x05\x61
udio
\x18\x06
\x01
(
\x0b\x32\x16
.visualdl.Record.AudioH
\x00\x12\x31\n\n
embeddings
\x18\x07
\x01
(
\x0b\x32\x1b
.visualdl.Record.EmbeddingsH
\x00\x12
/
\n\t
histogram
\x18\x08
\x01
(
\x0b\x32\x1a
.visualdl.Record.Histogram
H
\x00\x42\x0b\n\t
one_valueb
\x06
proto3'
serialized_pb
=
b
'
\n\x0c
record.proto
\x12\x08
visualdl
\"\x
e2\x06\n\x06
Record
\x12
&
\n\x06
values
\x18\x01
\x03
(
\x0b\x32\x16
.visualdl.Record.Value
\x1a
%
\n\x05
Image
\x12\x1c\n\x14\x65
ncoded_image_string
\x18\x04
\x01
(
\x0c\x1a
}
\n\x05\x41
udio
\x12\x13\n\x0b
sample_rate
\x18\x01
\x01
(
\x02\x12\x14\n\x0c
num_channels
\x18\x02
\x01
(
\x03\x12\x15\n\r
length_frames
\x18\x03
\x01
(
\x03\x12\x1c\n\x14\x65
ncoded_audio_string
\x18\x04
\x01
(
\x0c\x12\x14\n\x0c\x63
ontent_type
\x18\x05
\x01
(
\t\x1a
+
\n\t
Embedding
\x12\r\n\x05
label
\x18\x01
\x01
(
\t\x12\x0f\n\x07
vectors
\x18\x02
\x03
(
\x02\x1a
<
\n\n
Embeddings
\x12
.
\n\n
embeddings
\x18\x01
\x03
(
\x0b\x32\x1a
.visualdl.Record.Embedding
\x1a\x43\n\x10\x62
ytes_embeddings
\x12\x16\n\x0e\x65
ncoded_labels
\x18\x01
\x01
(
\x0c\x12\x17\n\x0f\x65
ncoded_vectors
\x18\x02
\x01
(
\x0c\x1a\x34\n\t
Histogram
\x12\x10\n\x04
hist
\x18\x01
\x03
(
\x01\x42\x02\x10\x01\x12\x15\n\t
bin_edges
\x18\x02
\x03
(
\x01\x42\x02\x10\x01\x1a
l
\n\x07
PRCurve
\x12\x0e\n\x02
TP
\x18\x01
\x03
(
\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46
P
\x18\x02
\x03
(
\x03\x42\x02\x10\x01\x12\x0e\n\x02
TN
\x18\x03
\x03
(
\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46
N
\x18\x04
\x03
(
\x03\x42\x02\x10\x01\x12\x11\n\t
precision
\x18\x05
\x03
(
\x01\x12\x0e\n\x06
recall
\x18\x06
\x03
(
\x01\x1a\xb5\x02\n\x05
Value
\x12\n\n\x02
id
\x18\x01
\x01
(
\x03\x12\x0b\n\x03
tag
\x18\x02
\x01
(
\t\x12\x11\n\t
timestamp
\x18\x03
\x01
(
\x03\x12\x0f\n\x05
value
\x18\x04
\x01
(
\x02
H
\x00\x12\'\n\x05
image
\x18\x05
\x01
(
\x0b\x32\x16
.visualdl.Record.ImageH
\x00\x12\'\n\x05\x61
udio
\x18\x06
\x01
(
\x0b\x32\x16
.visualdl.Record.AudioH
\x00\x12\x31\n\n
embeddings
\x18\x07
\x01
(
\x0b\x32\x1b
.visualdl.Record.EmbeddingsH
\x00\x12
/
\n\t
histogram
\x18\x08
\x01
(
\x0b\x32\x1a
.visualdl.Record.HistogramH
\x00\x12
,
\n\x08
pr_curve
\x18\t
\x01
(
\x0b\x32\x18
.visualdl.Record.PRCurve
H
\x00\x42\x0b\n\t
one_valueb
\x06
proto3'
)
...
...
@@ -253,6 +253,71 @@ _RECORD_HISTOGRAM = _descriptor.Descriptor(
serialized_end
=
471
,
)
_RECORD_PRCURVE
=
_descriptor
.
Descriptor
(
name
=
'PRCurve'
,
full_name
=
'visualdl.Record.PRCurve'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'TP'
,
full_name
=
'visualdl.Record.PRCurve.TP'
,
index
=
0
,
number
=
1
,
type
=
3
,
cpp_type
=
2
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
b
'
\020\001
'
,
file
=
DESCRIPTOR
),
_descriptor
.
FieldDescriptor
(
name
=
'FP'
,
full_name
=
'visualdl.Record.PRCurve.FP'
,
index
=
1
,
number
=
2
,
type
=
3
,
cpp_type
=
2
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
b
'
\020\001
'
,
file
=
DESCRIPTOR
),
_descriptor
.
FieldDescriptor
(
name
=
'TN'
,
full_name
=
'visualdl.Record.PRCurve.TN'
,
index
=
2
,
number
=
3
,
type
=
3
,
cpp_type
=
2
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
b
'
\020\001
'
,
file
=
DESCRIPTOR
),
_descriptor
.
FieldDescriptor
(
name
=
'FN'
,
full_name
=
'visualdl.Record.PRCurve.FN'
,
index
=
3
,
number
=
4
,
type
=
3
,
cpp_type
=
2
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
b
'
\020\001
'
,
file
=
DESCRIPTOR
),
_descriptor
.
FieldDescriptor
(
name
=
'precision'
,
full_name
=
'visualdl.Record.PRCurve.precision'
,
index
=
4
,
number
=
5
,
type
=
1
,
cpp_type
=
5
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
),
_descriptor
.
FieldDescriptor
(
name
=
'recall'
,
full_name
=
'visualdl.Record.PRCurve.recall'
,
index
=
5
,
number
=
6
,
type
=
1
,
cpp_type
=
5
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
serialized_options
=
None
,
is_extendable
=
False
,
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
473
,
serialized_end
=
581
,
)
_RECORD_VALUE
=
_descriptor
.
Descriptor
(
name
=
'Value'
,
full_name
=
'visualdl.Record.Value'
,
...
...
@@ -316,6 +381,13 @@ _RECORD_VALUE = _descriptor.Descriptor(
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
),
_descriptor
.
FieldDescriptor
(
name
=
'pr_curve'
,
full_name
=
'visualdl.Record.Value.pr_curve'
,
index
=
8
,
number
=
9
,
type
=
11
,
cpp_type
=
10
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
),
],
extensions
=
[
],
...
...
@@ -331,8 +403,8 @@ _RECORD_VALUE = _descriptor.Descriptor(
name
=
'one_value'
,
full_name
=
'visualdl.Record.Value.one_value'
,
index
=
0
,
containing_type
=
None
,
fields
=
[]),
],
serialized_start
=
47
4
,
serialized_end
=
737
,
serialized_start
=
58
4
,
serialized_end
=
893
,
)
_RECORD
=
_descriptor
.
Descriptor
(
...
...
@@ -352,7 +424,7 @@ _RECORD = _descriptor.Descriptor(
],
extensions
=
[
],
nested_types
=
[
_RECORD_IMAGE
,
_RECORD_AUDIO
,
_RECORD_EMBEDDING
,
_RECORD_EMBEDDINGS
,
_RECORD_BYTES_EMBEDDINGS
,
_RECORD_HISTOGRAM
,
_RECORD_VALUE
,
],
nested_types
=
[
_RECORD_IMAGE
,
_RECORD_AUDIO
,
_RECORD_EMBEDDING
,
_RECORD_EMBEDDINGS
,
_RECORD_BYTES_EMBEDDINGS
,
_RECORD_HISTOGRAM
,
_RECORD_
PRCURVE
,
_RECORD_
VALUE
,
],
enum_types
=
[
],
serialized_options
=
None
,
...
...
@@ -362,7 +434,7 @@ _RECORD = _descriptor.Descriptor(
oneofs
=
[
],
serialized_start
=
27
,
serialized_end
=
737
,
serialized_end
=
893
,
)
_RECORD_IMAGE
.
containing_type
=
_RECORD
...
...
@@ -372,10 +444,12 @@ _RECORD_EMBEDDINGS.fields_by_name['embeddings'].message_type = _RECORD_EMBEDDING
_RECORD_EMBEDDINGS
.
containing_type
=
_RECORD
_RECORD_BYTES_EMBEDDINGS
.
containing_type
=
_RECORD
_RECORD_HISTOGRAM
.
containing_type
=
_RECORD
_RECORD_PRCURVE
.
containing_type
=
_RECORD
_RECORD_VALUE
.
fields_by_name
[
'image'
].
message_type
=
_RECORD_IMAGE
_RECORD_VALUE
.
fields_by_name
[
'audio'
].
message_type
=
_RECORD_AUDIO
_RECORD_VALUE
.
fields_by_name
[
'embeddings'
].
message_type
=
_RECORD_EMBEDDINGS
_RECORD_VALUE
.
fields_by_name
[
'histogram'
].
message_type
=
_RECORD_HISTOGRAM
_RECORD_VALUE
.
fields_by_name
[
'pr_curve'
].
message_type
=
_RECORD_PRCURVE
_RECORD_VALUE
.
containing_type
=
_RECORD
_RECORD_VALUE
.
oneofs_by_name
[
'one_value'
].
fields
.
append
(
_RECORD_VALUE
.
fields_by_name
[
'value'
])
...
...
@@ -392,6 +466,9 @@ _RECORD_VALUE.fields_by_name['embeddings'].containing_oneof = _RECORD_VALUE.oneo
_RECORD_VALUE
.
oneofs_by_name
[
'one_value'
].
fields
.
append
(
_RECORD_VALUE
.
fields_by_name
[
'histogram'
])
_RECORD_VALUE
.
fields_by_name
[
'histogram'
].
containing_oneof
=
_RECORD_VALUE
.
oneofs_by_name
[
'one_value'
]
_RECORD_VALUE
.
oneofs_by_name
[
'one_value'
].
fields
.
append
(
_RECORD_VALUE
.
fields_by_name
[
'pr_curve'
])
_RECORD_VALUE
.
fields_by_name
[
'pr_curve'
].
containing_oneof
=
_RECORD_VALUE
.
oneofs_by_name
[
'one_value'
]
_RECORD
.
fields_by_name
[
'values'
].
message_type
=
_RECORD_VALUE
DESCRIPTOR
.
message_types_by_name
[
'Record'
]
=
_RECORD
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
...
...
@@ -440,6 +517,13 @@ Record = _reflection.GeneratedProtocolMessageType('Record', (_message.Message,),
})
,
'PRCurve'
:
_reflection
.
GeneratedProtocolMessageType
(
'PRCurve'
,
(
_message
.
Message
,),
{
'DESCRIPTOR'
:
_RECORD_PRCURVE
,
'__module__'
:
'record_pb2'
# @@protoc_insertion_point(class_scope:visualdl.Record.PRCurve)
})
,
'Value'
:
_reflection
.
GeneratedProtocolMessageType
(
'Value'
,
(
_message
.
Message
,),
{
'DESCRIPTOR'
:
_RECORD_VALUE
,
'__module__'
:
'record_pb2'
...
...
@@ -457,9 +541,14 @@ _sym_db.RegisterMessage(Record.Embedding)
_sym_db
.
RegisterMessage
(
Record
.
Embeddings
)
_sym_db
.
RegisterMessage
(
Record
.
bytes_embeddings
)
_sym_db
.
RegisterMessage
(
Record
.
Histogram
)
_sym_db
.
RegisterMessage
(
Record
.
PRCurve
)
_sym_db
.
RegisterMessage
(
Record
.
Value
)
_RECORD_HISTOGRAM
.
fields_by_name
[
'hist'
].
_options
=
None
_RECORD_HISTOGRAM
.
fields_by_name
[
'bin_edges'
].
_options
=
None
_RECORD_PRCURVE
.
fields_by_name
[
'TP'
].
_options
=
None
_RECORD_PRCURVE
.
fields_by_name
[
'FP'
].
_options
=
None
_RECORD_PRCURVE
.
fields_by_name
[
'TN'
].
_options
=
None
_RECORD_PRCURVE
.
fields_by_name
[
'FN'
].
_options
=
None
# @@protoc_insertion_point(module_scope)
visualdl/reader/reader.py
浏览文件 @
c6d4cdd1
...
...
@@ -106,6 +106,8 @@ class LogReader(object):
component
=
"audio"
elif
"histogram"
==
value_type
:
component
=
"histogram"
elif
"pr_curve"
==
value_type
:
component
=
"pr_curve"
else
:
raise
TypeError
(
"Invalid value type `%s`."
%
value_type
)
self
.
_tags
[
path
]
=
component
...
...
visualdl/server/api.py
浏览文件 @
c6d4cdd1
...
...
@@ -109,6 +109,10 @@ class Api(object):
def
embeddings_tags
(
self
):
return
self
.
_get_with_retry
(
'data/plugin/embeddings/tags'
,
lib
.
get_embeddings_tags
)
@
result
()
def
pr_curve_tags
(
self
):
return
self
.
_get_with_retry
(
'data/plugin/pr_curves/tags'
,
lib
.
get_pr_curve_tags
)
@
result
()
def
scalars_list
(
self
,
run
,
tag
):
key
=
os
.
path
.
join
(
'data/plugin/scalars/scalars'
,
run
,
tag
)
...
...
@@ -151,6 +155,16 @@ class Api(object):
key
=
os
.
path
.
join
(
'data/plugin/histogram/histogram'
,
run
,
tag
)
return
self
.
_get_with_retry
(
key
,
lib
.
get_histogram
,
run
,
tag
)
@
result
()
def
pr_curves_pr_curve
(
self
,
run
,
tag
):
key
=
os
.
path
.
join
(
'data/plugin/pr_curves/pr_curve'
,
run
,
tag
)
return
self
.
_get_with_retry
(
key
,
lib
.
get_pr_curve
,
run
,
tag
)
@
result
()
def
pr_curves_steps
(
self
,
run
):
key
=
os
.
path
.
join
(
'data/plugin/pr_curves/steps'
,
run
)
return
self
.
_get_with_retry
(
key
,
lib
.
get_pr_curve_step
,
run
)
@
result
(
'application/octet-stream'
,
lambda
s
:
{
"Content-Disposition"
:
'attachment; filename="%s"'
%
s
.
model_name
}
if
len
(
s
.
model_name
)
else
None
)
def
graphs_graph
(
self
):
key
=
os
.
path
.
join
(
'data/plugin/graphs/graph'
)
...
...
@@ -169,6 +183,7 @@ def create_api_call(logdir, model, cache_timeout):
'audio/tags'
:
(
api
.
audio_tags
,
[]),
'embeddings/tags'
:
(
api
.
embeddings_tags
,
[]),
'histogram/tags'
:
(
api
.
histogram_tags
,
[]),
'pr-curve/tags'
:
(
api
.
pr_curve_tags
,
[]),
'scalars/list'
:
(
api
.
scalars_list
,
[
'run'
,
'tag'
]),
'images/list'
:
(
api
.
images_list
,
[
'run'
,
'tag'
]),
'images/image'
:
(
api
.
images_image
,
[
'run'
,
'tag'
,
'index'
]),
...
...
@@ -176,7 +191,9 @@ def create_api_call(logdir, model, cache_timeout):
'audio/audio'
:
(
api
.
audio_audio
,
[
'run'
,
'tag'
,
'index'
]),
'embeddings/embedding'
:
(
api
.
embeddings_embedding
,
[
'run'
,
'tag'
,
'reduction'
,
'dimension'
]),
'histogram/list'
:
(
api
.
histogram_list
,
[
'run'
,
'tag'
]),
'graphs/graph'
:
(
api
.
graphs_graph
,
[])
'graphs/graph'
:
(
api
.
graphs_graph
,
[]),
'pr-curve/list'
:
(
api
.
pr_curves_pr_curve
,
[
'run'
,
'tag'
]),
'pr-curve/steps'
:
(
api
.
pr_curves_steps
,
[
'run'
])
}
def
call
(
path
:
str
,
args
):
...
...
visualdl/server/data_manager.py
浏览文件 @
c6d4cdd1
...
...
@@ -23,7 +23,8 @@ DEFAULT_PLUGIN_MAXSIZE = {
"image"
:
10
,
"histogram"
:
100
,
"embeddings"
:
50000
,
"audio"
:
10
"audio"
:
10
,
"pr_curve"
:
300
}
...
...
@@ -274,7 +275,9 @@ class DataManager(object):
"embeddings"
:
Reservoir
(
max_size
=
DEFAULT_PLUGIN_MAXSIZE
[
"embeddings"
]),
"audio"
:
Reservoir
(
max_size
=
DEFAULT_PLUGIN_MAXSIZE
[
"audio"
])
Reservoir
(
max_size
=
DEFAULT_PLUGIN_MAXSIZE
[
"audio"
]),
"pr_curve"
:
Reservoir
(
max_size
=
DEFAULT_PLUGIN_MAXSIZE
[
"pr_curve"
])
}
self
.
_mutex
=
threading
.
Lock
()
...
...
visualdl/server/lib.py
浏览文件 @
c6d4cdd1
...
...
@@ -114,6 +114,40 @@ def get_histogram_tags(log_reader):
return
get_logs
(
log_reader
,
"histogram"
)
def
get_pr_curve_tags
(
log_reader
):
return
get_logs
(
log_reader
,
"pr_curve"
)
def
get_pr_curve
(
log_reader
,
run
,
tag
):
log_reader
.
load_new_data
()
records
=
log_reader
.
data_manager
.
get_reservoir
(
"pr_curve"
).
get_items
(
run
,
decode_tag
(
tag
))
results
=
[]
for
item
in
records
:
pr_curve
=
item
.
pr_curve
length
=
len
(
pr_curve
.
precision
)
num_thresholds
=
[
float
(
v
)
/
length
for
v
in
range
(
1
,
length
+
1
)]
results
.
append
([
item
.
timestamp
,
item
.
id
,
list
(
pr_curve
.
precision
),
list
(
pr_curve
.
recall
),
list
(
pr_curve
.
TP
),
list
(
pr_curve
.
FP
),
list
(
pr_curve
.
TN
),
list
(
pr_curve
.
FN
),
num_thresholds
])
return
results
def
get_pr_curve_step
(
log_reader
,
run
,
tag
=
None
):
tag
=
get_pr_curve_tags
(
log_reader
)[
run
][
0
]
if
tag
is
None
else
tag
log_reader
.
load_new_data
()
records
=
log_reader
.
data_manager
.
get_reservoir
(
"pr_curve"
).
get_items
(
run
,
decode_tag
(
tag
))
results
=
[[
item
.
timestamp
,
item
.
id
]
for
item
in
records
]
return
results
def
get_embeddings
(
log_reader
,
run
,
tag
,
reduction
,
dimension
=
2
):
log_reader
.
load_new_data
()
records
=
log_reader
.
data_manager
.
get_reservoir
(
"embeddings"
).
get_items
(
...
...
visualdl/writer/writer.py
浏览文件 @
c6d4cdd1
...
...
@@ -14,9 +14,9 @@
# =======================================================================
import
os
import
time
from
visualdl.writer.record_writer
import
RecordFileWriter
from
visualdl.component.base_component
import
scalar
,
image
,
embedding
,
audio
,
histogram
import
numpy
as
np
from
visualdl.writer.record_writer
import
RecordFileWriter
from
visualdl.component.base_component
import
scalar
,
image
,
embedding
,
audio
,
histogram
,
pr_curve
class
DummyFileWriter
(
object
):
...
...
@@ -281,6 +281,50 @@ class LogWriter(object):
step
=
step
,
walltime
=
walltime
))
def
add_pr_curve
(
self
,
tag
,
labels
,
predictions
,
step
,
num_thresholds
=
10
,
weights
=
None
,
walltime
=
None
):
"""Add an precision-recall curve to vdl record file.
Args:
tag (string): Data identifier
labels (numpy.ndarray or list): Binary labels for each element.
predictions (numpy.ndarray or list): The probability that an element
be classified as true.
step (int): Step of pr curve.
weights (float): Multiple of data to display on the curve.
num_thresholds (int): Number of thresholds used to draw the curve.
walltime (int): Wall time of pr curve.
Example:
with LogWriter(logdir="./log/pr_curve_test/train") as writer:
for index in range(3):
labels = np.random.randint(2, size=100)
predictions = np.random.rand(100)
writer.add_pr_curve(tag='default',
labels=labels,
predictions=predictions,
step=index)
"""
if
'%'
in
tag
:
raise
RuntimeError
(
"% can't appear in tag!"
)
walltime
=
round
(
time
.
time
())
if
walltime
is
None
else
walltime
self
.
_get_file_writer
().
add_record
(
pr_curve
(
tag
=
tag
,
labels
=
labels
,
predictions
=
predictions
,
step
=
step
,
walltime
=
walltime
,
num_thresholds
=
num_thresholds
,
weights
=
weights
))
def
flush
(
self
):
"""Flush all data in cache to disk.
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录