From 0354f4457412db633138773dea82acc6df5c072b Mon Sep 17 00:00:00 2001 From: Bin Lu Date: Mon, 23 Aug 2021 19:29:39 +0800 Subject: [PATCH] Update metrics.py --- ppcls/metric/metrics.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 9908b354..204d2af0 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -168,6 +168,47 @@ class Recallk(nn.Layer): return metric_dict +class Precisionk(nn.Layer): + def __init__(self, topk=(1, 5)): + super().__init__() + assert isinstance(topk, (int, list, tuple)) + if isinstance(topk, int): + topk = [topk] + self.topk = topk + + def forward(self, similarities_matrix, query_img_id, gallery_img_id, + keep_mask): + metric_dict = dict() + + #get cmc + choosen_indices = paddle.argsort( + similarities_matrix, axis=1, descending=True) + gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0]) + gallery_labels_transpose = paddle.broadcast_to( + gallery_labels_transpose, + shape=[ + choosen_indices.shape[0], gallery_labels_transpose.shape[1] + ]) + choosen_label = paddle.index_sample(gallery_labels_transpose, + choosen_indices) + equal_flag = paddle.equal(choosen_label, query_img_id) + if keep_mask is not None: + keep_mask = paddle.index_sample( + keep_mask.astype('float32'), choosen_indices) + equal_flag = paddle.logical_and(equal_flag, + keep_mask.astype('bool')) + equal_flag = paddle.cast(equal_flag, 'float32') + + Ns = paddle.arange(gallery_img_id.shape[0]) + 1 + equal_flag_cumsum = paddle.cumsum(equal_flag, axis=1) + Precision_at_k = (paddle.mean(equal_flag_cumsum, axis=0) / Ns).numpy() + + for k in self.topk: + metric_dict["precision@{}".format(k)] = Precision_at_k[k - 1] + + return metric_dict + + class DistillationTopkAcc(TopkAcc): def __init__(self, model_key, feature_key=None, topk=(1, 5)): super().__init__(topk=topk) -- GitLab