未验证 提交 ae24d832 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #793 from RainFrost1/develop_reg

修复metrics一些问题
......@@ -148,9 +148,9 @@ Infer:
Metric:
Train:
- Topk:
k: [1, 5]
- TopkAcc:
topk: [1, 5]
Eval:
- Topk:
k: [1, 5]
- TopkAcc:
topk: [1, 5]
# global configs
Trainer:
name: TrainerReID
Global:
checkpoints: null
pretrained_model: null
......@@ -16,8 +14,7 @@ Global:
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
num_split: 1
feature_normalize: True
eval_mode: "retrieval"
# model architecture
Arch:
......@@ -99,10 +96,10 @@ DataLoader:
loader:
num_workers: 6
use_shared_memory: False
Query:
Eval:
Query:
# TOTO: modify to the latest trainer
dataset:
dataset:
name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id_query.txt"
......@@ -114,18 +111,18 @@ DataLoader:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
loader:
num_workers: 6
use_shared_memory: False
Gallery:
Gallery:
# TOTO: modify to the latest trainer
dataset:
dataset:
name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id.txt"
......@@ -137,15 +134,21 @@ DataLoader:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
loader:
num_workers: 6
use_shared_memory: False
Metric:
Eval:
- Recallk:
topk: [1, 5]
- mAP: {}
Infer:
infer_imgs: "docs/images/whl/demo.jpg"
batch_size: 10
......
......@@ -15,6 +15,7 @@
import numpy as np
import paddle
import paddle.nn as nn
from functools import lru_cache
# TODO: fix the format
......@@ -38,23 +39,13 @@ class TopkAcc(nn.Layer):
class mAP(nn.Layer):
def __init__(self, max_rank=50):
def __init__(self):
super().__init__()
self.max_rank = max_rank
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
num_q, num_g = similarities_matrix.shape
q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
if num_g < self.max_rank:
self.max_rank = num_g
print('Note: number of gallery samples is quite small, got {}'.
format(num_g))
indices = paddle.argsort(
similarities_matrix, axis=1, descending=True).numpy()
_, all_AP, _ = get_metrics(indices, num_q, num_g, q_pids, g_pids,
self.max_rank)
_, all_AP, _ = get_metrics(similarities_matrix, query_img_id,
gallery_img_id)
mAP = np.mean(all_AP)
metric_dict["mAP"] = mAP
......@@ -62,23 +53,13 @@ class mAP(nn.Layer):
class mINP(nn.Layer):
def __init__(self, max_rank=50):
def __init__(self):
super().__init__()
self.max_rank = max_rank
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
num_q, num_g = similarities_matrix.shape
q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
if num_g < self.max_rank:
max_rank = num_g
print('Note: number of gallery samples is quite small, got {}'.
format(num_g))
indices = paddle.argsort(
similarities_matrix, axis=1, descending=True).numpy()
_, _, all_INP = get_metrics(indices, num_q, num_g, q_pids, g_pids,
self.max_rank)
_, _, all_INP = get_metrics(similarities_matrix, query_img_id,
gallery_img_id)
mINP = np.mean(all_INP)
metric_dict["mINP"] = mINP
......@@ -86,34 +67,37 @@ class mINP(nn.Layer):
class Recallk(nn.Layer):
def __init__(self, max_rank=50, topk=(1, 5)):
def __init__(self, topk=(1, 5)):
super().__init__()
self.max_rank = max_rank
assert isinstance(topk, (int, list))
if isinstance(topk, int):
topk = [topk]
self.topk = topk
self.max_rank = max(self.topk) if max(self.topk) > 50 else 50
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
num_q, num_g = similarities_matrix.shape
q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
if num_g < self.max_rank:
max_rank = num_g
print('Note: number of gallery samples is quite small, got {}'.
format(num_g))
indices = paddle.argsort(
similarities_matrix, axis=1, descending=True).numpy()
all_cmc, _, _ = get_metrics(indices, num_q, num_g, q_pids, g_pids,
self.max_rank)
all_cmc, _, _ = get_metrics(similarities_matrix, query_img_id,
gallery_img_id, self.max_rank)
for k in self.topk:
metric_dict["recall{}".format(k)] = all_cmc[k - 1]
return metric_dict
def get_metrics(indices, num_q, num_g, q_pids, g_pids, max_rank=50):
@lru_cache()
def get_metrics(similarities_matrix, query_img_id, gallery_img_id,
max_rank=50):
num_q, num_g = similarities_matrix.shape
q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
if num_g < max_rank:
max_rank = num_g
print('Note: number of gallery samples is quite small, got {}'.format(
num_g))
indices = paddle.argsort(
similarities_matrix, axis=1, descending=True).numpy()
all_cmc = []
all_AP = []
all_INP = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册