Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
96f1edf7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
96f1edf7
编写于
6月 21, 2018
作者:
X
Xin Pan
提交者:
GitHub
6月 21, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #11629 from reyoung/hotfix/more_api_reference_docs
Cherry Pick the documentation PRs.
上级
49080ac9
35eb0112
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
480 addition
and
117 deletion
+480
-117
python/paddle/fluid/evaluator.py
python/paddle/fluid/evaluator.py
+74
-32
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+94
-19
python/paddle/fluid/layers/__init__.py
python/paddle/fluid/layers/__init__.py
+3
-3
python/paddle/fluid/layers/metric_op.py
python/paddle/fluid/layers/metric_op.py
+1
-1
python/paddle/fluid/metrics.py
python/paddle/fluid/metrics.py
+308
-62
未找到文件。
python/paddle/fluid/evaluator.py
浏览文件 @
96f1edf7
...
...
@@ -41,7 +41,12 @@ def _clone_var_(block, var):
class
Evaluator
(
object
):
"""
Base Class for all evaluators
Warning: better to use the fluid.metrics.* things, more
flexible support via pure Python and Operator, and decoupled
with executor. Short doc are intended to urge new user
start from Metrics.
Base Class for all evaluators.
Args:
name(str): The name of evaluator. such as, "accuracy". Used for generate
...
...
@@ -69,6 +74,10 @@ class Evaluator(object):
def
reset
(
self
,
executor
,
reset_program
=
None
):
"""
reset metric states at the begin of each pass/user specified batch
Args:
executor(Executor|ParallelExecutor): a executor for executing the reset_program
reset_program(Program): a single Program for reset process
"""
if
reset_program
is
None
:
reset_program
=
Program
()
...
...
@@ -85,15 +94,16 @@ class Evaluator(object):
def
eval
(
self
,
executor
,
eval_program
=
None
):
"""
Evaluate the statistics merged by multiple mini-batches.
Args:
executor(Executor|ParallelExecutor): a executor for executing the eval_program
eval_program(Program): a single Program for eval process
"""
raise
NotImplementedError
()
def
create_state
(
self
,
suffix
,
dtype
,
shape
):
def
_
create_state
(
self
,
suffix
,
dtype
,
shape
):
"""
Create state variable.
NOTE: It is not a public API.
Args:
suffix(str): the state suffix.
dtype(str|core.VarDesc.VarType): the state data type
...
...
@@ -113,9 +123,35 @@ class Evaluator(object):
class
ChunkEvaluator
(
Evaluator
):
"""
Warning: This would be deprecated in the future. Please use fluid.metrics.ChunkEvaluator
instead.
Accumulate counter numbers output by chunk_eval from mini-batches and
compute the precision recall and F1-score using the accumulated counter
numbers.
For some basics of chunking, please refer to
'Chunking with Support Vector Machines <https://aclanthology.info/pdf/N/N01/N01-1025.pdf>'.
Args:
input (Variable): prediction output of the network.
label (Variable): label of the test data set.
chunk_scheme (str): can be IOB/IOE/IOBES and IO. See the chunk_eval op for details.
num_chunk_types (int): the number of chunk type.
excluded_chunk_types (list): A list including chunk type ids, indicating chunk types that are not counted.
Returns:
tuple: tuple containing: precision, recall, f1_score
Examples:
.. code-block:: python
exe = fluid.executor(place)
evaluator = fluid.Evaluator.ChunkEvaluator(input, label)
for epoch in PASS_NUM:
evaluator.reset(exe)
for data in batches:
loss = exe.run(fetch_list=[cost])
distance, instance_error = distance_evaluator.eval(exe)
"""
def
__init__
(
...
...
@@ -130,11 +166,11 @@ class ChunkEvaluator(Evaluator):
if
main_program
.
current_block
().
idx
!=
0
:
raise
ValueError
(
"You can only invoke Evaluator in root block"
)
self
.
num_infer_chunks
=
self
.
create_state
(
self
.
num_infer_chunks
=
self
.
_
create_state
(
dtype
=
'int64'
,
shape
=
[
1
],
suffix
=
'num_infer_chunks'
)
self
.
num_label_chunks
=
self
.
create_state
(
self
.
num_label_chunks
=
self
.
_
create_state
(
dtype
=
'int64'
,
shape
=
[
1
],
suffix
=
'num_label_chunks'
)
self
.
num_correct_chunks
=
self
.
create_state
(
self
.
num_correct_chunks
=
self
.
_
create_state
(
dtype
=
'int64'
,
shape
=
[
1
],
suffix
=
'num_correct_chunks'
)
precision
,
recall
,
f1_score
,
num_infer_chunks
,
num_label_chunks
,
num_correct_chunks
=
layers
.
chunk_eval
(
input
=
input
,
...
...
@@ -178,6 +214,8 @@ class ChunkEvaluator(Evaluator):
class
EditDistance
(
Evaluator
):
"""
Warning: This would be deprecated in the future. Please use fluid.metrics.EditDistance
instead.
Accumulate edit distance sum and sequence number from mini-batches and
compute the average edit_distance and instance error of all batches.
...
...
@@ -188,15 +226,16 @@ class EditDistance(Evaluator):
ignored_tokens(list of int): Tokens that should be removed before
calculating edit distance.
Example:
Examples:
.. code-block:: python
exe = fluid.executor(place)
distance_evaluator = fluid.Evaluator.EditDistance(input, label)
for epoch in PASS_NUM:
distance_evaluator.reset(exe)
for data in batches:
loss = exe.run(fetch_list=[cost])
distance, instance_error = distance_evaluator.eval(exe)
exe = fluid.executor(place)
distance_evaluator = fluid.Evaluator.EditDistance(input, label)
for epoch in PASS_NUM:
distance_evaluator.reset(exe)
for data in batches:
loss = exe.run(fetch_list=[cost])
distance, instance_error = distance_evaluator.eval(exe)
In the above example:
'distance' is the average of the edit distance in a pass.
...
...
@@ -210,11 +249,11 @@ class EditDistance(Evaluator):
if
main_program
.
current_block
().
idx
!=
0
:
raise
ValueError
(
"You can only invoke Evaluator in root block"
)
self
.
total_distance
=
self
.
create_state
(
self
.
total_distance
=
self
.
_
create_state
(
dtype
=
'float32'
,
shape
=
[
1
],
suffix
=
'total_distance'
)
self
.
seq_num
=
self
.
create_state
(
self
.
seq_num
=
self
.
_
create_state
(
dtype
=
'int64'
,
shape
=
[
1
],
suffix
=
'seq_num'
)
self
.
instance_error
=
self
.
create_state
(
self
.
instance_error
=
self
.
_
create_state
(
dtype
=
'int64'
,
shape
=
[
1
],
suffix
=
'instance_error'
)
distances
,
seq_num
=
layers
.
edit_distance
(
input
=
input
,
label
=
label
,
ignored_tokens
=
ignored_tokens
)
...
...
@@ -256,9 +295,10 @@ class EditDistance(Evaluator):
class
DetectionMAP
(
Evaluator
):
"""
Warning: This would be deprecated in the future. Please use fluid.metrics.DetectionMAP
instead.
Calculate the detection mean average precision (mAP).
TODO (Dang Qingqing): update the following doc.
The general steps are as follows:
1. calculate the true positive and false positive according to the input
of detection and labels.
...
...
@@ -293,17 +333,18 @@ class DetectionMAP(Evaluator):
- 11point: the 11-point interpolated average precision.
- integral: the natural integral of the precision-recall curve.
Example:
Examples:
.. code-block:: python
exe = fluid.executor(place)
map_evaluator = fluid.Evaluator.DetectionMAP(input,
gt_label, gt_box, gt_difficult)
cur_map, accum_map = map_evaluator.get_map_var()
fetch = [cost, cur_map, accum_map]
for epoch in PASS_NUM:
map_evaluator.reset(exe)
for data in batches:
loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch)
exe = fluid.executor(place)
map_evaluator = fluid.Evaluator.DetectionMAP(input,
gt_label, gt_box, gt_difficult)
cur_map, accum_map = map_evaluator.get_map_var()
fetch = [cost, cur_map, accum_map]
for epoch in PASS_NUM:
map_evaluator.reset(exe)
for data in batches:
loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch)
In the above example:
...
...
@@ -340,9 +381,10 @@ class DetectionMAP(Evaluator):
evaluate_difficult
=
evaluate_difficult
,
ap_version
=
ap_version
)
self
.
create_state
(
dtype
=
'int32'
,
shape
=
None
,
suffix
=
'accum_pos_count'
)
self
.
create_state
(
dtype
=
'float32'
,
shape
=
None
,
suffix
=
'accum_true_pos'
)
self
.
create_state
(
dtype
=
'float32'
,
shape
=
None
,
suffix
=
'accum_false_pos'
)
self
.
_create_state
(
dtype
=
'int32'
,
shape
=
None
,
suffix
=
'accum_pos_count'
)
self
.
_create_state
(
dtype
=
'float32'
,
shape
=
None
,
suffix
=
'accum_true_pos'
)
self
.
_create_state
(
dtype
=
'float32'
,
shape
=
None
,
suffix
=
'accum_false_pos'
)
self
.
has_state
=
None
var
=
self
.
helper
.
create_variable
(
...
...
python/paddle/fluid/executor.py
浏览文件 @
96f1edf7
...
...
@@ -18,7 +18,7 @@ from framework import Program, default_main_program, Variable
from
.
import
core
__all__
=
[
'Executor'
,
'global_scope'
,
'scope_guard'
,
'switch_scope'
,
'fetch_var'
'Executor'
,
'global_scope'
,
'scope_guard'
,
'
_
switch_scope'
,
'fetch_var'
]
g_scope
=
core
.
Scope
()
...
...
@@ -35,7 +35,7 @@ def global_scope():
return
g_scope
def
switch_scope
(
scope
):
def
_
switch_scope
(
scope
):
global
g_scope
ex
=
g_scope
g_scope
=
scope
...
...
@@ -57,12 +57,27 @@ def scope_guard(scope):
Args:
scope: The new global/default scope.
"""
ex
=
switch_scope
(
scope
)
ex
=
_
switch_scope
(
scope
)
yield
switch_scope
(
ex
)
_
switch_scope
(
ex
)
def
as_numpy
(
tensor
):
"""
Convert a Tensor to a numpy.ndarray, its only support Tensor without LoD information.
For higher dimensional sequence data, please use LoDTensor directly.
Examples:
>>> import paddle.fluid as fluid
>>> outs = executor.run(...)
>>> np_outs = map(lambda x: as_numpy(x), outs)
>>> ...
Args:
tensor(Variable): a instance of Tensor
Returns:
numpy.ndarray
"""
if
isinstance
(
tensor
,
list
):
return
[
as_numpy
(
t
)
for
t
in
tensor
]
assert
isinstance
(
tensor
,
core
.
LoDTensor
)
...
...
@@ -186,7 +201,7 @@ def fetch_var(name, scope=None, return_numpy=True):
return
tensor
def
get_program_cache_key
(
feed
,
fetch_list
):
def
_
get_program_cache_key
(
feed
,
fetch_list
):
feed_var_names
=
feed
.
keys
()
def
to_name_str
(
var
):
...
...
@@ -205,6 +220,25 @@ def get_program_cache_key(feed, fetch_list):
class
Executor
(
object
):
"""
An Executor in Python, only support the single-GPU running. For multi-cards, please refer to
ParallelExecutor.
Python executor takes a program, add feed operators and fetch operators to this program according
to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
the variables(or names) that user want to get after program run. Note: the executor will run all
operators in the program but not only the operators dependent by the fetch_list.
It store the global variables into the global scope, and create a local scope for the temporary
variables. The local scope contents will be discarded after every minibatch forward/backward finished.
But the global scope variables will be persistent through different runs.
All of ops in program will be running in sequence.
Args:
place(core.CPUPlace|core.CUDAPlace(n)): indicate the executor run on which device
Note: For debugging complicated network in parallel-GPUs, you can test it on the executor.
They has the exactly same arguments, and expected the same results.
"""
def
__init__
(
self
,
place
):
self
.
place
=
place
p
=
core
.
Place
()
...
...
@@ -213,6 +247,23 @@ class Executor(object):
self
.
program_caches
=
dict
()
def
as_lodtensor
(
self
,
data
):
"""
Convert numpy.ndarray to Tensor, its only support Tensor without LoD information.
For higher dimensional sequence data, please use LoDTensor directly.
Examples:
>>> import paddle.fluid as fluid
>>> exe = fluid.executor(fluid.CPUPlace())
>>> data = np.array(size=(100, 200, 300))
>>> np_outs = map(lambda x: exe.as_lodtensor(x), data)
>>> ...
Args:
data(numpy.ndarray): a instance of array
Returns:
LoDTensor
"""
if
isinstance
(
data
,
list
):
raise
RuntimeError
(
"Some of your feed data hold LoD information.
\
They can not be completely cast from a list of Python
\
...
...
@@ -304,23 +355,47 @@ class Executor(object):
scope
=
None
,
return_numpy
=
True
,
use_program_cache
=
False
):
"""
Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
"""
Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according
to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
the variables(or names) that user want to get after program run. Note: the executor will run all
the variables(or names) that user want to get after program run.
Note: the executor will run all
operators in the program but not only the operators dependent by the fetch_list
:param program: the program that need to run, if not provied, then default_main_program will be used.
:param feed: feed variable map, e.g. {"image": ImageData, "label": LableData}
:param fetch_list: a list of variable or variable names that user want to get, run will return them according
to this list.
:param feed_var_name: the name for the input variable of feed Operator.
:param fetch_var_name: the name for the output variable of feed Operator.
:param scope: the scope used to run this program, you can switch it to different scope. default is global_scope
:param return_numpy: if convert the fetched tensor to numpy
:param use_program_cache: set use_program_cache to true if program not changed compare to the last step.
:return: result according to fetch_list.
Args:
program(Program): the program that need to run, if not provied, then default_main_program will be used.
feed(dict): feed variable map, e.g. {"image": ImageData, "label": LableData}
fetch_list(list): a list of variable or variable names that user want to get, run will return them according to this list.
feed_var_name(str): the name for the input variable of feed Operator.
fetch_var_name(str): the name for the output variable of fetch Operator.
scope(Scope): the scope used to run this program, you can switch it to different scope. default is global_scope
return_numpy(bool): if convert the fetched tensor to numpy
use_program_cache(bool): set use_program_cache to true if program not changed compare to the last step.
Returns:
list(numpy.array): fetch result according to fetch_list.
Examples:
>>> data = layers.data(name='X', shape=[1], dtype='float32')
>>> hidden = layers.fc(input=data, size=10)
>>> layers.assign(hidden, out)
>>> loss = layers.mean(out)
>>> adam = fluid.optimizer.Adam()
>>> adam.minimize(loss)
>>> cpu = core.CPUPlace()
>>> exe = Executor(cpu)
>>> exe.run(default_startup_program())
>>> x = numpy.random.random(size=(10, 1)).astype('float32')
>>> outs = exe.run(
>>> feed={'X': x},
>>> fetch_list=[loss.name])
"""
if
feed
is
None
:
feed
=
{}
...
...
@@ -341,7 +416,7 @@ class Executor(object):
if
scope
is
None
:
scope
=
global_scope
()
cache_key
=
get_program_cache_key
(
feed
,
fetch_list
)
cache_key
=
_
get_program_cache_key
(
feed
,
fetch_list
)
if
use_program_cache
:
cached_program
=
self
.
_get_program_cache
(
cache_key
)
if
cached_program
is
None
:
...
...
python/paddle/fluid/layers/__init__.py
浏览文件 @
96f1edf7
...
...
@@ -28,8 +28,8 @@ import math_op_patch
from
math_op_patch
import
*
import
detection
from
detection
import
*
import
metric
from
metric
import
*
import
metric
_op
from
metric
_op
import
*
from
learning_rate_scheduler
import
*
__all__
=
[]
...
...
@@ -41,5 +41,5 @@ __all__ += control_flow.__all__
__all__
+=
ops
.
__all__
__all__
+=
device
.
__all__
__all__
+=
detection
.
__all__
__all__
+=
metric
.
__all__
__all__
+=
metric
_op
.
__all__
__all__
+=
learning_rate_scheduler
.
__all__
python/paddle/fluid/layers/metric.py
→
python/paddle/fluid/layers/metric
_op
.py
浏览文件 @
96f1edf7
...
...
@@ -126,7 +126,7 @@ def auc(input, label, curve='ROC', num_thresholds=200):
topk_out
,
topk_indices
=
nn
.
topk
(
input
,
k
=
k
)
auc_out
=
helper
.
create_tmp_variable
(
dtype
=
"float32"
)
helper
.
append_op
(
type
=
"a
ccuracy
"
,
type
=
"a
uc
"
,
inputs
=
{
"Out"
:
[
topk_out
],
"Indices"
:
[
topk_indices
],
...
...
python/paddle/fluid/metrics.py
浏览文件 @
96f1edf7
...
...
@@ -23,6 +23,8 @@ import warnings
__all__
=
[
'MetricBase'
,
'CompositeMetric'
,
'Precision'
,
'Recall'
,
'Accuracy'
,
'ChunkEvaluator'
,
'EditDistance'
,
...
...
@@ -46,33 +48,34 @@ def _is_number_or_matrix_(var):
class
MetricBase
(
object
):
"""
Base Class for all evaluators
Base Class for all Metrics.
MetricBase define a group of interfaces for the
model evaluation methods. Metrics accumulate metric states between
consecutive minibatches, at every minibatch, use update
interface to add current minibatch value to global states.
Use eval to compute accumative metric value from last reset()
or from scratch on.
If you need to custom a new metric, please inherit from MetricBase and
custom implementation.
Args:
name(str): The name of evaluator. such as, "accuracy". Used for generate
temporary variable name.
Interface:
Note(*) : the states is the attributes who not has _ prefix.
get_config(): print current states and configuration
reset(): clear the states. If the Metrics states type is not (int, float, np.ndarray),
Please override this method.
update(): update states at every minibatch
eval(): get metric evaluation in numpy type.
name(str): The name of metric instance. such as, "accuracy".
It needed if you want to distinct different metrics in a model.
"""
def
__init__
(
self
,
name
,
**
kwargs
):
def
__init__
(
self
,
name
):
self
.
_name
=
str
(
name
)
if
name
!=
None
else
self
.
__class__
.
__name__
self
.
_kwargs
=
kwargs
if
kwargs
!=
None
else
dict
()
self
.
reset
()
def
__str__
(
self
):
return
self
.
_name
def
reset
(
self
):
"""
states is the attributes who not has _ prefix.
reset the states of metrics.
reset clear the states of metrics. By default, the states
are the members who do not has _ prefix, reset set them to inital states.
If you violate the implicit name rule, please also custom the reset
interface.
"""
states
=
{
attr
:
value
...
...
@@ -90,61 +93,231 @@ class MetricBase(object):
setattr
(
self
,
attr
,
None
)
def
get_config
(
self
):
"""
Get the metric and current states.
The states are the members who do not has "_" prefix.
Args:
None
Returns:
dict: a dict of metric and states
"""
states
=
{
attr
:
value
for
attr
,
value
in
self
.
__dict__
.
iteritems
()
if
not
attr
.
startswith
(
"_"
)
}
config
=
copy
.
deepcopy
(
self
.
_kwargs
)
config
=
{}
config
.
update
({
"name"
:
self
.
_name
,
"states"
:
copy
.
deepcopy
(
states
)})
return
config
def
update
(
self
):
raise
NotImplementedError
()
def
update
(
self
,
preds
,
labels
):
"""
Updates the metric states at every minibatch.
One user can compute the minibatch metric via pure Python, or
via a c++ operator.
Args:
preds(numpy.array): the predictions of current minibatch
labels(numpy.array): the labels of current minibatch, if the label is one-hot
or soft-label, should custom the corresponding update rule.
"""
raise
NotImplementedError
(
"Should not use it directly, please extend it."
)
def
eval
(
self
):
raise
NotImplementedError
()
"""
Evalute the current metrics based the accumulated states.
Returns:
float|list(float)|numpy.array: the metrics via Python.
"""
raise
NotImplementedError
(
"Should not use it directly, please extend it."
)
class
CompositeMetric
(
MetricBase
):
"""
Comp
ute multiple metrics in each minibatch
.
Comp
osite multiple metrics in one instance
.
for example, merge F1, accuracy, recall into one Metric.
Examples:
.. code-block:: python
labels = fluid.layers.data(name="data", shape=[1], dtype="int32")
data = fluid.layers.data(name="data", shape=[32, 32], dtype="int32")
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
comp = fluid.metrics.CompositeMetric()
acc = fluid.metrics.Precision()
recall = fluid.metrics.Recall()
comp.add_metric(acc)
comp.add_metric(recall)
for pass in range(PASSES):
comp.reset()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
comp.update(preds=preds, labels=labels)
numpy_acc, numpy_recall = comp.eval()
"""
def
__init__
(
self
,
name
=
None
,
**
kwargs
):
super
(
CompositeMetric
,
self
).
__init__
(
name
,
kwargs
)
def
__init__
(
self
,
name
=
None
):
super
(
CompositeMetric
,
self
).
__init__
(
name
)
self
.
_metrics
=
[]
def
add_metric
(
self
,
metric
):
"""
add one metric instance to CompositeMetric.
Args:
metric: a instance of MetricBase.
"""
if
not
isinstance
(
metric
,
MetricBase
):
raise
ValueError
(
"SubMetric should be inherit from MetricBase."
)
self
.
_metrics
.
append
(
metric
)
def
update
(
self
,
preds
,
labels
):
"""
Update every metrics in sequence.
Args:
preds(numpy.array): the predictions of current minibatch
labels(numpy.array): the labels of current minibatch, if the label is one-hot
or soft-label, should custom the corresponding update rule.
"""
for
m
in
self
.
_metrics
:
ans
.
append
(
m
.
update
(
preds
,
labels
))
def
eval
(
self
):
"""
Evaluate every metrics in sequence.
Returns:
list(float|numpy.array): a list of metrics value in Python.
"""
ans
=
[]
for
m
in
self
.
_metrics
:
ans
.
append
(
m
.
eval
())
return
ans
class
Precision
(
MetricBase
):
"""
Precision (also called positive predictive value) is the fraction of
relevant instances among the retrieved instances.
https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers
Note Precision is different with Accuracy in binary classifiers.
accuracy = true positive / total instances
precision = true positive / all positive instance
Examples:
.. code-block:: python
metric = fluid.metrics.Precision()
for pass in range(PASSES):
metric.reset()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
metric.update(preds=preds, labels=labels)
numpy_precision = metric.eval()
"""
def
__init__
(
self
,
name
=
None
):
super
(
Precision
,
self
).
__init__
(
name
)
self
.
tp
=
0
# true positive
self
.
fp
=
0
# false positive
def
update
(
self
,
preds
,
labels
):
if
not
_is_numpy_
(
preds
):
raise
ValueError
(
"The 'preds' must be a numpy ndarray."
)
if
not
_is_numpy_
(
labels
):
raise
ValueError
(
"The 'labels' must be a numpy ndarray."
)
sample_num
=
labels
[
0
]
for
i
in
range
(
sample_num
):
pred
=
preds
[
i
].
astype
(
"int32"
)
label
=
labels
[
i
]
if
label
==
1
:
if
pred
==
label
:
self
.
tp
+=
1
else
:
self
.
fp
+=
1
def
eval
(
self
):
ap
=
self
.
tp
+
self
.
fp
return
float
(
self
.
tp
)
/
ap
if
ap
!=
0
else
.
0
class
Recall
(
MetricBase
):
"""
Recall (also known as sensitivity) is the fraction of
relevant instances that have been retrieved over the
total amount of relevant instances
https://en.wikipedia.org/wiki/Precision_and_recall
Examples:
.. code-block:: python
metric = fluid.metrics.Recall()
for pass in range(PASSES):
metric.reset()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
metric.update(preds=preds, labels=labels)
numpy_recall = metric.eval()
"""
def
__init__
(
self
,
name
=
None
):
super
(
Recall
,
self
).
__init__
(
name
)
self
.
tp
=
0
# true positive
self
.
fn
=
0
# false negtive
def
update
(
self
,
preds
,
labels
):
if
not
_is_numpy_
(
preds
):
raise
ValueError
(
"The 'preds' must be a numpy ndarray."
)
if
not
_is_numpy_
(
labels
):
raise
ValueError
(
"The 'labels' must be a numpy ndarray."
)
sample_num
=
labels
[
0
]
for
i
in
range
(
sample_num
):
pred
=
preds
[
i
].
astype
(
"int32"
)
label
=
labels
[
i
]
if
label
==
1
:
if
pred
==
label
:
self
.
tp
+=
1
else
:
if
pred
!=
label
:
self
.
fn
+=
1
def
eval
(
self
):
recall
=
self
.
tp
+
self
.
fn
return
float
(
self
.
tp
)
/
recall
if
recall
!=
0
else
.
0
class
Accuracy
(
MetricBase
):
"""
Accumulate the accuracy from minibatches and compute the average accuracy
for every pass.
https://en.wikipedia.org/wiki/Accuracy_and_precision
Args:
name: the metrics name
Example:
minibatch_accuracy = fluid.layers.accuracy(pred, label)
accuracy_evaluator = fluid.metrics.Accuracy()
for epoch in PASS_NUM:
accuracy_evaluator.reset()
for data in batches:
loss = exe.run(fetch_list=[cost, minibatch_accuracy])
accuracy_evaluator.update(value=minibatch_accuracy, weight=batches)
accuracy = accuracy_evaluator.eval()
Examples:
.. code-block:: python
labels = fluid.layers.data(name="data", shape=[1], dtype="int32")
data = fluid.layers.data(name="data", shape=[32, 32], dtype="int32")
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
minibatch_accuracy = fluid.layers.accuracy(pred, label)
accuracy_evaluator = fluid.metrics.Accuracy()
for pass in range(PASSES):
accuracy_evaluator.reset()
for data in train_reader():
batch_size = data[0]
loss = exe.run(fetch_list=[cost, minibatch_accuracy])
accuracy_evaluator.update(value=minibatch_accuracy, weight=batch_size)
numpy_acc = accuracy_evaluator.eval()
"""
def
__init__
(
self
,
name
=
None
):
...
...
@@ -153,6 +326,13 @@ class Accuracy(MetricBase):
self
.
weight
=
.
0
def
update
(
self
,
value
,
weight
):
"""
Update minibatch states.
Args:
value(float|numpy.array): accuracy of one minibatch.
weight(int|float): batch size.
"""
if
not
_is_number_or_matrix_
(
value
):
raise
ValueError
(
"The 'value' must be a number(int, float) or a numpy ndarray."
)
...
...
@@ -163,9 +343,8 @@ class Accuracy(MetricBase):
def
eval
(
self
):
if
self
.
weight
==
0
:
raise
ValueError
(
"There is no data in Accuracy Metrics. Please check layers.accuracy output has added to Accuracy."
)
raise
ValueError
(
"There is no data in Accuracy Metrics.
\
Please check layers.accuracy output has added to Accuracy."
)
return
self
.
value
/
self
.
weight
...
...
@@ -174,6 +353,25 @@ class ChunkEvaluator(MetricBase):
Accumulate counter numbers output by chunk_eval from mini-batches and
compute the precision recall and F1-score using the accumulated counter
numbers.
For some basics of chunking, please refer to
'Chunking with Support Vector Machines <https://aclanthology.info/pdf/N/N01/N01-1025.pdf>'.
ChunkEvalEvaluator computes the precision, recall, and F1-score of chunk detection,
and supports IOB, IOE, IOBES and IO (also known as plain) tagging schemes.
Examples:
.. code-block:: python
labels = fluid.layers.data(name="data", shape=[1], dtype="int32")
data = fluid.layers.data(name="data", shape=[32, 32], dtype="int32")
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks = layers.chunk_eval(
input=pred,
label=label)
metric = fluid.metrics.ChunkEvaluator()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
metric.update(num_infer_chunks, num_label_chunks, num_correct_chunks)
numpy_precision, numpy_recall, numpy_f1 = metric.eval()
"""
def
__init__
(
self
,
name
=
None
):
...
...
@@ -183,9 +381,17 @@ class ChunkEvaluator(MetricBase):
self
.
num_correct_chunks
=
0
def
update
(
self
,
num_infer_chunks
,
num_label_chunks
,
num_correct_chunks
):
"""
Update the states based on the layers.chunk_eval() ouputs.
Args:
num_infer_chunks(int|numpy.array): The number of chunks in Inference on the given minibatch.
num_label_chunks(int|numpy.array): The number of chunks in Label on the given mini-batch.
num_correct_chunks(int|float|numpy.array): The number of chunks both in Inference and Label on the
given mini-batch.
"""
if
not
_is_number_or_matrix_
(
num_infer_chunks
):
raise
ValueError
(
"The 'num_infer_chunks' must be a number(int
, float
) or a numpy ndarray."
"The 'num_infer_chunks' must be a number(int) or a numpy ndarray."
)
if
not
_is_number_or_matrix_
(
num_label_chunks
):
raise
ValueError
(
...
...
@@ -212,21 +418,28 @@ class ChunkEvaluator(MetricBase):
class
EditDistance
(
MetricBase
):
"""
Edit distance is a way of quantifying how dissimilar two strings
(e.g., words) are to one another by counting the minimum number
of operations required to transform one string into the other.
Refer to https://en.wikipedia.org/wiki/Edit_distance
Accumulate edit distance sum and sequence number from mini-batches and
compute the average edit_distance and instance error of all batches.
Args:
name: the metrics name
Example:
edit_distance_metrics = fluid.layers.edit_distance(input, label)
distance_evaluator = fluid.metrics.EditDistance()
for epoch in PASS_NUM:
distance_evaluator.reset()
for data in batches:
loss = exe.run(fetch_list=[cost] + list(edit_distance_metrics))
distance_evaluator.update(*edit_distance_metrics)
distance, instance_error = distance_evaluator.eval()
Examples:
.. code-block:: python
distances, seq_num = fluid.layers.edit_distance(input, label)
distance_evaluator = fluid.metrics.EditDistance()
for epoch in PASS_NUM:
distance_evaluator.reset()
for data in batches:
loss = exe.run(fetch_list=[cost] + list(edit_distance_metrics))
distance_evaluator.update(distances, seq_num)
distance, instance_error = distance_evaluator.eval()
In the above example:
'distance' is the average of the edit distance in a pass.
...
...
@@ -264,16 +477,38 @@ class EditDistance(MetricBase):
class
DetectionMAP
(
MetricBase
):
"""
Calculate the detection mean average precision (mAP).
TODO (Dang Qingqing): update the following doc.
The general steps are as follows:
1. calculate the true positive and false positive according to the input
of detection and labels.
2. calculate mAP value, support two versions: '11 point' and 'integral'.
mAP is the metric to measure the accuracy of object detectors
like Faster R-CNN, SSD, etc.
It is the average of the maximum precisions at different recall values.
Please get more information from the following articles:
https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325
The general steps are as follows:
1. calculate the true positive and false positive according to the input
of detection and labels.
2. calculate mAP value, support two versions: '11 point' and 'integral'.
Examples:
.. code-block:: python
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
batch_map = layers.detection_map(
input,
label,
class_num,
background_label,
overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult,
ap_version=ap_version)
metric = fluid.metrics.DetectionMAP()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, batch_map])
batch_size = data[0]
metric.update(value=batch_map, weight=batch_size)
numpy_map = metric.eval()
"""
def
__init__
(
self
,
name
=
None
):
...
...
@@ -302,17 +537,18 @@ class DetectionMAP(MetricBase):
class
Auc
(
MetricBase
):
"""
Auc Metrics which adapts to binary classification.
Need to note that auc metrics compute the value via Python natively.
Auc metric adapts to the binary classification.
Refer to https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve
Need to note that auc metric compute the value via Python natively.
If you concern the speed, please use the fluid.layers.auc instead.
The `auc` function creates four local variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` that are used to
compute the AUC. To discretize the AUC curve, a linearly spaced set of
thresholds is used to compute pairs of recall and precision values. The area
under the ROC-curve is therefore computed using the height of the recall
values by the false positive rate, while the area under the PR-curve is the
computed using the height of the precision values by the recall.
`true_negatives`, `false_positives` and `false_negatives` that are used to
compute the AUC. To discretize the AUC curve, a linearly spaced set of
thresholds is used to compute pairs of recall and precision values. The area
under the ROC-curve is therefore computed using the height of the recall
values by the false positive rate, while the area under the PR-curve is the
computed using the height of the precision values by the recall.
Args:
name: metric name
...
...
@@ -322,6 +558,16 @@ class Auc(MetricBase):
curve.
"NOTE: only implement the ROC curve type via Python now."
Examples:
.. code-block:: python
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
metric = fluid.metrics.Auc()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
metric.update(preds, labels)
numpy_auc = metric.eval()
"""
def
__init__
(
self
,
name
,
curve
=
'ROC'
,
num_thresholds
=
200
):
...
...
@@ -334,10 +580,10 @@ class Auc(MetricBase):
self
.
tn_list
=
np
.
zeros
((
num_thresholds
,
))
self
.
fp_list
=
np
.
zeros
((
num_thresholds
,
))
def
update
(
self
,
labels
,
predictions
,
axis
=
1
):
def
update
(
self
,
preds
,
labels
):
if
not
_is_numpy_
(
labels
):
raise
ValueError
(
"The 'labels' must be a numpy ndarray."
)
if
not
_is_numpy_
(
pred
iction
s
):
if
not
_is_numpy_
(
preds
):
raise
ValueError
(
"The 'predictions' must be a numpy ndarray."
)
kepsilon
=
1e-7
# to account for floating point imprecisions
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录