mean_average_precision_calculator.py 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calculate the mean average precision.

It provides an interface for calculating mean average precision
for an entire list or the top-n ranked items.

Example usages:
We first call the function accumulate many times to process parts of the ranked
list. After processing all the parts, we call peek_map_at_n
to calculate the mean average precision.

```
import random

p = np.array([[random.random() for _ in xrange(50)] for _ in xrange(1000)])
a = np.array([[random.choice([0, 1]) for _ in xrange(50)]
     for _ in xrange(1000)])

# mean average precision for 50 classes.
calculator = mean_average_precision_calculator.MeanAveragePrecisionCalculator(
            num_class=50)
calculator.accumulate(p, a)
aps = calculator.peek_map_at_n()
```
"""

import numpy
from . import average_precision_calculator


class MeanAveragePrecisionCalculator(object):
    """This class is to calculate mean average precision.
  """

    def __init__(self, num_class):
        """Construct a calculator to calculate the (macro) average precision.

    Args:
      num_class: A positive Integer specifying the number of classes.
      top_n_array: A list of positive integers specifying the top n for each
      class. The top n in each class will be used to calculate its average
      precision at n.
      The size of the array must be num_class.

    Raises:
      ValueError: An error occurred when num_class is not a positive integer;
      or the top_n_array is not a list of positive integers.
    """
        if not isinstance(num_class, int) or num_class <= 1:
            raise ValueError("num_class must be a positive integer.")

        self._ap_calculators = []  # member of AveragePrecisionCalculator
        self._num_class = num_class  # total number of classes
        for i in range(num_class):
            self._ap_calculators.append(
                average_precision_calculator.AveragePrecisionCalculator())

    def accumulate(self, predictions, actuals, num_positives=None):
        """Accumulate the predictions and their ground truth labels.

    Args:
      predictions: A list of lists storing the prediction scores. The outer
      dimension corresponds to classes.
      actuals: A list of lists storing the ground truth labels. The dimensions
      should correspond to the predictions input. Any value
      larger than 0 will be treated as positives, otherwise as negatives.
      num_positives: If provided, it is a list of numbers representing the
      number of true positives for each class. If not provided, the number of
      true positives will be inferred from the 'actuals' array.

    Raises:
      ValueError: An error occurred when the shape of predictions and actuals
      does not match.
    """
        if not num_positives:
            num_positives = [None for i in predictions.shape[1]]

        calculators = self._ap_calculators
        for i in range(len(predictions)):
            calculators[i].accumulate(predictions[i], actuals[i],
                                      num_positives[i])

    def clear(self):
        for calculator in self._ap_calculators:
            calculator.clear()

    def is_empty(self):
        return ([calculator.heap_size for calculator in self._ap_calculators] ==
                [0 for _ in range(self._num_class)])

    def peek_map_at_n(self):
        """Peek the non-interpolated mean average precision at n.

    Returns:
      An array of non-interpolated average precision at n (default 0) for each
      class.
    """
        aps = [
            self._ap_calculators[i].peek_ap_at_n()
            for i in range(self._num_class)
        ]
        return aps