retrieval.py 13.5 KB
Newer Older
D
dongshuilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
D
dongshuilong 已提交
17

D
dongshuilong 已提交
18
import platform
19
from collections import defaultdict
H
HydrogenSulfate 已提交
20 21

import numpy as np
D
dongshuilong 已提交
22
import paddle
23
import scipy
24 25

from ppcls.utils import all_gather, logger
D
dongshuilong 已提交
26 27


W
weishengyu 已提交
28 29
def retrieval_eval(engine, epoch_id=0):
    engine.model.eval()
30
    # step1. prepare query and gallery features
W
weishengyu 已提交
31
    if engine.gallery_query_dataloader is not None:
32
        gallery_feat, gallery_label, gallery_camera = compute_feature(
33
            engine, "gallery_query")
34
        query_feat, query_label, query_camera = gallery_feat, gallery_label, gallery_camera
35
    else:
36
        gallery_feat, gallery_label, gallery_camera = compute_feature(
37
            engine, "gallery")
38 39
        query_feat, query_label, query_camera = compute_feature(engine,
                                                                "query")
40

41
    # step2. split features into feature blocks for saving memory
42
    num_query = len(query_feat)
43
    block_size = engine.config["Global"].get("sim_block_size", 64)
44 45 46 47 48 49 50 51
    sections = [block_size] * (num_query // block_size)
    if num_query % block_size > 0:
        sections.append(num_query % block_size)

    query_feat_blocks = paddle.split(query_feat, sections)
    query_label_blocks = paddle.split(query_label, sections)
    query_camera_blocks = paddle.split(
        query_camera, sections) if query_camera is not None else None
D
dongshuilong 已提交
52 53
    metric_key = None

54
    # step3. compute metric
W
weishengyu 已提交
55
    if engine.eval_loss_func is None:
56
        metric_dict = {metric_key: 0.0}
D
dongshuilong 已提交
57
    else:
58 59 60 61 62
        use_reranking = engine.config["Global"].get("re_ranking", False)
        logger.info(f"re_ranking={use_reranking}")
        if use_reranking:
            # compute distance matrix
            distmat = compute_re_ranking_dist(
63
                query_feat, gallery_feat, engine.config["Global"].get(
64 65
                    "feature_normalize", True), 20, 6, 0.3)
            # exclude illegal distance
66 67 68 69 70 71 72 73 74 75 76
            if query_camera is not None:
                camera_mask = query_camera != gallery_camera.t()
                label_mask = query_label != gallery_label.t()
                keep_mask = label_mask | camera_mask
                distmat = keep_mask.astype(query_feat.dtype) * distmat + (
                    ~keep_mask).astype(query_feat.dtype) * (distmat.max() + 1)
            else:
                keep_mask = None
            # compute metric with all samples
            metric_dict = engine.eval_metric_func(-distmat, query_label,
                                                  gallery_label, keep_mask)
H
HydrogenSulfate 已提交
77
        else:
78 79 80
            metric_dict = defaultdict(float)
            for block_idx, block_feat in enumerate(query_feat_blocks):
                # compute distance matrix
81
                distmat = paddle.matmul(
82 83 84 85 86 87 88 89 90
                    block_feat, gallery_feat, transpose_y=True)
                # exclude illegal distance
                if query_camera is not None:
                    camera_mask = query_camera_blocks[
                        block_idx] != gallery_camera.t()
                    label_mask = query_label_blocks[
                        block_idx] != gallery_label.t()
                    keep_mask = label_mask | camera_mask
                    distmat = keep_mask.astype(query_feat.dtype) * distmat
D
dongshuilong 已提交
91
                else:
H
HydrogenSulfate 已提交
92
                    keep_mask = None
93
                # compute metric by block
94
                metric_block = engine.eval_metric_func(
95 96 97
                    distmat, query_label_blocks[block_idx], gallery_label,
                    keep_mask)
                # accumulate metric
98
                for key in metric_block:
99 100
                    metric_dict[key] += metric_block[key] * block_feat.shape[
                        0] / num_query
H
HydrogenSulfate 已提交
101

D
dongshuilong 已提交
102
    metric_info_list = []
103 104
    for key, value in metric_dict.items():
        metric_info_list.append(f"{key}: {value:.5f}")
D
dongshuilong 已提交
105 106 107
        if metric_key is None:
            metric_key = key
    metric_msg = ", ".join(metric_info_list)
108
    logger.info(f"[Eval][Epoch {epoch_id}][Avg]{metric_msg}")
D
dongshuilong 已提交
109 110 111 112

    return metric_dict[metric_key]


113 114
def compute_feature(engine, name="gallery"):
    if name == "gallery":
W
weishengyu 已提交
115
        dataloader = engine.gallery_dataloader
116
    elif name == "query":
W
weishengyu 已提交
117
        dataloader = engine.query_dataloader
118
    elif name == "gallery_query":
W
weishengyu 已提交
119
        dataloader = engine.gallery_query_dataloader
D
dongshuilong 已提交
120
    else:
121 122 123
        raise ValueError(
            f"Only support gallery or query or gallery_query dataset, but got {name}"
        )
D
dongshuilong 已提交
124

125 126 127 128
    all_feat = []
    all_label = []
    all_camera = []
    has_camera = False
W
weishengyu 已提交
129 130
    for idx, batch in enumerate(dataloader):  # load is very time-consuming
        if idx % engine.config["Global"]["print_batch_step"] == 0:
D
dongshuilong 已提交
131 132 133
            logger.info(
                f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
            )
134

D
dongshuilong 已提交
135 136
        batch = [paddle.to_tensor(x) for x in batch]
        batch[1] = batch[1].reshape([-1, 1]).astype("int64")
137 138
        if len(batch) >= 3:
            has_camera = True
D
dongshuilong 已提交
139
            batch[2] = batch[2].reshape([-1, 1]).astype("int64")
140 141 142 143 144 145
        if engine.amp and engine.amp_eval:
            with paddle.amp.auto_cast(
                    custom_black_list={
                        "flatten_contiguous_range", "greater_than"
                    },
                    level=engine.amp_level):
146
                out = engine.model(batch[0])
147
        else:
148
            out = engine.model(batch[0])
149 150
        if "Student" in out:
            out = out["Student"]
151 152

        # get features
H
HydrogenSulfate 已提交
153 154
        if engine.config["Global"].get("retrieval_feature_from",
                                       "features") == "features":
155 156
            # use output from neck as feature
            batch_feat = out["features"]
H
HydrogenSulfate 已提交
157
        else:
158 159
            # use output from backbone as feature
            batch_feat = out["backbone"]
D
dongshuilong 已提交
160

161
        # do norm(optional)
W
weishengyu 已提交
162
        if engine.config["Global"].get("feature_normalize", True):
163
            batch_feat = paddle.nn.functional.normalize(batch_feat, p=2)
164

165
        # do binarize(optional)
W
weishengyu 已提交
166
        if engine.config["Global"].get("feature_binarize") == "round":
167
            batch_feat = paddle.round(batch_feat).astype("float32") * 2.0 - 1.0
168
        elif engine.config["Global"].get("feature_binarize") == "sign":
169
            batch_feat = paddle.sign(batch_feat).astype("float32")
D
dongshuilong 已提交
170

171
        if paddle.distributed.get_world_size() > 1:
172 173 174 175
            all_feat.append(all_gather(batch_feat))
            all_label.append(all_gather(batch[1]))
            if has_camera:
                all_camera.append(all_gather(batch[2]))
D
dongshuilong 已提交
176
        else:
177 178 179 180
            all_feat.append(batch_feat)
            all_label.append(batch[1])
            if has_camera:
                all_camera.append(batch[2])
181

W
weishengyu 已提交
182 183
    if engine.use_dali:
        dataloader.reset()
184

185 186 187 188 189 190
    all_feat = paddle.concat(all_feat)
    all_label = paddle.concat(all_label)
    if has_camera:
        all_camera = paddle.concat(all_camera)
    else:
        all_camera = None
191
    # discard redundant padding sample(s) at the end
192 193 194 195 196 197
    total_samples = dataloader.size if engine.use_dali else len(
        dataloader.dataset)
    all_feat = all_feat[:total_samples]
    all_label = all_label[:total_samples]
    if has_camera:
        all_camera = all_camera[:total_samples]
D
dongshuilong 已提交
198

199 200
    logger.info(f"Build {name} done, all feat shape: {all_feat.shape}")
    return all_feat, all_label, all_camera
H
HydrogenSulfate 已提交
201 202


203 204
def k_reciprocal_neighbor(rank: np.ndarray, p: int, k: int) -> np.ndarray:
    """Implementation of k-reciprocal nearest neighbors, i.e. R(p, k)
H
HydrogenSulfate 已提交
205 206

    Args:
207 208 209
        rank (np.ndarray): Rank mat with shape of [N, N].
        p (int): Probe index.
        k (int): Parameter k for k-reciprocal nearest neighbors algorithm.
H
HydrogenSulfate 已提交
210 211

    Returns:
212
        np.ndarray: K-reciprocal nearest neighbors of probe p with shape of [M, ].
H
HydrogenSulfate 已提交
213
    """
214 215 216 217 218 219 220
    # use k+1 for excluding probe index itself
    forward_k_neigh_index = rank[p, :k + 1]
    backward_k_neigh_index = rank[forward_k_neigh_index, :k + 1]
    candidate = np.where(backward_k_neigh_index == p)[0]
    return forward_k_neigh_index[candidate]


221 222
def compute_re_ranking_dist(query_feat: paddle.Tensor,
                            gallery_feat: paddle.Tensor,
223 224 225 226 227 228 229 230
                            feature_normed: bool=True,
                            k1: int=20,
                            k2: int=6,
                            lamb: float=0.5) -> paddle.Tensor:
    """
    Re-ranking Person Re-identification with k-reciprocal Encoding
    Reference: https://arxiv.org/abs/1701.08398
    Code refernence: https://github.com/michuanhaohao/reid-strong-baseline/blob/master/utils/re_ranking.py
H
HydrogenSulfate 已提交
231

232
    Args:
233 234
        query_feat (paddle.Tensor): Query features with shape of [num_query, feature_dim].
        gallery_feat (paddle.Tensor):  Gallery features with shape of [num_gallery, feature_dim].
235 236 237 238
        feature_normed (bool, optional):  Whether input features are normalized.
        k1 (int, optional): Parameter for K-reciprocal nearest neighbors. Defaults to 20.
        k2 (int, optional): Parameter for K-nearest neighbors. Defaults to 6.
        lamb (float, optional): Penalty factor. Defaults to 0.5.
H
HydrogenSulfate 已提交
239

240 241 242
    Returns:
        paddle.Tensor: (1 - lamb) x Dj + lamb x D, with shape of [num_query, num_gallery].
    """
243 244
    num_query = query_feat.shape[0]
    num_gallery = gallery_feat.shape[0]
245
    num_all = num_query + num_gallery
246
    feat = paddle.concat([query_feat, gallery_feat], 0)
247 248 249 250 251 252 253
    logger.info("Using GPU to compute original distance matrix")
    # use L2 distance
    if feature_normed:
        original_dist = 2 - 2 * paddle.matmul(feat, feat, transpose_y=True)
    else:
        original_dist = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]) + \
            paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]).t()
254
        original_dist = original_dist.addmm(feat, feat.t(), -2.0, 1.0)
255 256
    original_dist = original_dist.numpy()
    del feat
H
HydrogenSulfate 已提交
257 258 259

    original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
    V = np.zeros_like(original_dist).astype(np.float16)
260
    initial_rank = np.argpartition(original_dist, range(1, k1 + 1))  # 22.2s
261 262 263 264 265 266 267 268 269 270 271 272
    logger.info("Start re-ranking...")

    for p in range(num_all):
        # compute R(p,k1)
        p_k_reciprocal_ind = k_reciprocal_neighbor(initial_rank, p, k1)

        # compute R*(p,k1)=R(p,k1)∪R(q,k1/2)
        # s.t. |R(p,k1)∩R(q,k1/2)|>=2/3|R(q,k1/2)|, ∀q∈R(p,k1)
        p_k_reciprocal_exp_ind = p_k_reciprocal_ind
        for _, q in enumerate(p_k_reciprocal_ind):
            q_k_reciprocal_ind = k_reciprocal_neighbor(initial_rank, q,
                                                       int(np.around(k1 / 2)))
273 274 275 276 277
            if len(
                    np.intersect1d(
                        p_k_reciprocal_ind,
                        q_k_reciprocal_ind,
                        assume_unique=True)) > 2 / 3 * len(q_k_reciprocal_ind):
278 279 280 281 282 283 284 285 286 287
                p_k_reciprocal_exp_ind = np.append(p_k_reciprocal_exp_ind,
                                                   q_k_reciprocal_ind)
        p_k_reciprocal_exp_ind = np.unique(p_k_reciprocal_exp_ind)
        # reweight distance using gaussian kernel
        weight = np.exp(-original_dist[p, p_k_reciprocal_exp_ind])
        V[p, p_k_reciprocal_exp_ind] = weight / np.sum(weight)

    # local query expansion
    original_dist = original_dist[:num_query, ]
    if k2 > 1:
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
        try:
            # use sparse tensor to speed up query expansion
            indices = (np.repeat(np.arange(num_all), k2),
                       initial_rank[:, :k2].reshape([-1, ]))
            values = np.array(
                [1 / k2 for _ in range(num_all * k2)], dtype="float16")
            V = scipy.sparse.coo_matrix(
                (values, indices), V.shape,
                dtype="float16") @V.astype("float16")
        except Exception as e:
            logger.info(
                f"Failed to do local query expansion with sparse tensor for reason: \n{e}\n"
                f"now use for-loop instead")
            # use vanilla for-loop
            V_qe = np.zeros_like(V, dtype=np.float16)
            for i in range(num_all):
                V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
            V = V_qe
            del V_qe
H
HydrogenSulfate 已提交
307
    del initial_rank
308 309

    # cache k-reciprocal sets which contains gj
H
HydrogenSulfate 已提交
310
    invIndex = []
311 312
    for gj in range(num_all):
        invIndex.append(np.nonzero(V[:, gj])[0])
H
HydrogenSulfate 已提交
313

314
    # compute jaccard distance
H
HydrogenSulfate 已提交
315
    jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
316 317 318 319 320 321 322 323 324
    for p in range(num_query):
        sum_min = np.zeros(shape=[1, num_all], dtype=np.float16)
        gj_ind = np.nonzero(V[p, :])[0]
        gj_ind_inv = [invIndex[gj] for gj in gj_ind]
        for j, gj in enumerate(gj_ind):
            gi = gj_ind_inv[j]
            sum_min[0, gi] += np.minimum(V[p, gj], V[gi, gj])
        jaccard_dist[p] = 1 - sum_min / (2 - sum_min)

325 326
    # fuse jaccard distance with original distance
    final_dist = (1 - lamb) * jaccard_dist + lamb * original_dist
H
HydrogenSulfate 已提交
327 328 329
    del original_dist
    del V
    del jaccard_dist
330
    final_dist = final_dist[:num_query, num_query:]
H
HydrogenSulfate 已提交
331 332
    final_dist = paddle.to_tensor(final_dist)
    return final_dist