recall_k.py 3.7 KB
Newer Older
M
malin10 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import numpy as np
import paddle.fluid as fluid

from paddlerec.core.metric import Metric
M
bug fix  
malin10 已提交
21
from paddle.fluid.layers import accuracy
M
malin10 已提交
22 23
from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper
M
malin10 已提交
24
from paddle.fluid.layers.tensor import Variable
M
malin10 已提交
25 26 27 28 29 30 31


class RecallK(Metric):
    """
    Metric For Fluid Model
    """

M
doc  
malin10 已提交
32
    def __init__(self, input, label, k=20):
M
malin10 已提交
33
        """ """
M
doc  
malin10 已提交
34 35 36
        kwargs = locals()
        del kwargs['self']
        self.k = k
M
malin10 已提交
37

M
doc  
malin10 已提交
38
        if not isinstance(input, Variable):
M
malin10 已提交
39
            raise ValueError("input must be Variable, but received %s" %
M
doc  
malin10 已提交
40
                             type(input))
M
malin10 已提交
41 42 43 44
        if not isinstance(label, Variable):
            raise ValueError("label must be Variable, but received %s" %
                             type(label))

M
malin10 已提交
45
        helper = LayerHelper("PaddleRec_RecallK", **kwargs)
M
doc  
malin10 已提交
46
        batch_accuracy = accuracy(input, label, self.k)
M
malin10 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60
        global_ins_cnt, _ = helper.create_or_get_global_variable(
            name="ins_cnt", persistable=True, dtype='float32', shape=[1])
        global_pos_cnt, _ = helper.create_or_get_global_variable(
            name="pos_cnt", persistable=True, dtype='float32', shape=[1])

        for var in [global_ins_cnt, global_pos_cnt]:
            helper.set_variable_initializer(
                var, Constant(
                    value=0.0, force_cpu=True))

        tmp_ones = fluid.layers.fill_constant(
            shape=fluid.layers.shape(label), dtype="float32", value=1.0)
        batch_ins = fluid.layers.reduce_sum(tmp_ones)
        batch_pos = batch_ins * batch_accuracy
M
malin10 已提交
61 62

        helper.append_op(
M
malin10 已提交
63 64 65 66 67
            type="elementwise_add",
            inputs={"X": [global_ins_cnt],
                    "Y": [batch_ins]},
            outputs={"Out": [global_ins_cnt]})

M
malin10 已提交
68
        helper.append_op(
M
malin10 已提交
69 70 71 72
            type="elementwise_add",
            inputs={"X": [global_pos_cnt],
                    "Y": [batch_pos]},
            outputs={"Out": [global_pos_cnt]})
M
malin10 已提交
73

M
malin10 已提交
74
        self.acc = global_pos_cnt / global_ins_cnt
M
malin10 已提交
75

M
malin10 已提交
76 77 78 79 80
        self._global_metric_state_vars = dict()
        self._global_metric_state_vars['ins_cnt'] = (global_ins_cnt.name,
                                                     "float32")
        self._global_metric_state_vars['pos_cnt'] = (global_pos_cnt.name,
                                                     "float32")
M
malin10 已提交
81

M
update  
malin10 已提交
82
        metric_name = "Acc(Recall@%d)" % self.k
M
malin10 已提交
83
        self.metrics = dict()
M
update  
malin10 已提交
84 85
        self.metrics["InsCnt"] = global_ins_cnt
        self.metrics["RecallCnt"] = global_pos_cnt
M
malin10 已提交
86
        self.metrics[metric_name] = self.acc
M
malin10 已提交
87

M
update  
malin10 已提交
88
    # self.metrics["batch_metrics"] = batch_metrics
M
bug fix  
malin10 已提交
89
    def _calculate(self, global_metrics):
M
malin10 已提交
90
        for key in self._global_metric_state_vars:
M
update  
malin10 已提交
91 92 93 94 95 96 97 98 99 100 101
            if key not in global_metrics:
                raise ValueError("%s not existed" % key)
        ins_cnt = global_metrics['ins_cnt'][0]
        pos_cnt = global_metrics['pos_cnt'][0]
        if ins_cnt == 0:
            acc = 0
        else:
            acc = float(pos_cnt) / ins_cnt
        return "InsCnt=%s RecallCnt=%s Acc(Recall@%d)=%s" % (
            str(ins_cnt), str(pos_cnt), self.k, str(acc))

M
malin10 已提交
102 103
    def get_result(self):
        return self.metrics