未验证 提交 78ca8cf0 编写于 作者: Q qingqing01 提交者: GitHub

Unify the metrics implementation between low-level and high-level API. (#26158)

* Move paddle/incubate/hapi/metrics to paddle/metric
* Add Precision, Recall and Auc metric
上级 da1efe24
......@@ -91,6 +91,7 @@ set(PADDLE_PYTHON_PACKAGE_DIR ${CMAKE_CURRENT_BINARY_DIR}/dist/)
if (WITH_TESTING)
add_subdirectory(paddle/reader/tests)
add_subdirectory(paddle/dataset/tests)
add_subdirectory(paddle/tests)
add_subdirectory(paddle/fluid/tests)
add_subdirectory(paddle/fluid/contrib/tests)
add_subdirectory(paddle/fluid/contrib/slim/tests)
......
......@@ -3682,5 +3682,32 @@ class TestBook(LayerTest):
batch_first=batch_first)
class TestMetricsDetectionMap(unittest.TestCase):
def test_detection_map(self):
program = fluid.Program()
with program_guard(program):
detect_res = fluid.layers.data(
name='detect_res',
shape=[10, 6],
append_batch_size=False,
dtype='float32')
label = fluid.layers.data(
name='label',
shape=[10, 1],
append_batch_size=False,
dtype='float32')
box = fluid.layers.data(
name='bbox',
shape=[10, 4],
append_batch_size=False,
dtype='float32')
map_eval = fluid.metrics.DetectionMAP(
detect_res, label, box, class_num=21)
cur_map, accm_map = map_eval.get_map_var()
self.assertIsNotNone(cur_map)
self.assertIsNotNone(accm_map)
print(str(program))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
from paddle.fluid.framework import Program, program_guard
class TestMetricsDetectionMap(unittest.TestCase):
def test_detection_map(self):
program = fluid.Program()
with program_guard(program):
detect_res = fluid.layers.data(
name='detect_res',
shape=[10, 6],
append_batch_size=False,
dtype='float32')
label = fluid.layers.data(
name='label',
shape=[10, 1],
append_batch_size=False,
dtype='float32')
box = fluid.layers.data(
name='bbox',
shape=[10, 4],
append_batch_size=False,
dtype='float32')
map_eval = fluid.metrics.DetectionMAP(
detect_res, label, box, class_num=21)
cur_map, accm_map = map_eval.get_map_var()
self.assertIsNotNone(cur_map)
self.assertIsNotNone(accm_map)
print(str(program))
if __name__ == '__main__':
unittest.main()
......@@ -20,7 +20,6 @@ from . import download
from . import model
from .model import *
from . import metrics
from . import datasets
from . import distributed
from . import vision
......@@ -39,7 +38,6 @@ __all__ = [
'datasets',
'distributed',
'download',
'metrics',
'vision',
'text',
'utils',
......
......@@ -305,8 +305,8 @@ class ProgBarLogger(Callback):
optim = fluid.optimizer.Adam(0.001)
model.prepare(optimizer=optim,
loss_function=paddle.nn.CrossEntropyLoss(),
metrics=hapi.metrics.Accuracy())
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())
callback = hapi.callbacks.ProgBarLogger(log_freq=10)
model.fit(train_dataset, batch_size=64, callbacks=callback)
......@@ -441,8 +441,8 @@ class ModelCheckpoint(Callback):
optim = fluid.optimizer.Adam(0.001)
model.prepare(optimizer=optim,
loss_function=paddle.nn.CrossEntropyLoss(),
metrics=hapi.metrics.Accuracy())
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())
callback = hapi.callbacks.ModelCheckpoint(save_dir='./temp')
model.fit(train_dataset, batch_size=64, callbacks=callback)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import abc
import numpy as np
import paddle.fluid as fluid
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
__all__ = ['Metric', 'Accuracy']
@six.add_metaclass(abc.ABCMeta)
class Metric(object):
"""
Base class for metric, encapsulates metric logic and APIs
Usage:
m = SomeMetric()
for prediction, label in ...:
m.update(prediction, label)
m.accumulate()
Advanced usage for :code:`add_metric_op`
Metric calculation can be accelerated by calculating metric states
from model outputs and labels by Paddle OPs in :code:`add_metric_op`,
metric states will be fetch as numpy array and call :code:`update`
with states in numpy format.
Metric calculated as follows (operations in Model and Metric are
indicated with curly brackets, while data nodes not):
inputs & labels || ------------------
| ||
{model} ||
| ||
outputs & labels ||
| || tensor data
{Metric.add_metric_op} ||
| ||
metric states(tensor) ||
| ||
{fetch as numpy} || ------------------
| ||
metric states(numpy) || numpy data
| ||
{Metric.update} \/ ------------------
Examples:
For :code:`Accuracy` metric, which takes :code:`pred` and :code:`label`
as inputs, we can calculate the correct prediction matrix between
:code:`pred` and :code:`label` in :code:`add_metric_op`.
For examples, prediction results contains 10 classes, while :code:`pred`
shape is [N, 10], :code:`label` shape is [N, 1], N is mini-batch size,
and we only need to calculate accurary of top-1 and top-5, we could
calculated the correct prediction matrix of the top-5 scores of the
prediction of each sample like follows, while the correct prediction
matrix shape is [N, 5].
.. code-block:: python
def add_metric_op(pred, label):
# sort prediction and slice the top-5 scores
pred = fluid.layers.argsort(pred, descending=True)[1][:, :5]
# calculate whether the predictions are correct
correct = pred == label
return fluid.layers.cast(correct, dtype='float32')
With the :code:`add_metric_op`, we split some calculations to OPs(which
may run on GPU devices, will be faster), and only fetch 1 tensor with
shape as [N, 5] instead of 2 tensors with shapes as [N, 10] and [N, 1].
:code:`update` can be define as follows:
.. code-block:: python
def update(self, correct):
accs = []
for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum()
num_samples = len(correct)
accs.append(float(num_corrects) / num_samples)
self.total[i] += num_corrects
self.count[i] += num_samples
return accs
"""
def __init__(self):
pass
@abc.abstractmethod
def reset(self):
"""
Reset states and result
"""
raise NotImplementedError("function 'reset' not implemented in {}.".
format(self.__class__.__name__))
@abc.abstractmethod
def update(self, *args):
"""
Update states for metric
Inputs of :code:`update` is the outputs of :code:`Metric.add_metric_op`,
if :code:`add_metric_op` is not defined, the inputs of :code:`update`
will be flatten arguments of **output** of mode and **label** from data:
:code:`update(output1, output2, ..., label1, label2,...)`
see :code:`Metric.add_metric_op`
"""
raise NotImplementedError("function 'update' not implemented in {}.".
format(self.__class__.__name__))
@abc.abstractmethod
def accumulate(self):
"""
Accumulates statistics, computes and returns the metric value
"""
raise NotImplementedError(
"function 'accumulate' not implemented in {}.".format(
self.__class__.__name__))
@abc.abstractmethod
def name(self):
"""
Returns metric name
"""
raise NotImplementedError("function 'name' not implemented in {}.".
format(self.__class__.__name__))
def add_metric_op(self, *args):
"""
This API is advanced usage to accelerate metric calculating, calulations
from outputs of model to the states which should be updated by Metric can
be defined here, where Paddle OPs is also supported. Outputs of this API
will be the inputs of "Metric.update".
If :code:`add_metric_op` is defined, it will be called with **outputs**
of model and **labels** from data as arguments, all outputs and labels
will be concatenated and flatten and each filed as a separate argument
as follows:
:code:`add_metric_op(output1, output2, ..., label1, label2,...)`
If :code:`add_metric_op` is not defined, default behaviour is to pass
input to output, so output format will be:
:code:`return output1, output2, ..., label1, label2,...`
see :code:`Metric.update`
"""
return args
class Accuracy(Metric):
"""
Encapsulates accuracy metric logic
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import paddle.incubate.hapi as hapi
fluid.enable_dygraph()
train_dataset = hapi.datasets.MNIST(mode='train')
model = hapi.Model(hapi.vision.LeNet(classifier_activation=None))
optim = fluid.optimizer.Adam(
learning_rate=0.001, parameter_list=model.parameters())
model.prepare(
optim,
loss_function=paddle.nn.CrossEntropyLoss(),
metrics=hapi.metrics.Accuracy())
model.fit(train_dataset, batch_size=64)
"""
def __init__(self, topk=(1, ), name=None, *args, **kwargs):
super(Accuracy, self).__init__(*args, **kwargs)
self.topk = topk
self.maxk = max(topk)
self._init_name(name)
self.reset()
def add_metric_op(self, pred, label, *args):
pred = fluid.layers.argsort(pred, descending=True)[1][:, :self.maxk]
correct = pred == label
return fluid.layers.cast(correct, dtype='float32')
def update(self, correct, *args):
accs = []
for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum()
num_samples = len(correct)
accs.append(float(num_corrects) / num_samples)
self.total[i] += num_corrects
self.count[i] += num_samples
return accs
def reset(self):
self.total = [0.] * len(self.topk)
self.count = [0] * len(self.topk)
def accumulate(self):
res = []
for t, c in zip(self.total, self.count):
res.append(float(t) / c)
return res
def _init_name(self, name):
name = name or 'acc'
if self.maxk != 1:
self._name = ['{}_top{}'.format(name, k) for k in self.topk]
else:
self._name = [name]
def name(self):
return self._name
......@@ -24,6 +24,7 @@ import six
import warnings
from collections import Iterable
import paddle
from paddle import fluid
# Note: Use alias `Input` temporarily before releasing hapi feature.
from paddle.static import InputSpec as Input
......@@ -36,9 +37,9 @@ from paddle.fluid.layers.utils import flatten
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
from paddle.fluid.incubate.fleet.base import role_maker
from paddle.io import DataLoader, Dataset
from paddle.metric import Metric
from .distributed import DistributedBatchSampler, _all_gather, prepare_distributed_context, _parallel_context_initialized
from .metrics import Metric
from .callbacks import config_callbacks
from .utils import to_list, to_numpy, flatten_list, restore_flatten_list, extract_args
from .device import _get_device
......@@ -361,8 +362,8 @@ class StaticGraphAdapter(object):
self._label_vars[mode] = labels
outputs = to_list(self.model.network.forward(*inputs))
if mode != 'test' and self.model._loss_function:
losses = self.model._loss_function(*(outputs + labels))
if mode != 'test' and self.model._loss:
losses = self.model._loss(*(outputs + labels))
if self._nranks > 1 and mode != 'train':
outputs = [_all_gather(o, self._nranks) for o in outputs]
......@@ -371,8 +372,7 @@ class StaticGraphAdapter(object):
if mode != 'test':
for metric in self.model._metrics:
metrics.append(
to_list(metric.add_metric_op(*(outputs + labels))))
metrics.append(to_list(metric.compute(*(outputs + labels))))
if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses)
......@@ -477,7 +477,7 @@ class DynamicGraphAdapter(object):
if self._nranks > 1:
outputs = self.ddp_model.forward(* [to_variable(x) for x in inputs])
losses = self.model._loss_function(*(to_list(outputs) + labels))
losses = self.model._loss(*(to_list(outputs) + labels))
losses = to_list(losses)
final_loss = fluid.layers.sum(losses)
final_loss = self.ddp_model.scale_loss(final_loss)
......@@ -486,7 +486,7 @@ class DynamicGraphAdapter(object):
else:
outputs = self.model.network.forward(
* [to_variable(x) for x in inputs])
losses = self.model._loss_function(*(to_list(outputs) + labels))
losses = self.model._loss(*(to_list(outputs) + labels))
losses = to_list(losses)
final_loss = fluid.layers.sum(losses)
final_loss.backward()
......@@ -495,7 +495,7 @@ class DynamicGraphAdapter(object):
self.model.network.clear_gradients()
metrics = []
for metric in self.model._metrics:
metric_outs = metric.add_metric_op(*(to_list(outputs) + labels))
metric_outs = metric.compute(*(to_list(outputs) + labels))
m = metric.update(* [to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
......@@ -510,8 +510,8 @@ class DynamicGraphAdapter(object):
labels = [to_variable(l) for l in to_list(labels)]
outputs = self.model.network.forward(* [to_variable(x) for x in inputs])
if self.model._loss_function:
losses = self.model._loss_function(*(to_list(outputs) + labels))
if self.model._loss:
losses = self.model._loss(*(to_list(outputs) + labels))
losses = to_list(losses)
if self._nranks > 1:
......@@ -539,13 +539,13 @@ class DynamicGraphAdapter(object):
self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples
metric_outs = metric.add_metric_op(*(to_list(outputs) + labels))
metric_outs = metric.compute(*(to_list(outputs) + labels))
m = metric.update(* [to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
if self.model._loss_function and len(metrics):
if self.model._loss and len(metrics):
return [to_numpy(l) for l in losses], metrics
elif self.model._loss_function:
elif self.model._loss:
return [to_numpy(l) for l in losses]
else:
return metrics
......@@ -633,21 +633,21 @@ class Model(object):
"""
An Model object is network with training and inference features.
Dynamic graph and static graph are supported at the same time,
switched by `fluid.enable_dygraph()`. The usage is as follows.
switched by `paddle.disable_static()`. The usage is as follows.
But note, the switching between dynamic and static should be before
instantiating a Model. The input description, i.e, hapi.Input,
must be required for static graph.
Args:
network (fluid.dygraph.Layer): The network is an instance of
fluid.dygraph.Layer.
network (paddle.nn.Layer): The network is an instance of
paddle.nn.Layer.
inputs (Input|list|dict|None): `inputs`, entry points of network,
could be a Input layer, or lits of Input layers,
or dict (name: Input), or None. For static graph,
inputs must be set. For dynamic graph, it could be None.
labels (Input|list|None): `labels`, entry points of network,
could be a Input layer or lits of Input layers, or None.
For static graph, if labels is required in loss_function,
For static graph, if labels is required in loss,
labels must be set. Otherwise, it could be None.
......@@ -655,13 +655,12 @@ class Model(object):
.. code-block:: python
import paddle
import paddle.fluid as fluid
import paddle.incubate.hapi as hapi
class MyNet(fluid.dygraph.Layer):
class MyNet(paddle.nn.Layer):
def __init__(self, classifier_act=None):
super(MyNet, self).__init__()
self._fc1 = fluid.dygraph.Linear(784, 200, act=classifier_act)
self._fc1 = paddle.nn.Linear(784, 200, act=classifier_act)
def forward(self, x):
y = self._fc1(x)
......@@ -669,18 +668,18 @@ class Model(object):
device = hapi.set_device('gpu')
# if use static graph, do not set
fluid.enable_dygraph(device)
paddle.disable_static(device)
# inputs and labels are not required for dynamic graph.
input = hapi.Input([None, 784], 'float32', 'x')
label = hapi.Input([None, 1], 'int64', 'label')
model = hapi.Model(MyNet(), input, label)
optim = fluid.optimizer.SGD(learning_rate=1e-3,
optim = paddle.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
model.prepare(optim,
paddle.nn.CrossEntropyLoss(),
hapi.metrics.Accuracy())
paddle.metric.Accuracy())
mnist_data = hapi.datasets.MNIST(mode='train', chw_format=False)
model.fit(mnist_data, epochs=2, batch_size=32, verbose=1)
......@@ -692,7 +691,7 @@ class Model(object):
self.network = network
self._inputs = None
self._labels = None
self._loss_function = None
self._loss = None
self._loss_weights = None
self._optimizer = None
self._optimizer = None
......@@ -732,25 +731,24 @@ class Model(object):
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.incubate.hapi as hapi
class MyNet(fluid.dygraph.Layer):
class MyNet(paddle.nn.Layer):
def __init__(self, classifier_act=None):
super(MyNet, self).__init__()
self._fc = fluid.dygraph.Linear(784, 10, act=classifier_act)
self._fc = paddle.nn.Linear(784, 10, act=classifier_act)
def forward(self, x):
y = self._fc(x)
return y
device = hapi.set_device('gpu')
fluid.enable_dygraph(device)
paddle.disable_static(device)
input = hapi.Input([None, 784], 'float32', 'x')
label = hapi.Input([None, 1], 'int64', 'label')
model = hapi.Model(MyNet(), input, label)
optim = fluid.optimizer.SGD(learning_rate=1e-3,
optim = paddle.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
model.prepare(optim, paddle.nn.CrossEntropyLoss())
data = np.random.random(size=(4,784)).astype(np.float32)
......@@ -781,25 +779,24 @@ class Model(object):
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.incubate.hapi as hapi
class MyNet(fluid.dygraph.Layer):
class MyNet(paddle.nn.Layer):
def __init__(self, classifier_act=None):
super(MyNet, self).__init__()
self._fc = fluid.dygraph.Linear(784, 10, act=classifier_act)
self._fc = paddle.nn.Linear(784, 10, act=classifier_act)
def forward(self, x):
y = self._fc(x)
return y
device = hapi.set_device('gpu')
fluid.enable_dygraph(device)
paddle.disable_static(device)
input = hapi.Input([None, 784], 'float32', 'x')
label = hapi.Input([None, 1], 'int64', 'label')
model = hapi.Model(MyNet(), input, label)
optim = fluid.optimizer.SGD(learning_rate=1e-3,
optim = paddle.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
model.prepare(optim,
paddle.nn.CrossEntropyLoss())
......@@ -827,24 +824,24 @@ class Model(object):
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
import paddle
import paddle.incubate.hapi as hapi
class MyNet(fluid.dygraph.Layer):
class MyNet(paddle.nn.Layer):
def __init__(self):
super(MyNet, self).__init__()
self._fc = fluid.dygraph.Linear(784, 1, act='softmax')
self._fc = paddle.nn.Linear(784, 1, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = hapi.set_device('gpu')
fluid.enable_dygraph(device)
paddle.disable_static(device)
model = hapi.Model(MyNet())
model.prepare()
data = np.random.random(size=(4,784)).astype(np.float32)
out = model.eval_batch([data])
out = model.test_batch([data])
print(out)
"""
return self._adapter.test_batch(inputs)
......@@ -875,19 +872,19 @@ class Model(object):
.. code-block:: python
import paddle.fluid as fluid
import paddle
import paddle.incubate.hapi as hapi
class MyNet(fluid.dygraph.Layer):
class MyNet(paddle.nn.Layer):
def __init__(self):
super(MyNet, self).__init__()
self._fc = fluid.dygraph.Linear(784, 1, act='softmax')
self._fc = paddle.nn.Linear(784, 1, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = hapi.set_device('cpu')
fluid.enable_dygraph(device)
paddle.disable_static(device)
model = hapi.Model(MyNet())
model.save('checkpoint/test')
"""
......@@ -927,19 +924,19 @@ class Model(object):
.. code-block:: python
import paddle.fluid as fluid
import paddle
import paddle.incubate.hapi as hapi
class MyNet(fluid.dygraph.Layer):
class MyNet(paddle.nn.Layer):
def __init__(self):
super(MyNet, self).__init__()
self._fc = fluid.dygraph.Linear(784, 1, act='softmax')
self._fc = paddle.nn.Linear(784, 1, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
device = hapi.set_device('cpu')
fluid.enable_dygraph(device)
paddle.disable_static(device)
model = hapi.Model(MyNet())
model.load('checkpoint/test')
"""
......@@ -1002,24 +999,24 @@ class Model(object):
.. code-block:: python
import paddle.fluid as fluid
import paddle
from paddle.incubate.hapi import Model
class MyNet(fluid.dygraph.Layer):
class MyNet(paddle.nn.Layer):
def __init__(self):
super(MyNet, self).__init__()
self._fc = fluid.dygraph.Linear(20, 10, act='softmax')
self._fc = paddle.nn.Linear(20, 10, act='softmax')
def forward(self, x):
y = self._fc(x)
return y
fluid.enable_dygraph()
paddle.disable_static()
model = Model(MyNet())
params = model.parameters()
"""
return self._adapter.parameters()
def prepare(self, optimizer=None, loss_function=None, metrics=None):
def prepare(self, optimizer=None, loss=None, metrics=None):
"""
Configures the model before runing.
......@@ -1027,8 +1024,8 @@ class Model(object):
optimizer (Optimizer|None): Optimizer must be set in training
and should be a Optimizer instance. It can be None in eval
and test mode.
loss_function (Loss|callable function|None): Loss function can
be a `fluid.dygraph.Layer` instance or any callable function
loss (Loss|callable function|None): Loss function can
be a `paddle.nn.Layer` instance or any callable function
taken the predicted values and ground truth values as input.
It can be None when there is no loss.
metrics (Metric|list of Metric|None): If metrics is set, all
......@@ -1047,7 +1044,7 @@ class Model(object):
startup_prog_seed = fluid.default_startup_program(
).random_seed
fluid.disable_dygraph()
fluid.enable_dygraph(self._place)
paddle.disable_static(self._place)
# enable_dygraph would create and switch to a new program,
# thus also copy seed to the new program
fluid.default_main_program().random_seed = main_prog_seed
......@@ -1059,12 +1056,11 @@ class Model(object):
_parallel_context_initialized = True
self._optimizer = optimizer
if loss_function:
if not isinstance(loss_function, fluid.dygraph.Layer) or \
not callable(loss_function):
raise TypeError("'loss_function' must be sub classes of \
`fluid.dygraph.Layer` or any callable function.")
self._loss_function = loss_function
if loss is not None:
if not isinstance(loss, paddle.nn.Layer) and not callable(loss):
raise TypeError("'loss' must be sub classes of " \
"`paddle.nn.Layer` or any callable function.")
self._loss = loss
metrics = metrics or []
for metric in to_list(metrics):
......@@ -1144,12 +1140,11 @@ class Model(object):
.. code-block:: python
import paddle
import paddle.fluid as fluid
import paddle.incubate.hapi as hapi
dynamic = True
device = hapi.set_device('gpu')
fluid.enable_dygraph(device) if dynamic else None
paddle.disable_static(device) if dynamic else None
train_dataset = hapi.datasets.MNIST(mode='train')
val_dataset = hapi.datasets.MNIST(mode='test')
......@@ -1159,12 +1154,12 @@ class Model(object):
model = hapi.Model(hapi.vision.LeNet(classifier_activation=None),
input, label)
optim = fluid.optimizer.Adam(
learning_rate=0.001, parameter_list=model.parameters())
optim = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
model.prepare(
optim,
paddle.nn.CrossEntropyLoss(),
hapi.metrics.Accuracy(topk=(1, 2)))
paddle.metric.Accuracy(topk=(1, 2)))
model.fit(train_dataset,
val_dataset,
epochs=2,
......@@ -1177,18 +1172,17 @@ class Model(object):
.. code-block:: python
import paddle
import paddle.fluid as fluid
import paddle.incubate.hapi as hapi
dynamic = True
device = hapi.set_device('gpu')
fluid.enable_dygraph(device) if dynamic else None
paddle.disable_static(device) if dynamic else None
train_dataset = hapi.datasets.MNIST(mode='train')
train_loader = fluid.io.DataLoader(train_dataset,
train_loader = paddle.io.DataLoader(train_dataset,
places=device, batch_size=64)
val_dataset = hapi.datasets.MNIST(mode='test')
val_loader = fluid.io.DataLoader(val_dataset,
val_loader = paddle.io.DataLoader(val_dataset,
places=device, batch_size=64)
input = hapi.Input([None, 1, 28, 28], 'float32', 'image')
......@@ -1196,12 +1190,12 @@ class Model(object):
model = hapi.Model(hapi.vision.LeNet(classifier_activation=None),
input, label)
optim = fluid.optimizer.Adam(
learning_rate=0.001, parameter_list=model.parameters())
optim = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
model.prepare(
optim,
paddle.nn.CrossEntropyLoss(),
hapi.metrics.Accuracy(topk=(1, 2)))
paddle.metric.Accuracy(topk=(1, 2)))
model.fit(train_loader,
val_loader,
epochs=2,
......@@ -1313,7 +1307,7 @@ class Model(object):
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
import paddle.incubate.hapi as hapi
# declarative mode
......@@ -1322,15 +1316,15 @@ class Model(object):
input = hapi.Input([-1, 1, 28, 28], 'float32', 'image')
label = hapi.Input([None, 1], 'int64', 'label')
model = hapi.Model(hapi.vision.LeNet(), input, label)
model.prepare(metrics=hapi.metrics.Accuracy())
model.prepare(metrics=paddle.metric.Accuracy())
result = model.evaluate(val_dataset, batch_size=64)
print(result)
# imperative mode
fluid.enable_dygraph()
paddle.disable_static()
model = hapi.Model(hapi.vision.LeNet())
model.prepare(metrics=hapi.metrics.Accuracy())
model.prepare(metrics=paddle.metric.Accuracy())
result = model.evaluate(val_dataset, batch_size=64)
print(result)
......@@ -1407,7 +1401,7 @@ class Model(object):
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
import paddle
import paddle.incubate.hapi as hapi
class MnistDataset(hapi.datasets.MNIST):
......@@ -1436,7 +1430,7 @@ class Model(object):
# imperative mode
device = hapi.set_device('cpu')
fluid.enable_dygraph(device)
paddle.disable_static(device)
model = hapi.Model(hapi.vision.LeNet())
model.prepare()
result = model.predict(test_dataset, batch_size=64)
......@@ -1506,7 +1500,6 @@ class Model(object):
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.incubate.hapi as hapi
input = hapi.Input([-1, 1, 28, 28], 'float32', 'image')
......@@ -1562,9 +1555,9 @@ class Model(object):
if mode != 'test':
outs = getattr(self, mode + '_batch')(data[:len(self._inputs)],
data[len(self._inputs):])
if self._metrics and self._loss_function:
if self._metrics and self._loss:
metrics = [[l[0] for l in outs[0]]]
elif self._loss_function:
elif self._loss:
metrics = [[l[0] for l in outs]]
else:
metrics = []
......@@ -1635,7 +1628,7 @@ class Model(object):
metric.reset()
def _metrics_name(self):
metrics_name = ['loss'] if self._loss_function else []
metrics_name = ['loss'] if self._loss else []
for m in self._metrics:
metrics_name.extend(to_list(m.name()))
return metrics_name
......
......@@ -25,7 +25,7 @@ from paddle import fluid
from paddle.incubate.hapi import Model, Input, set_device
from paddle.nn.layer.loss import CrossEntropyLoss
from paddle.incubate.hapi.vision.models import LeNet
from paddle.incubate.hapi.metrics import Accuracy
from paddle.metric import Accuracy
from paddle.incubate.hapi.callbacks import ProgBarLogger
from paddle.incubate.hapi.datasets import MNIST
......
......@@ -25,7 +25,7 @@ from paddle import fluid
from paddle.incubate.hapi import Model, Input, set_device
from paddle.nn.layer.loss import CrossEntropyLoss
from paddle.incubate.hapi.vision.models import LeNet
from paddle.incubate.hapi.metrics import Accuracy
from paddle.metric import Accuracy
from paddle.incubate.hapi.callbacks import ProgBarLogger
from paddle.incubate.hapi.datasets import MNIST
......
......@@ -29,7 +29,7 @@ from paddle.fluid.dygraph.base import to_variable
import paddle.incubate.hapi as hapi
from paddle.incubate.hapi import Model, Input
from paddle.nn.layer.loss import CrossEntropyLoss
from paddle.incubate.hapi.metrics import Accuracy
from paddle.metric import Accuracy
from paddle.incubate.hapi.datasets import MNIST
from paddle.incubate.hapi.vision.models import LeNet
from paddle.incubate.hapi.distributed import DistributedBatchSampler, prepare_distributed_context
......@@ -202,7 +202,7 @@ class TestModel(unittest.TestCase):
model = Model(net, inputs=self.inputs, labels=self.labels)
model.prepare(
optim_new,
loss_function=CrossEntropyLoss(reduction="sum"),
loss=CrossEntropyLoss(reduction="sum"),
metrics=Accuracy())
model.fit(self.train_dataset, batch_size=64, shuffle=False)
......@@ -333,8 +333,7 @@ class TestModelFunction(unittest.TestCase):
inputs = [Input([None, dim], 'float32', 'x')]
labels = [Input([None, 1], 'int64', 'label')]
model = Model(net, inputs, labels)
model.prepare(
optim2, loss_function=CrossEntropyLoss(reduction="sum"))
model.prepare(optim2, loss=CrossEntropyLoss(reduction="sum"))
loss, = model.train_batch([data], [label])
np.testing.assert_allclose(loss.flatten(), ref.flatten())
......@@ -379,8 +378,7 @@ class TestModelFunction(unittest.TestCase):
parameter_list=net.parameters())
model = Model(net, inputs, labels)
model.prepare(
optimizer=optim,
loss_function=CrossEntropyLoss(reduction="sum"))
optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
model.save(path + '/test')
model.load(path + '/test')
shutil.rmtree(path)
......@@ -394,8 +392,7 @@ class TestModelFunction(unittest.TestCase):
model = Model(MyModel(classifier_activation=None))
optim = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=model.parameters())
model.prepare(
optimizer=optim, loss_function=CrossEntropyLoss(reduction="sum"))
model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
model.save(path + '/test')
fluid.disable_dygraph()
......@@ -404,8 +401,7 @@ class TestModelFunction(unittest.TestCase):
model = Model(MyModel(classifier_activation=None), inputs, labels)
optim = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=model.parameters())
model.prepare(
optimizer=optim, loss_function=CrossEntropyLoss(reduction="sum"))
model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
model.load(path + '/test')
shutil.rmtree(path)
......@@ -418,8 +414,7 @@ class TestModelFunction(unittest.TestCase):
optim = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=net.parameters())
model = Model(net, inputs, labels)
model.prepare(
optimizer=optim, loss_function=CrossEntropyLoss(reduction="sum"))
model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
model.save(path + '/test')
device = hapi.set_device('cpu')
......@@ -431,8 +426,7 @@ class TestModelFunction(unittest.TestCase):
optim = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=net.parameters())
model = Model(net, inputs, labels)
model.prepare(
optimizer=optim, loss_function=CrossEntropyLoss(reduction="sum"))
model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
model.load(path + '/test')
shutil.rmtree(path)
fluid.disable_dygraph()
......
......@@ -12,17 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: define the functions to calculate metric in this directory
__all__ = [
'Accuracy', 'Auc', 'ChunkEvaluator', 'CompositeMetric', 'DetectionMAP',
'EditDistance', 'Precision', 'Recall', 'accuracy', 'auc', 'chunk_eval',
'cos_sim', 'mean_iou'
]
from ..fluid.metrics import Accuracy, Auc, ChunkEvaluator, CompositeMetric, DetectionMAP, EditDistance, \
Precision, Recall
from .metrics import *
from . import metrics
from ..fluid.layers.metric_op import accuracy, auc
from ..fluid.layers.nn import chunk_eval, cos_sim, mean_iou
__all__ = metrics.__all__ + [
'accuracy',
'auc',
'chunk_eval',
'cos_sim',
'mean_iou',
]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import abc
import numpy as np
import paddle
__all__ = ['Metric', 'Accuracy', 'Precision', 'Recall', 'Auc']
def _is_numpy_(var):
return isinstance(var, (np.ndarray, np.generic))
@six.add_metaclass(abc.ABCMeta)
class Metric(object):
"""
Base class for metric, encapsulates metric logic and APIs
Usage:
m = SomeMetric()
for prediction, label in ...:
m.update(prediction, label)
m.accumulate()
Advanced usage for :code:`compute`:
Metric calculation can be accelerated by calculating metric states
from model outputs and labels by build-in operators not by Python/NumPy
in :code:`compute`, metric states will be fetched as NumPy array and
call :code:`update` with states in NumPy format.
Metric calculated as follows (operations in Model and Metric are
indicated with curly brackets, while data nodes not):
inputs & labels || ------------------
| ||
{model} ||
| ||
outputs & labels ||
| || tensor data
{Metric.compute} ||
| ||
metric states(tensor) ||
| ||
{fetch as numpy} || ------------------
| ||
metric states(numpy) || numpy data
| ||
{Metric.update} \/ ------------------
Examples:
For :code:`Accuracy` metric, which takes :code:`pred` and :code:`label`
as inputs, we can calculate the correct prediction matrix between
:code:`pred` and :code:`label` in :code:`compute`.
For examples, prediction results contains 10 classes, while :code:`pred`
shape is [N, 10], :code:`label` shape is [N, 1], N is mini-batch size,
and we only need to calculate accurary of top-1 and top-5, we could
calculate the correct prediction matrix of the top-5 scores of the
prediction of each sample like follows, while the correct prediction
matrix shape is [N, 5].
.. code-block:: python
def compute(pred, label):
# sort prediction and slice the top-5 scores
pred = paddle.argsort(pred, descending=True)[:, :5]
# calculate whether the predictions are correct
correct = pred == label
return paddle.cast(correct, dtype='float32')
With the :code:`compute`, we split some calculations to OPs (which
may run on GPU devices, will be faster), and only fetch 1 tensor with
shape as [N, 5] instead of 2 tensors with shapes as [N, 10] and [N, 1].
:code:`update` can be define as follows:
.. code-block:: python
def update(self, correct):
accs = []
for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum()
num_samples = len(correct)
accs.append(float(num_corrects) / num_samples)
self.total[i] += num_corrects
self.count[i] += num_samples
return accs
"""
def __init__(self):
pass
@abc.abstractmethod
def reset(self):
"""
Reset states and result
"""
raise NotImplementedError("function 'reset' not implemented in {}.".
format(self.__class__.__name__))
@abc.abstractmethod
def update(self, *args):
"""
Update states for metric
Inputs of :code:`update` is the outputs of :code:`Metric.compute`,
if :code:`compute` is not defined, the inputs of :code:`update`
will be flatten arguments of **output** of mode and **label** from data:
:code:`update(output1, output2, ..., label1, label2,...)`
see :code:`Metric.compute`
"""
raise NotImplementedError("function 'update' not implemented in {}.".
format(self.__class__.__name__))
@abc.abstractmethod
def accumulate(self):
"""
Accumulates statistics, computes and returns the metric value
"""
raise NotImplementedError(
"function 'accumulate' not implemented in {}.".format(
self.__class__.__name__))
@abc.abstractmethod
def name(self):
"""
Returns metric name
"""
raise NotImplementedError("function 'name' not implemented in {}.".
format(self.__class__.__name__))
def compute(self, *args):
"""
This API is advanced usage to accelerate metric calculating, calulations
from outputs of model to the states which should be updated by Metric can
be defined here, where Paddle OPs is also supported. Outputs of this API
will be the inputs of "Metric.update".
If :code:`compute` is defined, it will be called with **outputs**
of model and **labels** from data as arguments, all outputs and labels
will be concatenated and flatten and each filed as a separate argument
as follows:
:code:`compute(output1, output2, ..., label1, label2,...)`
If :code:`compute` is not defined, default behaviour is to pass
input to output, so output format will be:
:code:`return output1, output2, ..., label1, label2,...`
see :code:`Metric.update`
"""
return args
class Accuracy(Metric):
"""
Encapsulates accuracy metric logic.
Args:
topk (int|tuple(int)): Number of top elements to look at
for computing accuracy. Default is (1,).
name (str, optional): String name of the metric instance. Default
is `acc`.
Example by standalone:
.. code-block:: python
import numpy as np
import paddle
paddle.disable_static()
x = paddle.to_tensor(np.array([
[0.1, 0.2, 0.3, 0.4],
[0.1, 0.4, 0.3, 0.2],
[0.1, 0.2, 0.4, 0.3],
[0.1, 0.2, 0.3, 0.4]]))
y = paddle.to_tensor(np.array([[0], [1], [2], [3]]))
m = paddle.metric.Accuracy()
correct = m.compute(x, y)
m.update(correct)
res = m.accumulate()
print(res) # 0.75
Example with Model API:
.. code-block:: python
import paddle
import paddle.incubate.hapi as hapi
paddle.disable_static()
train_dataset = hapi.datasets.MNIST(mode='train')
model = hapi.Model(hapi.vision.LeNet(classifier_activation=None))
optim = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
model.prepare(
optim,
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())
model.fit(train_dataset, batch_size=64)
"""
def __init__(self, topk=(1, ), name=None, *args, **kwargs):
super(Accuracy, self).__init__(*args, **kwargs)
self.topk = topk
self.maxk = max(topk)
self._init_name(name)
self.reset()
def compute(self, pred, label, *args):
"""
Compute the top-k (maxinum value in `topk`) indices.
Args:
pred (Tensor): The predicted value is a Tensor wit type
float32 or float64.
label (Tensor): The ground truth value is a 2D Tensor, its
shape is [batch_size, 1] and type is int64.
Return:
Tensor: Correct mask, a tensor with shape [batch_size, topk].
"""
pred = paddle.argsort(pred, descending=True)[:, :self.maxk]
correct = pred == label
return paddle.cast(correct, dtype='float32')
def update(self, correct, *args):
"""
Update the metrics states (correct count and total count), in order to
calculate cumulative accuracy of all instances. This function also
returns the accuracy of current step.
Args:
correct: Correct mask, a tensor with shape [batch_size, topk].
Return:
Tensor: the accuracy of current step.
"""
if isinstance(correct, paddle.Tensor):
correct = correct.numpy()
accs = []
for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum()
num_samples = len(correct)
accs.append(float(num_corrects) / num_samples)
self.total[i] += num_corrects
self.count[i] += num_samples
accs = accs[0] if len(self.topk) == 1 else accs
return accs
def reset(self):
"""
Resets all of the metric state.
"""
self.total = [0.] * len(self.topk)
self.count = [0] * len(self.topk)
def accumulate(self):
"""
Computes and returns the accumulated metric.
"""
res = []
for t, c in zip(self.total, self.count):
r = float(t) / c if c > 0 else 0.
res.append(r)
res = res[0] if len(self.topk) == 1 else res
return res
def _init_name(self, name):
name = name or 'acc'
if self.maxk != 1:
self._name = ['{}_top{}'.format(name, k) for k in self.topk]
else:
self._name = [name]
def name(self):
"""
Return name of metric instance.
"""
return self._name
class Precision(Metric):
"""
Precision (also called positive predictive value) is the fraction of
relevant instances among the retrieved instances. Refer to
https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers
Noted that this class manages the precision score only for binary
classification task.
Args:
name (str, optional): String name of the metric instance.
Default is `precision`.
Example by standalone:
.. code-block:: python
import numpy as np
import paddle
x = np.array([0.1, 0.5, 0.6, 0.7])
y = np.array([0, 1, 1, 1])
m = paddle.metric.Precision()
m.update(x, y)
res = m.accumulate()
print(res) # 1.0
Example with Model API:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn as nn
import paddle.incubate.hapi as hapi
class Data(paddle.io.Dataset):
def __init__(self):
super(Data, self).__init__()
self.n = 1024
self.x = np.random.randn(self.n, 10).astype('float32')
self.y = np.random.randint(2, size=(self.n, 1)).astype('float32')
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
def __len__(self):
return self.n
paddle.disable_static()
model = hapi.Model(nn.Sequential(
nn.Linear(10, 1),
nn.Sigmoid()
))
optim = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
model.prepare(
optim,
loss=nn.BCELoss(),
metrics=paddle.metric.Precision())
data = Data()
model.fit(data, batch_size=16)
"""
def __init__(self, name='precision', *args, **kwargs):
super(Precision, self).__init__(*args, **kwargs)
self.tp = 0 # true positive
self.fp = 0 # false positive
self._name = name
def update(self, preds, labels):
"""
Update the states based on the current mini-batch prediction results.
Args:
preds (numpy.ndarray): The prediction result, usually the output
of two-class sigmoid function. It should be a vector (column
vector or row vector) with data type: 'float64' or 'float32'.
labels (numpy.ndarray): The ground truth (labels),
the shape should keep the same as preds.
The data type is 'int32' or 'int64'.
"""
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
elif not _is_numpy_(preds):
raise ValueError("The 'preds' must be a numpy ndarray or Tensor.")
if isinstance(labels, paddle.Tensor):
labels = labels.numpy()
elif not _is_numpy_(labels):
raise ValueError("The 'labels' must be a numpy ndarray or Tensor.")
sample_num = labels.shape[0]
preds = np.floor(preds + 0.5).astype("int32")
for i in range(sample_num):
pred = preds[i]
label = labels[i]
if pred == 1:
if pred == label:
self.tp += 1
else:
self.fp += 1
def reset(self):
"""
Resets all of the metric state.
"""
self.tp = 0
self.fp = 0
def accumulate(self):
"""
Calculate the final precision.
Returns:
A scaler float: results of the calculated precision.
"""
ap = self.tp + self.fp
return float(self.tp) / ap if ap != 0 else .0
def name(self):
"""
Returns metric name
"""
return self._name
class Recall(Metric):
"""
Recall (also known as sensitivity) is the fraction of
relevant instances that have been retrieved over the
total amount of relevant instances
Refer to:
https://en.wikipedia.org/wiki/Precision_and_recall
Noted that this class manages the recall score only for
binary classification task.
Args:
name (str, optional): String name of the metric instance.
Default is `recall`.
Example by standalone:
.. code-block:: python
import numpy as np
import paddle
x = np.array([0.1, 0.5, 0.6, 0.7])
y = np.array([1, 0, 1, 1])
m = paddle.metric.Recall()
m.update(x, y)
res = m.accumulate()
print(res) # 2.0 / 3.0
Example with Model API:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn as nn
import paddle.incubate.hapi as hapi
class Data(paddle.io.Dataset):
def __init__(self):
super(Data, self).__init__()
self.n = 1024
self.x = np.random.randn(self.n, 10).astype('float32')
self.y = np.random.randint(2, size=(self.n, 1)).astype('float32')
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
def __len__(self):
return self.n
paddle.disable_static()
model = hapi.Model(nn.Sequential(
nn.Linear(10, 1),
nn.Sigmoid()
))
optim = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
model.prepare(
optim,
loss=nn.BCELoss(),
metrics=[paddle.metric.Precision(), paddle.metric.Recall()])
data = Data()
model.fit(data, batch_size=16)
"""
def __init__(self, name='recall', *args, **kwargs):
super(Recall, self).__init__(*args, **kwargs)
self.tp = 0 # true positive
self.fn = 0 # false negative
self._name = name
def update(self, preds, labels):
"""
Update the states based on the current mini-batch prediction results.
Args:
preds(numpy.array): prediction results of current mini-batch,
the output of two-class sigmoid function.
Shape: [batch_size, 1]. Dtype: 'float64' or 'float32'.
labels(numpy.array): ground truth (labels) of current mini-batch,
the shape should keep the same as preds.
Shape: [batch_size, 1], Dtype: 'int32' or 'int64'.
"""
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
elif not _is_numpy_(preds):
raise ValueError("The 'preds' must be a numpy ndarray or Tensor.")
if isinstance(labels, paddle.Tensor):
labels = labels.numpy()
elif not _is_numpy_(labels):
raise ValueError("The 'labels' must be a numpy ndarray or Tensor.")
sample_num = labels.shape[0]
preds = np.rint(preds).astype("int32")
for i in range(sample_num):
pred = preds[i]
label = labels[i]
if label == 1:
if pred == label:
self.tp += 1
else:
self.fn += 1
def accumulate(self):
"""
Calculate the final recall.
Returns:
A scaler float: results of the calculated Recall.
"""
recall = self.tp + self.fn
return float(self.tp) / recall if recall != 0 else .0
def reset(self):
"""
Resets all of the metric state.
"""
self.tp = 0
self.fn = 0
def name(self):
"""
Returns metric name
"""
return self._name
class Auc(Metric):
"""
The auc metric is for binary classification.
Refer to https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
Please notice that the auc metric is implemented with python, which may be a little bit slow.
The `auc` function creates four local variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` that are used to
compute the AUC. To discretize the AUC curve, a linearly spaced set of
thresholds is used to compute pairs of recall and precision values. The area
under the ROC-curve is therefore computed using the height of the recall
values by the false positive rate, while the area under the PR-curve is the
computed using the height of the precision values by the recall.
Args:
curve (str): Specifies the mode of the curve to be computed,
'ROC' or 'PR' for the Precision-Recall-curve. Default is 'ROC'.
num_thresholds (int): The number of thresholds to use when
discretizing the roc curve. Default is 4095.
'ROC' or 'PR' for the Precision-Recall-curve. Default is 'ROC'.
name (str, optional): String name of the metric instance. Default
is `auc`.
"NOTE: only implement the ROC curve type via Python now."
Example by standalone:
.. code-block:: python
import numpy as np
import paddle
m = paddle.metric.Auc()
n = 8
class0_preds = np.random.random(size = (n, 1))
class1_preds = 1 - class0_preds
preds = np.concatenate((class0_preds, class1_preds), axis=1)
labels = np.random.randint(2, size = (n, 1))
m.update(preds=preds, labels=labels)
res = m.accumulate()
Example with Model API:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn as nn
import paddle.incubate.hapi as hapi
class Data(paddle.io.Dataset):
def __init__(self):
super(Data, self).__init__()
self.n = 1024
self.x = np.random.randn(self.n, 10).astype('float32')
self.y = np.random.randint(2, size=(self.n, 1)).astype('int64')
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
def __len__(self):
return self.n
paddle.disable_static()
model = hapi.Model(nn.Sequential(
nn.Linear(10, 2, act='softmax'),
))
optim = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
def loss(x, y):
return nn.functional.nll_loss(paddle.log(x), y)
model.prepare(
optim,
loss=loss,
metrics=paddle.metric.Auc())
data = Data()
model.fit(data, batch_size=16)
"""
def __init__(self,
curve='ROC',
num_thresholds=4095,
name='auc',
*args,
**kwargs):
super(Auc, self).__init__(*args, **kwargs)
self._curve = curve
self._num_thresholds = num_thresholds
_num_pred_buckets = num_thresholds + 1
self._stat_pos = np.zeros(_num_pred_buckets)
self._stat_neg = np.zeros(_num_pred_buckets)
self._name = name
def update(self, preds, labels):
"""
Update the auc curve with the given predictions and labels.
Args:
preds (numpy.array): An numpy array in the shape of
(batch_size, 2), preds[i][j] denotes the probability of
classifying the instance i into the class j.
labels (numpy.array): an numpy array in the shape of
(batch_size, 1), labels[i] is either o or 1,
representing the label of the instance i.
"""
if isinstance(labels, paddle.Tensor):
labels = labels.numpy()
elif not _is_numpy_(labels):
raise ValueError("The 'labels' must be a numpy ndarray or Tensor.")
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
elif not _is_numpy_(preds):
raise ValueError("The 'preds' must be a numpy ndarray or Tensor.")
for i, lbl in enumerate(labels):
value = preds[i, 1]
bin_idx = int(value * self._num_thresholds)
assert bin_idx <= self._num_thresholds
if lbl:
self._stat_pos[bin_idx] += 1.0
else:
self._stat_neg[bin_idx] += 1.0
@staticmethod
def trapezoid_area(x1, x2, y1, y2):
return abs(x1 - x2) * (y1 + y2) / 2.0
def accumulate(self):
"""
Return the area (a float score) under auc curve
Return:
float: the area under auc curve
"""
tot_pos = 0.0
tot_neg = 0.0
auc = 0.0
idx = self._num_thresholds
while idx >= 0:
tot_pos_prev = tot_pos
tot_neg_prev = tot_neg
tot_pos += self._stat_pos[idx]
tot_neg += self._stat_neg[idx]
auc += self.trapezoid_area(tot_neg, tot_neg_prev, tot_pos,
tot_pos_prev)
idx -= 1
return auc / tot_pos / tot_neg if tot_pos > 0.0 and tot_neg > 0.0 else 0.0
def reset(self):
"""
Reset states and result
"""
_num_pred_buckets = self._num_thresholds + 1
self._stat_pos = np.zeros(_num_pred_buckets)
self._stat_neg = np.zeros(_num_pred_buckets)
def name(self):
"""
Returns metric name
"""
return self._name
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
......@@ -19,10 +19,9 @@ import os
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.incubate.hapi.metrics import *
from paddle.incubate.hapi.utils import to_list
......@@ -35,7 +34,7 @@ def accuracy(pred, label, topk=(1, )):
res = []
for k in topk:
correct_k = correct[:, :k].sum()
res.append(correct_k / batch_size)
res.append(float(correct_k) / batch_size)
return res
......@@ -47,6 +46,41 @@ def convert_to_one_hot(y, C):
return oh
class TestAccuracy(unittest.TestCase):
def test_acc(self):
paddle.disable_static()
x = paddle.to_tensor(
np.array([[0.1, 0.2, 0.3, 0.4], [0.1, 0.4, 0.3, 0.2],
[0.1, 0.2, 0.4, 0.3], [0.1, 0.2, 0.3, 0.4]]))
y = paddle.to_tensor(np.array([[0], [1], [2], [3]]))
m = paddle.metric.Accuracy(name='my_acc')
# check name
self.assertEqual(m.name(), ['my_acc'])
correct = m.compute(x, y)
# check results
self.assertEqual(m.update(correct), 0.75)
self.assertEqual(m.accumulate(), 0.75)
x = paddle.to_tensor(
np.array([[0.1, 0.2, 0.3, 0.4], [0.1, 0.3, 0.4, 0.2],
[0.1, 0.2, 0.4, 0.3], [0.1, 0.2, 0.3, 0.4]]))
y = paddle.to_tensor(np.array([[0], [1], [2], [3]]))
correct = m.compute(x, y)
# check results
self.assertEqual(m.update(correct), 0.5)
self.assertEqual(m.accumulate(), 0.625)
# check reset
m.reset()
self.assertEqual(m.total[0], 0.0)
self.assertEqual(m.count[0], 0.0)
paddle.enable_static()
class TestAccuracyDynamic(unittest.TestCase):
def setUp(self):
self.topk = (1, )
......@@ -66,17 +100,18 @@ class TestAccuracyDynamic(unittest.TestCase):
def test_main(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
acc = Accuracy(topk=self.topk, name=self.name)
acc = paddle.metric.Accuracy(topk=self.topk, name=self.name)
for _ in range(10):
label, pred = self.random_pred_label()
label_var = to_variable(label)
pred_var = to_variable(pred)
state = to_list(acc.add_metric_op(pred_var, label_var))
label_var = paddle.to_tensor(label)
pred_var = paddle.to_tensor(pred)
state = to_list(acc.compute(pred_var, label_var))
acc.update(* [s.numpy() for s in state])
res_m = acc.accumulate()
res_f = accuracy(pred, label, self.topk)
assert np.all(np.isclose(np.array(res_m, dtype='float64'), np.array(res_f, dtype='float64'), rtol=1e-3)), \
"Accuracy precision error: {} != {}".format(res_m, res_f)
assert np.all(np.isclose(np.array(res_m, dtype='float64'),
np.array(res_f, dtype='float64'), rtol=1e-3)), \
"Accuracy precision error: {} != {}".format(res_m, res_f)
acc.reset()
assert np.sum(acc.total) == 0
assert np.sum(acc.count) == 0
......@@ -94,12 +129,14 @@ class TestAccuracyStatic(TestAccuracyDynamic):
def test_main(self):
main_prog = fluid.Program()
startup_prog = fluid.Program()
main_prog.random_seed = 1024
startup_prog.random_seed = 1024
with fluid.program_guard(main_prog, startup_prog):
pred = fluid.data(
name='pred', shape=[None, self.class_num], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
acc = Accuracy(topk=self.topk, name=self.name)
state = acc.add_metric_op(pred, label)
acc = paddle.metric.Accuracy(topk=self.topk, name=self.name)
state = acc.compute(pred, label)
exe = fluid.Executor(fluid.CPUPlace())
compiled_main_prog = fluid.CompiledProgram(main_prog)
......@@ -114,7 +151,7 @@ class TestAccuracyStatic(TestAccuracyDynamic):
acc.update(*state_ret)
res_m = acc.accumulate()
res_f = accuracy(pred, label, self.topk)
assert np.all(np.isclose(np.array(res_m, dtype='float64'), np.array(res_f, dtype='float64'), rtol=1e-3)), \
assert np.all(np.isclose(np.array(res_m), np.array(res_f), rtol=1e-3)), \
"Accuracy precision error: {} != {}".format(res_m, res_f)
acc.reset()
assert np.sum(acc.total) == 0
......@@ -125,9 +162,114 @@ class TestAccuracyStaticMultiTopk(TestAccuracyStatic):
def setUp(self):
self.topk = (1, 5)
self.class_num = 10
self.sample_num = 1000
self.sample_num = 100
self.name = "accuracy"
class TestPrecision(unittest.TestCase):
def test_1d(self):
paddle.disable_static()
x = np.array([0.1, 0.5, 0.6, 0.7])
y = np.array([1, 0, 1, 1])
m = paddle.metric.Precision()
m.update(x, y)
r = m.accumulate()
self.assertAlmostEqual(r, 2. / 3.)
x = paddle.to_tensor(np.array([0.1, 0.5, 0.6, 0.7, 0.2]))
y = paddle.to_tensor(np.array([1, 0, 1, 1, 1]))
m.update(x, y)
r = m.accumulate()
self.assertAlmostEqual(r, 4. / 6.)
paddle.enable_static()
def test_2d(self):
paddle.disable_static()
x = np.array([0.1, 0.5, 0.6, 0.7]).reshape(-1, 1)
y = np.array([1, 0, 1, 1]).reshape(-1, 1)
m = paddle.metric.Precision()
m.update(x, y)
r = m.accumulate()
self.assertAlmostEqual(r, 2. / 3.)
x = np.array([0.1, 0.5, 0.6, 0.7, 0.2]).reshape(-1, 1)
y = np.array([1, 0, 1, 1, 1]).reshape(-1, 1)
m.update(x, y)
r = m.accumulate()
self.assertAlmostEqual(r, 4. / 6.)
# check reset
m.reset()
self.assertEqual(m.tp, 0.0)
self.assertEqual(m.fp, 0.0)
self.assertEqual(m.accumulate(), 0.0)
paddle.enable_static()
class TestRecall(unittest.TestCase):
def test_1d(self):
paddle.disable_static()
x = np.array([0.1, 0.5, 0.6, 0.7])
y = np.array([1, 0, 1, 1])
m = paddle.metric.Recall()
m.update(x, y)
r = m.accumulate()
self.assertAlmostEqual(r, 2. / 3.)
x = paddle.to_tensor(np.array([0.1, 0.5, 0.6, 0.7]))
y = paddle.to_tensor(np.array([1, 0, 0, 1]))
m.update(x, y)
r = m.accumulate()
self.assertAlmostEqual(r, 3. / 5.)
# check reset
m.reset()
self.assertEqual(m.tp, 0.0)
self.assertEqual(m.fn, 0.0)
self.assertEqual(m.accumulate(), 0.0)
paddle.enable_static()
class TestAuc(unittest.TestCase):
def test_auc_numpy(self):
paddle.disable_static()
x = np.array([[0.78, 0.22], [0.62, 0.38], [0.55, 0.45], [0.30, 0.70],
[0.14, 0.86], [0.59, 0.41], [0.91, 0.08], [0.16, 0.84]])
y = np.array([[0], [1], [1], [0], [1], [0], [0], [1]])
m = paddle.metric.Auc()
m.update(x, y)
r = m.accumulate()
self.assertAlmostEqual(r, 0.8125)
m.reset()
self.assertEqual(m.accumulate(), 0.0)
paddle.enable_static()
def test_auc_tensor(self):
paddle.disable_static()
x = paddle.to_tensor(
np.array([[0.78, 0.22], [0.62, 0.38], [0.55, 0.45], [0.30, 0.70],
[0.14, 0.86], [0.59, 0.41], [0.91, 0.08], [0.16, 0.84]]))
y = paddle.to_tensor(np.array([[0], [1], [1], [0], [1], [0], [0], [1]]))
m = paddle.metric.Auc()
m.update(x, y)
r = m.accumulate()
self.assertAlmostEqual(r, 0.8125)
m.reset()
self.assertEqual(m.accumulate(), 0.0)
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
......@@ -106,12 +106,27 @@
"Metric.update",
"Metric.accumulate",
"Metric.name",
"Metric.add_metric_op",
"Metric.compute",
"Accuracy.reset",
"Accuracy.update",
"Accuracy.accumulate",
"Accuracy.name",
"Accuracy.add_metric_op",
"Accuracy.compute",
"Precision.reset",
"Precision.update",
"Precision.accumulate",
"Precision.name",
"Precision.compute",
"Recall.reset",
"Recall.update",
"Recall.accumulate",
"Recall.name",
"Recall.compute",
"Auc.reset",
"Auc.update",
"Auc.accumulate",
"Auc.name",
"Auc.compute",
"Callback.set_params",
"Callback.on_train_begin",
"Callback.on_train_end",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册