diff --git a/ppcls/arch/gears/fc.py b/ppcls/arch/gears/fc.py index 9d7308854a0424970981dd7223eed47bfc7c0b5c..0262c100bb80a7f12a9fe4085a444539e3a5cd27 100644 --- a/ppcls/arch/gears/fc.py +++ b/ppcls/arch/gears/fc.py @@ -30,6 +30,6 @@ class FC(nn.Layer): self.fc = paddle.nn.Linear( self.embedding_size, self.class_num, weight_attr=weight_attr) - def forward(self, input, label): + def forward(self, input): out = self.fc(input) return out diff --git a/ppcls/engine/trainer.py b/ppcls/engine/trainer.py index abbc8579aeb462ff999eb581635b65e4d53e85ff..22a401e99572188418864a250f2d94f39a87dd41 100644 --- a/ppcls/engine/trainer.py +++ b/ppcls/engine/trainer.py @@ -31,7 +31,7 @@ from ppcls.utils import logger from ppcls.data import build_dataloader from ppcls.arch import build_model from ppcls.loss import build_loss -from ppcls.arch.loss_metrics import build_metrics +from ppcls.metric import build_metrics from ppcls.optimizer import build_optimizer from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import init_model @@ -81,43 +81,35 @@ class Trainer(object): self.vdl_writer = LogWriter(logdir=vdl_writer_path) logger.info('train with paddle {} and device {}'.format( paddle.__version__, self.device)) - - def _build_metric_info(self, metric_config, mode="train"): - """ - _build_metric_info: build metrics according to current mode - Return: - metric: dict of the metrics info - """ - metric = None - mode = mode.capitalize() - if mode in metric_config and metric_config[mode] is not None: - metric = build_metrics(metric_config[mode]) - return metric - - def _build_loss_info(self, loss_config, mode="train"): - """ - _build_loss_info: build loss according to current mode - Return: - loss_dict: dict of the loss info - """ - loss = None - mode = mode.capitalize() - if mode in loss_config and loss_config[mode] is not None: - loss = build_loss(loss_config[mode]) - return loss + # init members + self.train_dataloader = None + self.eval_dataloader = None + self.gallery_dataloader = None + self.query_dataloader = None + self.eval_mode = self.config["Global"].get("eval_mode", + "classification") + self.train_loss_func = None + self.eval_loss_func = None + self.train_metric_func = None + self.eval_metric_func = None def train(self): # build train loss and metric info - loss_func = self._build_loss_info(self.config["Loss"]) - if "Metric" in self.config: - metric_func = self._build_metric_info(self.config["Metric"]) - else: - metric_func = None - - train_dataloader = build_dataloader(self.config["DataLoader"], "Train", - self.device) - - step_each_epoch = len(train_dataloader) + if self.train_loss_func is None: + loss_info = self.config["Loss"]["Train"] + self.train_loss_func = build_loss(loss_info) + if self.train_metric_func is None: + metric_config = self.config.get("Metric") + if metric_config is not None: + metric_config = metric_config.get("Train") + if metric_config is not None: + self.train_metric_func = build_metrics(metric_config) + + if self.train_dataloader is None: + self.train_dataloader = build_dataloader(self.config["DataLoader"], + "Train", self.device) + + step_each_epoch = len(self.train_dataloader) optimizer, lr_sch = build_optimizer(self.config["Optimizer"], self.config["Global"]["epochs"], @@ -146,8 +138,7 @@ class Trainer(object): for epoch_id in range(best_metric["epoch"] + 1, self.config["Global"]["epochs"] + 1): acc = 0.0 - self.model.train() - for iter_id, batch in enumerate(train_dataloader()): + for iter_id, batch in enumerate(self.train_dataloader()): batch_size = batch[0].shape[0] batch[1] = paddle.to_tensor(batch[1].numpy().astype("int64") .reshape([-1, 1])) @@ -158,15 +149,15 @@ class Trainer(object): else: out = self.model(batch[0], batch[1]) # calc loss - loss_dict = loss_func(out, batch[1]) + loss_dict = self.train_loss_func(out, batch[1]) for key in loss_dict: if not key in output_info: output_info[key] = AverageMeter(key, '7.5f') output_info[key].update(loss_dict[key].numpy()[0], batch_size) # calc metric - if metric_func is not None: - metric_dict = metric_func(out, batch[-1]) + if self.train_metric_func is not None: + metric_dict = self.train_metric_func(out, batch[-1]) for key in metric_dict: if not key in output_info: output_info[key] = AverageMeter(key, '7.5f') @@ -181,7 +172,7 @@ class Trainer(object): ]) logger.info("[Train][Epoch {}][Iter: {}/{}]{}, {}".format( epoch_id, iter_id, - len(train_dataloader), lr_msg, metric_msg)) + len(self.train_dataloader), lr_msg, metric_msg)) # step opt and lr loss_dict["loss"].backward() @@ -212,6 +203,7 @@ class Trainer(object): self.output_dir, model_name=self.config["Arch"]["name"], prefix="best_model") + self.model.train() # save model if epoch_id % save_interval == 0: @@ -228,20 +220,56 @@ class Trainer(object): @paddle.no_grad() def eval(self, epoch_id=0): - output_info = dict() - - eval_dataloader = build_dataloader(self.config["DataLoader"], "Eval", - self.device) - self.model.eval() + if self.eval_loss_func is None: + loss_config = self.config.get("Loss", None) + if loss_config is not None: + loss_config = loss_config.get("Eval") + if loss_config is not None: + self.eval_loss_func = build_loss(loss_config) + if self.eval_mode == "classification": + if self.eval_dataloader is None: + self.eval_dataloader = build_dataloader( + self.config["DataLoader"], "Eval", self.device) + + if self.eval_metric_func is None: + metric_config = self.config.get("Metric") + if metric_config is not None: + metric_config = metric_config.get("Eval") + if metric_config is not None: + self.eval_metric_func = build_metrics(metric_config) + + eval_result = self.eval_cls(epoch_id) + + elif self.eval_mode == "retrieval": + if self.gallery_dataloader is None: + self.gallery_dataloader = build_dataloader( + self.config["DataLoader"]["Eval"], "Gallery", self.device) + + if self.query_dataloader is None: + self.query_dataloader = build_dataloader( + self.config["DataLoader"]["Eval"], "Query", self.device) + # build metric info + if self.eval_metric_func is None: + metric_config = self.config.get("Metric", None) + if metric_config is None: + metric_config = [{"name": "Recallk", "topk": (1, 5)}] + else: + metric_config = metric_config["Eval"] + self.eval_metric_func = build_metrics(metric_config) + eval_result = self.eval_retrieval(epoch_id) + else: + logger.warning("Invalid eval mode: {}".format(self.eval_mode)) + eval_result = None + self.model.train() + return eval_result + + def eval_cls(self, epoch_id=0): + output_info = dict() print_batch_step = self.config["Global"]["print_batch_step"] - # build train loss and metric info - loss_func = self._build_loss_info(self.config["Loss"], "eval") - metric_func = self._build_metric_info(self.config["Metric"], "eval") metric_key = None - - for iter_id, batch in enumerate(eval_dataloader()): + for iter_id, batch in enumerate(self.eval_dataloader()): batch_size = batch[0].shape[0] batch[0] = paddle.to_tensor(batch[0]).astype("float32") batch[1] = paddle.to_tensor(batch[1]).reshape([-1, 1]) @@ -250,32 +278,32 @@ class Trainer(object): out = self.model(batch[0], batch[1]) else: out = self.model(batch[0]) - # calc build - if loss_func is not None: - loss_dict = loss_func(out, batch[-1]) + # calc loss + if self.eval_loss_func is not None: + loss_dict = self.eval_loss_func(out, batch[-1]) for key in loss_dict: if not key in output_info: output_info[key] = AverageMeter(key, '7.5f') output_info[key].update(loss_dict[key].numpy()[0], batch_size) - # calc metric - if metric_func is not None: - metric_dict = metric_func(out, batch[-1]) - if paddle.distributed.get_world_size() > 1: - for key in metric_dict: - paddle.distributed.all_reduce( - metric_dict[key], - op=paddle.distributed.ReduceOp.SUM) - metric_dict[key] = metric_dict[ - key] / paddle.distributed.get_world_size() + # calc metric + if self.eval_metric_func is not None: + metric_dict = self.eval_metric_func(out, batch[-1]) + if paddle.distributed.get_world_size() > 1: for key in metric_dict: - if metric_key is None: - metric_key = key - if not key in output_info: - output_info[key] = AverageMeter(key, '7.5f') + paddle.distributed.all_reduce( + metric_dict[key], + op=paddle.distributed.ReduceOp.SUM) + metric_dict[key] = metric_dict[ + key] / paddle.distributed.get_world_size() + for key in metric_dict: + if metric_key is None: + metric_key = key + if not key in output_info: + output_info[key] = AverageMeter(key, '7.5f') - output_info[key].update(metric_dict[key].numpy()[0], - batch_size) + output_info[key].update(metric_dict[key].numpy()[0], + batch_size) if iter_id % print_batch_step == 0: metric_msg = ", ".join([ @@ -283,7 +311,7 @@ class Trainer(object): for key in output_info ]) logger.info("[Eval][Epoch {}][Iter: {}/{}]{}".format( - epoch_id, iter_id, len(eval_dataloader), metric_msg)) + epoch_id, iter_id, len(self.eval_dataloader), metric_msg)) metric_msg = ", ".join([ "{}: {:.5f}".format(key, output_info[key].avg) @@ -291,13 +319,128 @@ class Trainer(object): ]) logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg)) - self.model.train() # do not try to save best model - if metric_func is None: + if self.eval_metric_func is None: return -1 # return 1st metric in the dict return output_info[metric_key].avg + def eval_retrieval(self, epoch_id=0): + self.model.eval() + cum_similarity_matrix = None + # step1. build gallery + gallery_feas, gallery_img_id, gallery_camera_id = self._cal_feature( + name='gallery') + query_feas, query_img_id, query_camera_id = self._cal_feature( + name='query') + gallery_img_id = gallery_img_id + # if gallery_camera_id is not None: + # gallery_camera_id = gallery_camera_id + # step2. do evaluation + sim_block_size = self.config["Global"].get("sim_block_size", 64) + sections = [sim_block_size] * (len(query_feas) // sim_block_size) + if len(query_feas) % sim_block_size: + sections.append(len(query_feas) % sim_block_size) + fea_blocks = paddle.split(query_feas, num_or_sections=sections) + if query_camera_id is not None: + camera_id_blocks = paddle.split( + query_camera_id, num_or_sections=sections) + image_id_blocks = paddle.split( + query_img_id, num_or_sections=sections) + metric_key = None + + for block_idx, block_fea in enumerate(fea_blocks): + similarity_matrix = paddle.matmul( + block_fea, gallery_feas, transpose_y=True) + if query_camera_id is not None: + camera_id_block = camera_id_blocks[block_idx] + camera_id_mask = (camera_id_block != gallery_camera_id.t()) + + image_id_block = image_id_blocks[block_idx] + image_id_mask = (image_id_block != gallery_img_id.t()) + + keep_mask = paddle.logical_or(camera_id_mask, image_id_mask) + similarity_matrix = similarity_matrix * keep_mask.astype( + "float32") + if cum_similarity_matrix is None: + cum_similarity_matrix = similarity_matrix + else: + cum_similarity_matrix = paddle.concat( + [cum_similarity_matrix, similarity_matrix], axis=0) + + # calc metric + if self.eval_metric_func is not None: + metric_dict = self.eval_metric_func(cum_similarity_matrix, + query_img_id, gallery_img_id) + else: + metric_dict = {metric_key: 0.} + metric_info_list = [] + + for key in metric_dict: + if metric_key is None: + metric_key = key + metric_info_list.append("{}: {:.5f}".format(key, metric_dict[key])) + metric_msg = ", ".join(metric_info_list) + logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg)) + + return metric_dict[metric_key] + + def _cal_feature(self, name='gallery'): + all_feas = None + all_image_id = None + all_camera_id = None + if name == 'gallery': + dataloader = self.gallery_dataloader + elif name == 'query': + dataloader = self.query_dataloader + else: + raise RuntimeError("Only support gallery or query dataset") + + has_cam_id = False + for idx, batch in enumerate(dataloader( + )): # load is very time-consuming + batch = [paddle.to_tensor(x) for x in batch] + batch[1] = batch[1].reshape([-1, 1]) + if len(batch) == 3: + has_cam_id = True + batch[2] = batch[2].reshape([-1, 1]) + out = self.model(batch[0], batch[1]) + batch_feas = out["features"] + + # do norm + if self.config["Global"].get("feature_normalize", True): + feas_norm = paddle.sqrt( + paddle.sum(paddle.square(batch_feas), axis=1, + keepdim=True)) + batch_feas = paddle.divide(batch_feas, feas_norm) + + if all_feas is None: + all_feas = batch_feas + if has_cam_id: + all_camera_id = batch[2] + all_image_id = batch[1] + else: + all_feas = paddle.concat([all_feas, batch_feas]) + all_image_id = paddle.concat([all_image_id, batch[1]]) + if has_cam_id: + all_camera_id = paddle.concat([all_camera_id, batch[2]]) + + if paddle.distributed.get_world_size() > 1: + feat_list = [] + img_id_list = [] + cam_id_list = [] + paddle.distributed.all_gather(feat_list, all_feas) + paddle.distributed.all_gather(img_id_list, all_image_id) + all_feas = paddle.concat(feat_list, axis=0) + all_image_id = paddle.concat(img_id_list, axis=0) + if has_cam_id: + paddle.distributed.all_gather(cam_id_list, all_camera_id) + all_camera_id = paddle.concat(cam_id_list, axis=0) + + logger.info("Build {} done, all feat shape: {}, begin to eval..". + format(name, all_feas.shape)) + return all_feas, all_image_id, all_camera_id + @paddle.no_grad() def infer(self, ): total_trainer = paddle.distributed.get_world_size() diff --git a/ppcls/engine/trainer_reid.py b/ppcls/engine/trainer_reid.py deleted file mode 100644 index 361de6c6a2a7263f6aaa9848457bafa0f72abafd..0000000000000000000000000000000000000000 --- a/ppcls/engine/trainer_reid.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import os -import sys -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../../'))) - -import numpy as np -import paddle -from .trainer import Trainer -from ppcls.utils import logger -from ppcls.data import build_dataloader - - -class TrainerReID(Trainer): - def __init__(self, config, mode="train"): - super().__init__(config, mode) - - self.gallery_dataloader = build_dataloader(self.config["DataLoader"], - "Gallery", self.device) - - self.query_dataloader = build_dataloader(self.config["DataLoader"], - "Query", self.device) - - @paddle.no_grad() - def eval(self, epoch_id=0): - output_info = dict() - self.model.eval() - print_batch_step = self.config["Global"]["print_batch_step"] - - # step1. build gallery - gallery_feas, gallery_img_id, gallery_camera_id = self._cal_feature( - name='gallery') - query_feas, query_img_id, query_camera_id = self._cal_feature( - name='query') - - # step2. do evaluation - if "num_split" in self.config["Global"]: - num_split = self.config["Global"]["num_split"] - else: - num_split = 1 - fea_blocks = paddle.split(query_feas, num_or_sections=1) - - total_similarities_matrix = None - - for block_fea in fea_blocks: - similarities_matrix = paddle.matmul( - block_fea, gallery_feas, transpose_y=True) - if total_similarities_matrix is None: - total_similarities_matrix = similarities_matrix - else: - total_similarities_matrix = paddle.concat( - [total_similarities_matrix, similarities_matrix]) - - # distmat = (1 - total_similarities_matrix).numpy() - q_pids = query_img_id.numpy().reshape((query_img_id.shape[0])) - g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0])) - if query_camera_id is not None and gallery_camera_id is not None: - q_camids = query_camera_id.numpy().reshape( - (query_camera_id.shape[0])) - g_camids = gallery_camera_id.numpy().reshape( - (gallery_camera_id.shape[0])) - max_rank = 50 - - num_q, num_g = total_similarities_matrix.shape - if num_g < max_rank: - max_rank = num_g - print('Note: number of gallery samples is quite small, got {}'. - format(num_g)) - - # indices = np.argsort(distmat, axis=1) - indices = paddle.argsort( - total_similarities_matrix, axis=1, descending=True).numpy() - - matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) - - # compute cmc curve for each query - all_cmc = [] - all_AP = [] - all_INP = [] - num_valid_q = 0. # number of valid query - for q_idx in range(num_q): - # get query pid and camid - q_pid = q_pids[q_idx] - q_camid = q_camids[q_idx] - - # remove gallery samples that have the same pid and camid with query - order = indices[q_idx] - if query_camera_id is not None and gallery_camera_id is not None: - remove = (g_pids[order] == q_pid) & ( - g_camids[order] == q_camid) - else: - remove = g_pids[order] == q_pid - keep = np.invert(remove) - - # compute cmc curve - raw_cmc = matches[q_idx][ - keep] # binary vector, positions with value 1 are correct matches - if not np.any(raw_cmc): - # this condition is true when query identity does not appear in gallery - continue - - cmc = raw_cmc.cumsum() - - pos_idx = np.where(raw_cmc == 1) - max_pos_idx = np.max(pos_idx) - inp = cmc[max_pos_idx] / (max_pos_idx + 1.0) - all_INP.append(inp) - - cmc[cmc > 1] = 1 - - all_cmc.append(cmc[:max_rank]) - num_valid_q += 1. - - # compute average precision - # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision - num_rel = raw_cmc.sum() - tmp_cmc = raw_cmc.cumsum() - tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] - tmp_cmc = np.asarray(tmp_cmc) * raw_cmc - AP = tmp_cmc.sum() / num_rel - all_AP.append(AP) - assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' - - all_cmc = np.asarray(all_cmc).astype(np.float32) - all_cmc = all_cmc.sum(0) / num_valid_q - - mAP = np.mean(all_AP) - mINP = np.mean(all_INP) - logger.info( - "[Eval][Epoch {}]: mAP: {:.5f}, mINP: {:.5f},rank_1: {:.5f}, rank_5: {:.5f}" - .format(epoch_id, mAP, mINP, all_cmc[0], all_cmc[4])) - return mAP - - def _cal_feature(self, name='gallery'): - all_feas = None - all_image_id = None - all_camera_id = None - if name == 'gallery': - dataloader = self.gallery_dataloader - elif name == 'query': - dataloader = self.query_dataloader - else: - raise RuntimeError("Only support gallery or query dataset") - - has_cam_id = False - for idx, batch in enumerate(dataloader( - )): # load is very time-consuming - batch = [paddle.to_tensor(x) for x in batch] - batch[1] = batch[1].reshape([-1, 1]) - if len(batch) == 3: - has_cam_id = True - batch[2] = batch[2].reshape([-1, 1]) - out = self.model(batch[0], batch[1]) - batch_feas = out["features"] - - # do norm - if self.config["Global"].get("feature_normalize", True): - feas_norm = paddle.sqrt( - paddle.sum(paddle.square(batch_feas), axis=1, - keepdim=True)) - batch_feas = paddle.divide(batch_feas, feas_norm) - - batch_feas = batch_feas - batch_image_labels = batch[1] - if has_cam_id: - batch_camera_labels = batch[2] - - if all_feas is None: - all_feas = batch_feas - if has_cam_id: - all_camera_id = batch[2] - all_image_id = batch[1] - else: - all_feas = paddle.concat([all_feas, batch_feas]) - all_image_id = paddle.concat([all_image_id, batch[1]]) - if has_cam_id: - all_camera_id = paddle.concat([all_camera_id, batch[2]]) - - if paddle.distributed.get_world_size() > 1: - feat_list = [] - img_id_list = [] - cam_id_list = [] - paddle.distributed.all_gather(feat_list, all_feas) - paddle.distributed.all_gather(img_id_list, all_image_id) - all_feas = paddle.concat(feat_list, axis=0) - all_image_id = paddle.concat(img_id_list, axis=0) - if has_cam_id: - paddle.distributed.all_gather(cam_id_list, all_camera_id) - all_camera_id = paddle.concat(cam_id_list, axis=0) - - logger.info("Build {} done, all feat shape: {}, begin to eval..". - format(name, all_feas.shape)) - return all_feas, all_image_id, all_camera_id diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d566a962d1fc2769f3e6ade61c80d32348adc08a --- /dev/null +++ b/ppcls/metric/__init__.py @@ -0,0 +1,45 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from paddle import nn +import copy +from collections import OrderedDict + +from .metrics import TopkAcc, mAP, mINP, Recallk + + +class CombinedMetrics(nn.Layer): + def __init__(self, config_list): + super().__init__() + self.metric_func_list = [] + assert isinstance(config_list, list), ( + 'operator config should be a list') + for config in config_list: + assert isinstance(config, + dict) and len(config) == 1, "yaml format error" + metric_name = list(config)[0] + metric_params = config[metric_name] + self.metric_func_list.append(eval(metric_name)(**metric_params)) + + def __call__(self, *args, **kwargs): + metric_dict = OrderedDict() + for idx, metric_func in enumerate(self.metric_func_list): + metric_dict.update(metric_func(*args, **kwargs)) + + return metric_dict + + +def build_metrics(config): + metrics_list = CombinedMetrics(copy.deepcopy(config)) + return metrics_list diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..d0044dafb2d077d709f60b059dbb0d15ca77ffcb --- /dev/null +++ b/ppcls/metric/metrics.py @@ -0,0 +1,147 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.nn as nn + + +# TODO: fix the format +class TopkAcc(nn.Layer): + def __init__(self, topk=(1, 5)): + super().__init__() + assert isinstance(topk, (int, list, tuple)) + if isinstance(topk, int): + topk = [topk] + self.topk = topk + + def forward(self, x, label): + if isinstance(x, dict): + x = x["logits"] + + metric_dict = dict() + for k in self.topk: + metric_dict["top{}".format(k)] = paddle.metric.accuracy( + x, label, k=k) + return metric_dict + + +class mAP(nn.Layer): + def __init__(self, max_rank=50): + super().__init__() + self.max_rank = max_rank + + def forward(self, similarities_matrix, query_img_id, gallery_img_id): + metric_dict = dict() + num_q, num_g = similarities_matrix.shape + q_pids = query_img_id.numpy().reshape((query_img_id.shape[0])) + g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0])) + if num_g < self.max_rank: + self.max_rank = num_g + print('Note: number of gallery samples is quite small, got {}'. + format(num_g)) + indices = paddle.argsort( + similarities_matrix, axis=1, descending=True).numpy() + _, all_AP, _ = get_metrics(indices, num_q, num_g, q_pids, g_pids, + self.max_rank) + + mAP = np.mean(all_AP) + metric_dict["mAP"] = mAP + return metric_dict + + +class mINP(nn.Layer): + def __init__(self, max_rank=50): + super().__init__() + self.max_rank = max_rank + + def forward(self, similarities_matrix, query_img_id, gallery_img_id): + metric_dict = dict() + num_q, num_g = similarities_matrix.shape + q_pids = query_img_id.numpy().reshape((query_img_id.shape[0])) + g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0])) + if num_g < self.max_rank: + max_rank = num_g + print('Note: number of gallery samples is quite small, got {}'. + format(num_g)) + indices = paddle.argsort( + similarities_matrix, axis=1, descending=True).numpy() + _, _, all_INP = get_metrics(indices, num_q, num_g, q_pids, g_pids, + self.max_rank) + + mINP = np.mean(all_INP) + metric_dict["mINP"] = mINP + return metric_dict + + +class Recallk(nn.Layer): + def __init__(self, max_rank=50, topk=(1, 5)): + super().__init__() + self.max_rank = max_rank + assert isinstance(topk, (int, list)) + if isinstance(topk, int): + topk = [topk] + self.topk = topk + + def forward(self, similarities_matrix, query_img_id, gallery_img_id): + metric_dict = dict() + num_q, num_g = similarities_matrix.shape + q_pids = query_img_id.numpy().reshape((query_img_id.shape[0])) + g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0])) + if num_g < self.max_rank: + max_rank = num_g + print('Note: number of gallery samples is quite small, got {}'. + format(num_g)) + indices = paddle.argsort( + similarities_matrix, axis=1, descending=True).numpy() + all_cmc, _, _ = get_metrics(indices, num_q, num_g, q_pids, g_pids, + self.max_rank) + + for k in self.topk: + metric_dict["recall{}".format(k)] = all_cmc[k - 1] + return metric_dict + + +def get_metrics(indices, num_q, num_g, q_pids, g_pids, max_rank=50): + all_cmc = [] + all_AP = [] + all_INP = [] + num_valid_q = 0 + matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) + for q_idx in range(num_q): + raw_cmc = matches[q_idx] + if not np.any(raw_cmc): + continue + cmc = raw_cmc.cumsum() + pos_idx = np.where(raw_cmc == 1) + max_pos_idx = np.max(pos_idx) + inp = cmc[max_pos_idx] / (max_pos_idx + 1.0) + all_INP.append(inp) + cmc[cmc > 1] = 1 + + all_cmc.append(cmc[:max_rank]) + num_valid_q += 1. + + num_rel = raw_cmc.sum() + tmp_cmc = raw_cmc.cumsum() + tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] + tmp_cmc = np.asarray(tmp_cmc) * raw_cmc + AP = tmp_cmc.sum() / num_rel + all_AP.append(AP) + assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' + + all_cmc = np.asarray(all_cmc).astype(np.float32) + all_cmc = all_cmc.sum(0) / num_valid_q + + return all_cmc, all_AP, all_INP diff --git a/tools/train.py b/tools/train.py index d1d014639dde029d07789249d5cca4181176009b..aec796c71bf57ab0124844c07db865de916403f4 100644 --- a/tools/train.py +++ b/tools/train.py @@ -22,13 +22,9 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) from ppcls.utils import config from ppcls.engine.trainer import Trainer -from ppcls.engine.trainer_reid import TrainerReID if __name__ == "__main__": args = config.parse_args() config = config.get_config(args.config, overrides=args.override, show=True) - if "Trainer" in config: - trainer = eval(config["Trainer"]["name"])(config, mode="train") - else: - trainer = Trainer(config, mode="train") + trainer = Trainer(config, mode="train") trainer.train()