提交 50be6e6b 编写于 作者: D dengkaipeng

refine doc for metrics

上级 581b4944
......@@ -33,10 +33,78 @@ class Metric(object):
Base class for metric, encapsulates metric logic and APIs
Usage:
m = SomeMetric()
for prediction, label in ...:
m.update(prediction, label)
m.accumulate()
.. code-block:: python
m = SomeMetric()
for prediction, label in ...:
m.update(prediction, label)
m.accumulate()
Advanced usage for :code:`add_metric_op`
Metric calculating con be accelerate by calucateing metric states
from model outputs and labels by Paddle OPs in :code:`add_metric_op`,
metric states will be fetch as numpy array and call :code:`update`
with states in numpy format.
Metric calculated as follows (operations in Model and Metric are
indicated with curly brackets, while data nodes not):
inputs & labels ||
| ||
{model} ||
| ||
outputs labels ||
| ||
{Metric.add_metric_op} ||
| ||
metric states ||
| ||
{fetch as numpy} ||
| ||
metric states(numpy) ||
| ||
{Metric.update} \/
Examples:
For :code:`Accuracy` metric, which takes :code:`pred` and :code:`label`
as inputs, we can calculate the correct prediction matrix between
:code:`pred` and :code:`label` in :code:`add_metric_op`.
For examples, prediction results contains 10 classes, while :code:`pred`
shape is [N, 10], :code:`label` shape is [N, 1], N is mini-batch size,
and we only need to calculate accurary of top-1 and top-5, we could
calculated the correct prediction matrix of the top-5 scores of the
prediction of each sample like follows, while the correct prediction
matrix shape is [N, 5].
.. code-block:: python
def add_metric_op(pred, label):
# sort prediction and slice the top-5 scores
pred = fluid.layers.argsort(pred, descending=True)[1][:, :5]
# calculate whether the predictions are correct
correct = pred == label
return fluid.layers.cast(correct, dtype='float32')
With the :code:`add_metric_op`, we split some calculations to OPs(which
may run on GPU devices, will be faster), and only fetch 1 tensor with
shape as [N, 5] instead of 2 tensors with shapes as [N, 10] and [N, 1].
:code:`update` can be define as follows:
.. code-block:: python
def update(self, correct):
accs = []
for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum()
num_samples = len(correct)
accs.append(float(num_corrects) / num_samples)
self.total[i] += num_corrects
self.count[i] += num_samples
return accs
"""
@abc.abstractmethod
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册