diff --git a/ppcls/configs/Cartoonface/ResNet50_icartoon.yaml b/ppcls/configs/Cartoonface/ResNet50_icartoon.yaml index f18f3346bd6abf6a997ee1461cc815799c61d6a6..9640b7be91cdc7fc805898ab24c458c4dab950c1 100644 --- a/ppcls/configs/Cartoonface/ResNet50_icartoon.yaml +++ b/ppcls/configs/Cartoonface/ResNet50_icartoon.yaml @@ -138,4 +138,4 @@ Metric: topk: [1, 5] Eval: - Recallk: - topk: 1 + topk: [1] diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index d2e66bc54dc298ec329b45305684eb33f39da11b..8ec438eced5feac5edd7f40982d88e4d99d0b177 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -69,7 +69,7 @@ class mINP(nn.Layer): class Recallk(nn.Layer): def __init__(self, topk=(1, 5)): super().__init__() - assert isinstance(topk, (int, list)) + assert isinstance(topk, (int, list, tuple)) if isinstance(topk, int): topk = [topk] self.topk = topk @@ -97,6 +97,9 @@ class RetriMetric(nn.Layer): gallery_img_id, self.max_rank) if "Recallk" in self.config.keys(): topk = self.config['Recallk']['topk'] + assert isinstance(topk, (int, list, tuple)) + if isinstance(topk, int): + topk = [topk] for k in topk: metric_dict["recall{}".format(k)] = all_cmc[k - 1] if "mAP" in self.config.keys():