metrics.py 18.6 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from cmath import nan
W
weishengyu 已提交
16 17 18
import numpy as np
import paddle
import paddle.nn as nn
C
cuicheng01 已提交
19 20 21 22 23 24
import paddle.nn.functional as F

from sklearn.metrics import hamming_loss
from sklearn.metrics import accuracy_score as accuracy_metric
from sklearn.metrics import multilabel_confusion_matrix
from sklearn.preprocessing import binarize
W
weishengyu 已提交
25

Z
zhiboniu 已提交
26 27
from easydict import EasyDict

C
cuicheng01 已提交
28
from ppcls.metric.avg_metrics import AvgMetrics
Z
zhiboniu 已提交
29
from ppcls.utils.misc import AverageMeter, AttrMeter
30
from ppcls.utils import logger
D
dongshuilong 已提交
31

C
cuicheng01 已提交
32 33

class TopkAcc(AvgMetrics):
W
weishengyu 已提交
34 35 36 37 38 39
    def __init__(self, topk=(1, 5)):
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk
C
cuicheng01 已提交
40 41 42
        self.reset()

    def reset(self):
43
        self.avg_meters = {
44
            f"top{k}": AverageMeter(f"top{k}")
45 46
            for k in self.topk
        }
W
weishengyu 已提交
47 48 49 50 51

    def forward(self, x, label):
        if isinstance(x, dict):
            x = x["logits"]

52 53
        output_dims = x.shape[-1]

W
weishengyu 已提交
54
        metric_dict = dict()
55 56 57 58
        for idx, k in enumerate(self.topk):
            if output_dims < k:
                msg = f"The output dims({output_dims}) is less than k({k}), and the argument {k} of Topk has been removed."
                logger.warning(msg)
59
                self.avg_meters.pop(f"top{k}")
60
                continue
61 62 63 64
            metric_dict[f"top{k}"] = paddle.metric.accuracy(x, label, k=k)
            self.avg_meters[f"top{k}"].update(metric_dict[f"top{k}"],
                                              x.shape[0])

T
Tingquan Gao 已提交
65
        self.topk = list(filter(lambda k: k <= output_dims, self.topk))
66

W
weishengyu 已提交
67 68
        return metric_dict

D
dongshuilong 已提交
69

W
weishengyu 已提交
70
class mAP(nn.Layer):
H
HydrogenSulfate 已提交
71
    def __init__(self, descending=True):
W
weishengyu 已提交
72
        super().__init__()
H
HydrogenSulfate 已提交
73
        self.descending = descending
W
weishengyu 已提交
74

D
dongshuilong 已提交
75
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
D
dongshuilong 已提交
76
                keep_mask):
W
weishengyu 已提交
77
        metric_dict = dict()
D
dongshuilong 已提交
78 79

        choosen_indices = paddle.argsort(
H
HydrogenSulfate 已提交
80
            similarities_matrix, axis=1, descending=self.descending)
D
dongshuilong 已提交
81 82 83 84 85 86 87 88
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
B
Bin Lu 已提交
89
        equal_flag = paddle.equal(choosen_label, query_img_id)
D
dongshuilong 已提交
90 91 92 93 94
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
B
Bin Lu 已提交
95 96
        equal_flag = paddle.cast(equal_flag, 'float32')

D
dongshuilong 已提交
97 98 99 100
        num_rel = paddle.sum(equal_flag, axis=1)
        num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.))
        num_rel_index = paddle.nonzero(num_rel.astype("int"))
        num_rel_index = paddle.reshape(num_rel_index, [num_rel_index.shape[0]])
101 102 103 104 105

        if paddle.numel(num_rel_index).item() == 0:
            metric_dict["mAP"] = np.nan
            return metric_dict

D
dongshuilong 已提交
106 107
        equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0)

B
Bin Lu 已提交
108 109
        acc_sum = paddle.cumsum(equal_flag, axis=1)
        div = paddle.arange(acc_sum.shape[1]).astype("float32") + 1
D
dongshuilong 已提交
110
        precision = paddle.divide(acc_sum, div)
B
Bin Lu 已提交
111 112 113

        #calc map
        precision_mask = paddle.multiply(equal_flag, precision)
D
dongshuilong 已提交
114 115
        ap = paddle.sum(precision_mask, axis=1) / paddle.sum(equal_flag,
                                                             axis=1)
116
        metric_dict["mAP"] = float(paddle.mean(ap))
W
weishengyu 已提交
117 118
        return metric_dict

D
dongshuilong 已提交
119

W
weishengyu 已提交
120
class mINP(nn.Layer):
H
HydrogenSulfate 已提交
121
    def __init__(self, descending=True):
W
weishengyu 已提交
122
        super().__init__()
H
HydrogenSulfate 已提交
123
        self.descending = descending
W
weishengyu 已提交
124

D
dongshuilong 已提交
125
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
D
dongshuilong 已提交
126
                keep_mask):
W
weishengyu 已提交
127
        metric_dict = dict()
D
dongshuilong 已提交
128 129

        choosen_indices = paddle.argsort(
H
HydrogenSulfate 已提交
130
            similarities_matrix, axis=1, descending=self.descending)
D
dongshuilong 已提交
131 132 133 134 135 136 137 138
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
D
dongshuilong 已提交
139 140
        equal_flag = paddle.equal(choosen_label, query_img_id)
        if keep_mask is not None:
141
            keep_mask = paddle.indechmx_sample(
D
dongshuilong 已提交
142 143 144
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
D
dongshuilong 已提交
145
        equal_flag = paddle.cast(equal_flag, 'float32')
D
dongshuilong 已提交
146 147 148 149 150 151

        num_rel = paddle.sum(equal_flag, axis=1)
        num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.))
        num_rel_index = paddle.nonzero(num_rel.astype("int"))
        num_rel_index = paddle.reshape(num_rel_index, [num_rel_index.shape[0]])
        equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0)
B
Bin Lu 已提交
152 153

        #do accumulative sum
D
dongshuilong 已提交
154
        div = paddle.arange(equal_flag.shape[1]).astype("float32") + 2
D
dongshuilong 已提交
155 156
        minus = paddle.divide(equal_flag, div)
        auxilary = paddle.subtract(equal_flag, minus)
D
dongshuilong 已提交
157
        hard_index = paddle.argmax(auxilary, axis=1).astype("float32")
D
dongshuilong 已提交
158
        all_INP = paddle.divide(paddle.sum(equal_flag, axis=1), hard_index)
B
Bin Lu 已提交
159
        mINP = paddle.mean(all_INP)
160
        metric_dict["mINP"] = float(mINP)
W
weishengyu 已提交
161 162
        return metric_dict

D
dongshuilong 已提交
163

C
cuicheng01 已提交
164
class TprAtFpr(nn.Layer):
165
    def __init__(self, max_fpr=1 / 1000.):
C
cuicheng01 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
        super().__init__()
        self.gt_pos_score_list = []
        self.gt_neg_score_list = []
        self.softmax = nn.Softmax(axis=-1)
        self.max_fpr = max_fpr
        self.max_tpr = 0.

    def forward(self, x, label):
        if isinstance(x, dict):
            x = x["logits"]
        x = self.softmax(x)
        for i, label_i in enumerate(label):
            if label_i[0] == 0:
                self.gt_neg_score_list.append(x[i][1].numpy())
            else:
                self.gt_pos_score_list.append(x[i][1].numpy())
        return {}

    def reset(self):
        self.gt_pos_score_list = []
        self.gt_neg_score_list = []
        self.max_tpr = 0.

    @property
    def avg(self):
        return self.max_tpr

    @property
    def avg_info(self):
        max_tpr = 0.
        result = ""
        gt_pos_score_list = np.array(self.gt_pos_score_list)
        gt_neg_score_list = np.array(self.gt_neg_score_list)
        for i in range(0, 10000):
            threshold = i / 10000.
            if len(gt_pos_score_list) == 0:
                continue
203 204
            tpr = np.sum(
                gt_pos_score_list > threshold) / len(gt_pos_score_list)
C
cuicheng01 已提交
205 206
            if len(gt_neg_score_list) == 0 and tpr > max_tpr:
                max_tpr = tpr
C
cuicheng01 已提交
207 208 209 210
                result = "threshold: {}, fpr: 0.0, tpr: {:.5f}".format(
                    threshold, tpr)
                msg = f"The number of negative samples is 0, please add negative samples."
                logger.warning(msg)
211 212
            fpr = np.sum(
                gt_neg_score_list > threshold) / len(gt_neg_score_list)
C
cuicheng01 已提交
213 214
            if fpr <= self.max_fpr and tpr > max_tpr:
                max_tpr = tpr
215 216
                result = "threshold: {}, fpr: {}, tpr: {:.5f}".format(
                    threshold, fpr, tpr)
C
cuicheng01 已提交
217 218 219 220
        self.max_tpr = max_tpr
        return result


weixin_46524038's avatar
weixin_46524038 已提交
221
class MultilabelMeanAccuracy(nn.Layer):
weixin_46524038's avatar
weixin_46524038 已提交
222 223 224 225
    def __init__(self,
                 start_threshold=0.4,
                 num_iterations=10,
                 end_threshold=0.9):
weixin_46524038's avatar
weixin_46524038 已提交
226
        super().__init__()
weixin_46524038's avatar
weixin_46524038 已提交
227 228 229
        self.start_threshold = start_threshold
        self.num_iterations = num_iterations
        self.end_threshold = end_threshold
weixin_46524038's avatar
weixin_46524038 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
        self.gt_all_score_list = []
        self.gt_label_score_list = []
        self.max_acc = 0.

    def forward(self, x, label):
        if isinstance(x, dict):
            x = x["logits"]
        x = F.sigmoid(x)
        label = label[:, 0, :]
        for i in range(len(x)):
            self.gt_all_score_list.append(x[i].numpy())
            self.gt_label_score_list.append(label[i].numpy())
        return {}

    def reset(self):
        self.gt_all_score_list = []
        self.gt_label_score_list = []
        self.max_acc = 0.

    @property
    def avg(self):
        return self.max_acc

    @property
    def avg_info(self):
        max_acc = 0.
        result = ""
        gt_all_score_list = np.array(self.gt_all_score_list)
        gt_label_score_list = np.array(self.gt_label_score_list)
weixin_46524038's avatar
weixin_46524038 已提交
259 260 261 262
        for i in range(self.num_iterations):
            threshold = self.start_threshold + i * (self.end_threshold -
                                                    self.start_threshold
                                                    ) / self.num_iterations
weixin_46524038's avatar
weixin_46524038 已提交
263 264 265 266 267 268 269 270 271
            pred_label = (gt_all_score_list > threshold).astype(int)
            TP = np.sum(
                (gt_label_score_list == 1) * (pred_label == 1)).astype(float)
            TN = np.sum(
                (gt_label_score_list == 0) * (pred_label == 0)).astype(float)
            acc = (TP + TN) / len(gt_all_score_list)
            if max_acc <= acc:
                max_acc = acc
                result = "threshold: {}, mean_acc: {}".format(
weixin_46524038's avatar
weixin_46524038 已提交
272 273
                    threshold, max_acc / len(gt_label_score_list[0]))
        self.max_acc = max_acc / len(gt_label_score_list[0])
weixin_46524038's avatar
weixin_46524038 已提交
274 275 276
        return result


W
weishengyu 已提交
277
class Recallk(nn.Layer):
H
HydrogenSulfate 已提交
278
    def __init__(self, topk=(1, 5), descending=True):
W
weishengyu 已提交
279
        super().__init__()
B
Bin Lu 已提交
280
        assert isinstance(topk, (int, list, tuple))
W
weishengyu 已提交
281 282 283
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk
H
HydrogenSulfate 已提交
284
        self.descending = descending
W
weishengyu 已提交
285

D
dongshuilong 已提交
286 287
    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
                keep_mask):
W
weishengyu 已提交
288
        metric_dict = dict()
B
Bin Lu 已提交
289 290

        #get cmc
D
dongshuilong 已提交
291
        choosen_indices = paddle.argsort(
H
HydrogenSulfate 已提交
292
            similarities_matrix, axis=1, descending=self.descending)
D
dongshuilong 已提交
293 294 295 296 297 298 299 300
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
B
Bin Lu 已提交
301
        equal_flag = paddle.equal(choosen_label, query_img_id)
D
dongshuilong 已提交
302 303 304 305 306
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
B
Bin Lu 已提交
307
        equal_flag = paddle.cast(equal_flag, 'float32')
D
dongshuilong 已提交
308 309 310 311 312
        real_query_num = paddle.sum(equal_flag, axis=1)
        real_query_num = paddle.sum(
            paddle.greater_than(real_query_num, paddle.to_tensor(0.)).astype(
                "float32"))

B
Bin Lu 已提交
313
        acc_sum = paddle.cumsum(equal_flag, axis=1)
D
dongshuilong 已提交
314 315 316
        mask = paddle.greater_than(acc_sum,
                                   paddle.to_tensor(0.)).astype("float32")
        all_cmc = (paddle.sum(mask, axis=0) / real_query_num).numpy()
W
weishengyu 已提交
317 318 319 320 321

        for k in self.topk:
            metric_dict["recall{}".format(k)] = all_cmc[k - 1]
        return metric_dict

D
dongshuilong 已提交
322

B
Bin Lu 已提交
323
class Precisionk(nn.Layer):
H
HydrogenSulfate 已提交
324
    def __init__(self, topk=(1, 5), descending=True):
B
Bin Lu 已提交
325 326 327 328 329
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk
H
HydrogenSulfate 已提交
330
        self.descending = descending
B
Bin Lu 已提交
331 332 333 334 335 336 337

    def forward(self, similarities_matrix, query_img_id, gallery_img_id,
                keep_mask):
        metric_dict = dict()

        #get cmc
        choosen_indices = paddle.argsort(
H
HydrogenSulfate 已提交
338
            similarities_matrix, axis=1, descending=self.descending)
B
Bin Lu 已提交
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
        gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
        gallery_labels_transpose = paddle.broadcast_to(
            gallery_labels_transpose,
            shape=[
                choosen_indices.shape[0], gallery_labels_transpose.shape[1]
            ])
        choosen_label = paddle.index_sample(gallery_labels_transpose,
                                            choosen_indices)
        equal_flag = paddle.equal(choosen_label, query_img_id)
        if keep_mask is not None:
            keep_mask = paddle.index_sample(
                keep_mask.astype('float32'), choosen_indices)
            equal_flag = paddle.logical_and(equal_flag,
                                            keep_mask.astype('bool'))
        equal_flag = paddle.cast(equal_flag, 'float32')
C
cuicheng01 已提交
354

B
Bin Lu 已提交
355 356 357 358 359 360 361 362 363 364
        Ns = paddle.arange(gallery_img_id.shape[0]) + 1
        equal_flag_cumsum = paddle.cumsum(equal_flag, axis=1)
        Precision_at_k = (paddle.mean(equal_flag_cumsum, axis=0) / Ns).numpy()

        for k in self.topk:
            metric_dict["precision@{}".format(k)] = Precision_at_k[k - 1]

        return metric_dict


365 366 367 368 369 370 371
class DistillationTopkAcc(TopkAcc):
    def __init__(self, model_key, feature_key=None, topk=(1, 5)):
        super().__init__(topk=topk)
        self.model_key = model_key
        self.feature_key = feature_key

    def forward(self, x, label):
G
gaotingquan 已提交
372 373
        if isinstance(x, dict):
            x = x[self.model_key]
374 375 376
        if self.feature_key is not None:
            x = x[self.feature_key]
        return super().forward(x, label)
C
cuicheng01 已提交
377 378 379 380 381 382 383 384 385 386 387 388


class GoogLeNetTopkAcc(TopkAcc):
    def __init__(self, topk=(1, 5)):
        super().__init__()
        assert isinstance(topk, (int, list, tuple))
        if isinstance(topk, int):
            topk = [topk]
        self.topk = topk

    def forward(self, x, label):
        return super().forward(x[0], label)
C
cuicheng01 已提交
389 390


C
cuicheng01 已提交
391 392 393 394
class MultiLabelMetric(AvgMetrics):
    def __init__(self, bi_threshold=0.5):
        super().__init__()
        self.bi_threshold = bi_threshold
C
cuicheng01 已提交
395

C
cuicheng01 已提交
396 397 398
    def _multi_hot_encode(self, output):
        logits = F.sigmoid(output).numpy()
        return binarize(logits, threshold=self.bi_threshold)
C
cuicheng01 已提交
399 400


C
cuicheng01 已提交
401
class HammingDistance(MultiLabelMetric):
C
cuicheng01 已提交
402 403 404 405 406 407 408 409
    """
    Soft metric based label for multilabel classification
    Returns:
        The smaller the return value is, the better model is.
    """

    def __init__(self):
        super().__init__()
C
cuicheng01 已提交
410 411 412
        self.reset()

    def reset(self):
C
cuicheng01 已提交
413
        self.avg_meters = {"HammingDistance": AverageMeter("HammingDistance")}
C
cuicheng01 已提交
414

C
cuicheng01 已提交
415 416
    def forward(self, output, target):
        preds = super()._multi_hot_encode(output)
C
cuicheng01 已提交
417 418 419
        metric_dict = dict()
        metric_dict["HammingDistance"] = paddle.to_tensor(
            hamming_loss(target, preds))
420
        self.avg_meters["HammingDistance"].update(
421
            float(metric_dict["HammingDistance"]), output.shape[0])
C
cuicheng01 已提交
422 423 424
        return metric_dict


C
cuicheng01 已提交
425
class AccuracyScore(MultiLabelMetric):
C
cuicheng01 已提交
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
    """
    Hard metric for multilabel classification
    Args:
        base: ["sample", "label"], default="sample"
            if "sample", return metric score based sample,
            if "label", return metric score based label.
    Returns:
        accuracy:
    """

    def __init__(self, base="label"):
        super().__init__()
        assert base in ["sample", "label"
                        ], 'must be one of ["sample", "label"]'
        self.base = base
C
cuicheng01 已提交
441 442 443 444
        self.reset()

    def reset(self):
        self.avg_meters = {"AccuracyScore": AverageMeter("AccuracyScore")}
C
cuicheng01 已提交
445

C
cuicheng01 已提交
446 447
    def forward(self, output, target):
        preds = super()._multi_hot_encode(output)
C
cuicheng01 已提交
448 449 450 451 452 453 454 455 456 457 458 459
        metric_dict = dict()
        if self.base == "sample":
            accuracy = accuracy_metric(target, preds)
        elif self.base == "label":
            mcm = multilabel_confusion_matrix(target, preds)
            tns = mcm[:, 0, 0]
            fns = mcm[:, 1, 0]
            tps = mcm[:, 1, 1]
            fps = mcm[:, 0, 1]
            accuracy = (sum(tps) + sum(tns)) / (
                sum(tps) + sum(tns) + sum(fns) + sum(fps))
        metric_dict["AccuracyScore"] = paddle.to_tensor(accuracy)
460
        self.avg_meters["AccuracyScore"].update(
461
            float(metric_dict["AccuracyScore"]), output.shape[0])
C
cuicheng01 已提交
462
        return metric_dict
Z
zhiboniu 已提交
463 464 465 466 467


def get_attr_metrics(gt_label, preds_probs, threshold):
    """
    index: evaluated label index
Z
zhiboniu 已提交
468
    adapted from "https://github.com/valencebond/Rethinking_of_PAR/blob/master/metrics/pedestrian_metrics.py"
Z
zhiboniu 已提交
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515
    """
    pred_label = (preds_probs > threshold).astype(int)

    eps = 1e-20
    result = EasyDict()

    has_fuyi = gt_label == -1
    pred_label[has_fuyi] = -1

    ###############################
    # label metrics
    # TP + FN
    result.gt_pos = np.sum((gt_label == 1), axis=0).astype(float)
    # TN + FP
    result.gt_neg = np.sum((gt_label == 0), axis=0).astype(float)
    # TP
    result.true_pos = np.sum((gt_label == 1) * (pred_label == 1),
                             axis=0).astype(float)
    # TN
    result.true_neg = np.sum((gt_label == 0) * (pred_label == 0),
                             axis=0).astype(float)
    # FP
    result.false_pos = np.sum(((gt_label == 0) * (pred_label == 1)),
                              axis=0).astype(float)
    # FN
    result.false_neg = np.sum(((gt_label == 1) * (pred_label == 0)),
                              axis=0).astype(float)

    ################
    # instance metrics
    result.gt_pos_ins = np.sum((gt_label == 1), axis=1).astype(float)
    result.true_pos_ins = np.sum((pred_label == 1), axis=1).astype(float)
    # true positive
    result.intersect_pos = np.sum((gt_label == 1) * (pred_label == 1),
                                  axis=1).astype(float)
    # IOU
    result.union_pos = np.sum(((gt_label == 1) + (pred_label == 1)),
                              axis=1).astype(float)

    return result


class ATTRMetric(nn.Layer):
    def __init__(self, threshold=0.5):
        super().__init__()
        self.threshold = threshold

Z
zhiboniu 已提交
516 517 518
    def reset(self):
        self.attrmeter = AttrMeter(threshold=0.5)

Z
zhiboniu 已提交
519
    def forward(self, output, target):
Z
zhiboniu 已提交
520
        metric_dict = get_attr_metrics(target[:, 0, :].numpy(),
Z
zhiboniu 已提交
521
                                       output.numpy(), self.threshold)
Z
zhiboniu 已提交
522
        self.attrmeter.update(metric_dict)
C
cuicheng01 已提交
523
        return metric_dict