retrieval.py 13.4 KB
Newer Older
D
dongshuilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
D
dongshuilong 已提交
17

18
from collections import defaultdict
H
HydrogenSulfate 已提交
19 20

import numpy as np
D
dongshuilong 已提交
21
import paddle
22
import scipy
23 24

from ppcls.utils import all_gather, logger
D
dongshuilong 已提交
25 26


W
weishengyu 已提交
27 28
def retrieval_eval(engine, epoch_id=0):
    engine.model.eval()
29
    # step1. prepare query and gallery features
W
weishengyu 已提交
30
    if engine.gallery_query_dataloader is not None:
31
        gallery_feat, gallery_label, gallery_camera = compute_feature(
32
            engine, "gallery_query")
33
        query_feat, query_label, query_camera = gallery_feat, gallery_label, gallery_camera
34
    else:
35
        gallery_feat, gallery_label, gallery_camera = compute_feature(
36
            engine, "gallery")
37 38
        query_feat, query_label, query_camera = compute_feature(engine,
                                                                "query")
39

40
    # step2. split features into feature blocks for saving memory
41
    num_query = len(query_feat)
42
    block_size = engine.config["Global"].get("sim_block_size", 64)
43 44 45 46 47 48 49 50
    sections = [block_size] * (num_query // block_size)
    if num_query % block_size > 0:
        sections.append(num_query % block_size)

    query_feat_blocks = paddle.split(query_feat, sections)
    query_label_blocks = paddle.split(query_label, sections)
    query_camera_blocks = paddle.split(
        query_camera, sections) if query_camera is not None else None
D
dongshuilong 已提交
51 52
    metric_key = None

53
    # step3. compute metric
W
weishengyu 已提交
54
    if engine.eval_loss_func is None:
55
        metric_dict = {metric_key: 0.0}
D
dongshuilong 已提交
56
    else:
57 58 59 60 61
        use_reranking = engine.config["Global"].get("re_ranking", False)
        logger.info(f"re_ranking={use_reranking}")
        if use_reranking:
            # compute distance matrix
            distmat = compute_re_ranking_dist(
62
                query_feat, gallery_feat, engine.config["Global"].get(
63 64
                    "feature_normalize", True), 20, 6, 0.3)
            # exclude illegal distance
65 66 67 68 69 70 71 72 73 74 75
            if query_camera is not None:
                camera_mask = query_camera != gallery_camera.t()
                label_mask = query_label != gallery_label.t()
                keep_mask = label_mask | camera_mask
                distmat = keep_mask.astype(query_feat.dtype) * distmat + (
                    ~keep_mask).astype(query_feat.dtype) * (distmat.max() + 1)
            else:
                keep_mask = None
            # compute metric with all samples
            metric_dict = engine.eval_metric_func(-distmat, query_label,
                                                  gallery_label, keep_mask)
H
HydrogenSulfate 已提交
76
        else:
77 78 79
            metric_dict = defaultdict(float)
            for block_idx, block_feat in enumerate(query_feat_blocks):
                # compute distance matrix
80
                distmat = paddle.matmul(
81 82 83 84 85 86 87 88 89
                    block_feat, gallery_feat, transpose_y=True)
                # exclude illegal distance
                if query_camera is not None:
                    camera_mask = query_camera_blocks[
                        block_idx] != gallery_camera.t()
                    label_mask = query_label_blocks[
                        block_idx] != gallery_label.t()
                    keep_mask = label_mask | camera_mask
                    distmat = keep_mask.astype(query_feat.dtype) * distmat
D
dongshuilong 已提交
90
                else:
H
HydrogenSulfate 已提交
91
                    keep_mask = None
92
                # compute metric by block
93
                metric_block = engine.eval_metric_func(
94 95 96
                    distmat, query_label_blocks[block_idx], gallery_label,
                    keep_mask)
                # accumulate metric
97
                for key in metric_block:
98 99
                    metric_dict[key] += metric_block[key] * block_feat.shape[
                        0] / num_query
H
HydrogenSulfate 已提交
100

D
dongshuilong 已提交
101
    metric_info_list = []
102 103
    for key, value in metric_dict.items():
        metric_info_list.append(f"{key}: {value:.5f}")
D
dongshuilong 已提交
104 105 106
        if metric_key is None:
            metric_key = key
    metric_msg = ", ".join(metric_info_list)
107
    logger.info(f"[Eval][Epoch {epoch_id}][Avg]{metric_msg}")
D
dongshuilong 已提交
108 109 110 111

    return metric_dict[metric_key]


112 113
def compute_feature(engine, name="gallery"):
    if name == "gallery":
W
weishengyu 已提交
114
        dataloader = engine.gallery_dataloader
115
    elif name == "query":
W
weishengyu 已提交
116
        dataloader = engine.query_dataloader
117
    elif name == "gallery_query":
W
weishengyu 已提交
118
        dataloader = engine.gallery_query_dataloader
D
dongshuilong 已提交
119
    else:
120 121 122
        raise ValueError(
            f"Only support gallery or query or gallery_query dataset, but got {name}"
        )
D
dongshuilong 已提交
123

124 125 126 127
    all_feat = []
    all_label = []
    all_camera = []
    has_camera = False
W
weishengyu 已提交
128 129
    for idx, batch in enumerate(dataloader):  # load is very time-consuming
        if idx % engine.config["Global"]["print_batch_step"] == 0:
D
dongshuilong 已提交
130 131 132
            logger.info(
                f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
            )
133

D
dongshuilong 已提交
134 135
        batch = [paddle.to_tensor(x) for x in batch]
        batch[1] = batch[1].reshape([-1, 1]).astype("int64")
136 137
        if len(batch) >= 3:
            has_camera = True
D
dongshuilong 已提交
138
            batch[2] = batch[2].reshape([-1, 1]).astype("int64")
139 140 141 142 143 144
        if engine.amp and engine.amp_eval:
            with paddle.amp.auto_cast(
                    custom_black_list={
                        "flatten_contiguous_range", "greater_than"
                    },
                    level=engine.amp_level):
145
                out = engine.model(batch[0])
146
        else:
147
            out = engine.model(batch[0])
148 149
        if "Student" in out:
            out = out["Student"]
150 151

        # get features
H
HydrogenSulfate 已提交
152 153
        if engine.config["Global"].get("retrieval_feature_from",
                                       "features") == "features":
154 155
            # use output from neck as feature
            batch_feat = out["features"]
H
HydrogenSulfate 已提交
156
        else:
157 158
            # use output from backbone as feature
            batch_feat = out["backbone"]
D
dongshuilong 已提交
159

160
        # do norm(optional)
W
weishengyu 已提交
161
        if engine.config["Global"].get("feature_normalize", True):
162
            batch_feat = paddle.nn.functional.normalize(batch_feat, p=2)
163

164
        # do binarize(optional)
W
weishengyu 已提交
165
        if engine.config["Global"].get("feature_binarize") == "round":
166
            batch_feat = paddle.round(batch_feat).astype("float32") * 2.0 - 1.0
167
        elif engine.config["Global"].get("feature_binarize") == "sign":
168
            batch_feat = paddle.sign(batch_feat).astype("float32")
D
dongshuilong 已提交
169

170
        if paddle.distributed.get_world_size() > 1:
171 172 173 174
            all_feat.append(all_gather(batch_feat))
            all_label.append(all_gather(batch[1]))
            if has_camera:
                all_camera.append(all_gather(batch[2]))
D
dongshuilong 已提交
175
        else:
176 177 178 179
            all_feat.append(batch_feat)
            all_label.append(batch[1])
            if has_camera:
                all_camera.append(batch[2])
180

W
weishengyu 已提交
181 182
    if engine.use_dali:
        dataloader.reset()
183

184 185 186 187 188 189
    all_feat = paddle.concat(all_feat)
    all_label = paddle.concat(all_label)
    if has_camera:
        all_camera = paddle.concat(all_camera)
    else:
        all_camera = None
190
    # discard redundant padding sample(s) at the end
191 192 193 194 195 196
    total_samples = dataloader.size if engine.use_dali else len(
        dataloader.dataset)
    all_feat = all_feat[:total_samples]
    all_label = all_label[:total_samples]
    if has_camera:
        all_camera = all_camera[:total_samples]
D
dongshuilong 已提交
197

198 199
    logger.info(f"Build {name} done, all feat shape: {all_feat.shape}")
    return all_feat, all_label, all_camera
H
HydrogenSulfate 已提交
200 201


202 203
def k_reciprocal_neighbor(rank: np.ndarray, p: int, k: int) -> np.ndarray:
    """Implementation of k-reciprocal nearest neighbors, i.e. R(p, k)
H
HydrogenSulfate 已提交
204 205

    Args:
206 207 208
        rank (np.ndarray): Rank mat with shape of [N, N].
        p (int): Probe index.
        k (int): Parameter k for k-reciprocal nearest neighbors algorithm.
H
HydrogenSulfate 已提交
209 210

    Returns:
211
        np.ndarray: K-reciprocal nearest neighbors of probe p with shape of [M, ].
H
HydrogenSulfate 已提交
212
    """
213 214 215 216 217 218 219
    # use k+1 for excluding probe index itself
    forward_k_neigh_index = rank[p, :k + 1]
    backward_k_neigh_index = rank[forward_k_neigh_index, :k + 1]
    candidate = np.where(backward_k_neigh_index == p)[0]
    return forward_k_neigh_index[candidate]


220 221
def compute_re_ranking_dist(query_feat: paddle.Tensor,
                            gallery_feat: paddle.Tensor,
222 223 224 225 226 227 228 229
                            feature_normed: bool=True,
                            k1: int=20,
                            k2: int=6,
                            lamb: float=0.5) -> paddle.Tensor:
    """
    Re-ranking Person Re-identification with k-reciprocal Encoding
    Reference: https://arxiv.org/abs/1701.08398
    Code refernence: https://github.com/michuanhaohao/reid-strong-baseline/blob/master/utils/re_ranking.py
H
HydrogenSulfate 已提交
230

231
    Args:
232 233
        query_feat (paddle.Tensor): Query features with shape of [num_query, feature_dim].
        gallery_feat (paddle.Tensor):  Gallery features with shape of [num_gallery, feature_dim].
234 235 236 237
        feature_normed (bool, optional):  Whether input features are normalized.
        k1 (int, optional): Parameter for K-reciprocal nearest neighbors. Defaults to 20.
        k2 (int, optional): Parameter for K-nearest neighbors. Defaults to 6.
        lamb (float, optional): Penalty factor. Defaults to 0.5.
H
HydrogenSulfate 已提交
238

239 240 241
    Returns:
        paddle.Tensor: (1 - lamb) x Dj + lamb x D, with shape of [num_query, num_gallery].
    """
242 243
    num_query = query_feat.shape[0]
    num_gallery = gallery_feat.shape[0]
244
    num_all = num_query + num_gallery
245
    feat = paddle.concat([query_feat, gallery_feat], 0)
246 247 248 249 250 251 252
    logger.info("Using GPU to compute original distance matrix")
    # use L2 distance
    if feature_normed:
        original_dist = 2 - 2 * paddle.matmul(feat, feat, transpose_y=True)
    else:
        original_dist = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]) + \
            paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]).t()
253
        original_dist = original_dist.addmm(feat, feat.t(), -2.0, 1.0)
254 255
    original_dist = original_dist.numpy()
    del feat
H
HydrogenSulfate 已提交
256 257 258

    original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
    V = np.zeros_like(original_dist).astype(np.float16)
悟、's avatar
悟、 已提交
259
    initial_rank = np.argpartition(original_dist, range(1, k1 + 1))
260 261 262 263 264 265 266 267 268 269 270 271
    logger.info("Start re-ranking...")

    for p in range(num_all):
        # compute R(p,k1)
        p_k_reciprocal_ind = k_reciprocal_neighbor(initial_rank, p, k1)

        # compute R*(p,k1)=R(p,k1)∪R(q,k1/2)
        # s.t. |R(p,k1)∩R(q,k1/2)|>=2/3|R(q,k1/2)|, ∀q∈R(p,k1)
        p_k_reciprocal_exp_ind = p_k_reciprocal_ind
        for _, q in enumerate(p_k_reciprocal_ind):
            q_k_reciprocal_ind = k_reciprocal_neighbor(initial_rank, q,
                                                       int(np.around(k1 / 2)))
272 273 274 275 276
            if len(
                    np.intersect1d(
                        p_k_reciprocal_ind,
                        q_k_reciprocal_ind,
                        assume_unique=True)) > 2 / 3 * len(q_k_reciprocal_ind):
277 278 279 280 281 282 283 284 285 286
                p_k_reciprocal_exp_ind = np.append(p_k_reciprocal_exp_ind,
                                                   q_k_reciprocal_ind)
        p_k_reciprocal_exp_ind = np.unique(p_k_reciprocal_exp_ind)
        # reweight distance using gaussian kernel
        weight = np.exp(-original_dist[p, p_k_reciprocal_exp_ind])
        V[p, p_k_reciprocal_exp_ind] = weight / np.sum(weight)

    # local query expansion
    original_dist = original_dist[:num_query, ]
    if k2 > 1:
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
        try:
            # use sparse tensor to speed up query expansion
            indices = (np.repeat(np.arange(num_all), k2),
                       initial_rank[:, :k2].reshape([-1, ]))
            values = np.array(
                [1 / k2 for _ in range(num_all * k2)], dtype="float16")
            V = scipy.sparse.coo_matrix(
                (values, indices), V.shape,
                dtype="float16") @V.astype("float16")
        except Exception as e:
            logger.info(
                f"Failed to do local query expansion with sparse tensor for reason: \n{e}\n"
                f"now use for-loop instead")
            # use vanilla for-loop
            V_qe = np.zeros_like(V, dtype=np.float16)
            for i in range(num_all):
                V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
            V = V_qe
            del V_qe
H
HydrogenSulfate 已提交
306
    del initial_rank
307 308

    # cache k-reciprocal sets which contains gj
H
HydrogenSulfate 已提交
309
    invIndex = []
310 311
    for gj in range(num_all):
        invIndex.append(np.nonzero(V[:, gj])[0])
H
HydrogenSulfate 已提交
312

313
    # compute jaccard distance
H
HydrogenSulfate 已提交
314
    jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
315 316 317 318 319 320 321 322 323
    for p in range(num_query):
        sum_min = np.zeros(shape=[1, num_all], dtype=np.float16)
        gj_ind = np.nonzero(V[p, :])[0]
        gj_ind_inv = [invIndex[gj] for gj in gj_ind]
        for j, gj in enumerate(gj_ind):
            gi = gj_ind_inv[j]
            sum_min[0, gi] += np.minimum(V[p, gj], V[gi, gj])
        jaccard_dist[p] = 1 - sum_min / (2 - sum_min)

324 325
    # fuse jaccard distance with original distance
    final_dist = (1 - lamb) * jaccard_dist + lamb * original_dist
H
HydrogenSulfate 已提交
326 327 328
    del original_dist
    del V
    del jaccard_dist
329
    final_dist = final_dist[:num_query, num_query:]
H
HydrogenSulfate 已提交
330 331
    final_dist = paddle.to_tensor(final_dist)
    return final_dist