diff --git a/ppcls/configs/Vehicle/ResNet50_ReID.yaml b/ppcls/configs/Vehicle/ResNet50_ReID.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d29a5f883d8ee586b070892200f425115275a9e6 --- /dev/null +++ b/ppcls/configs/Vehicle/ResNet50_ReID.yaml @@ -0,0 +1,161 @@ +# global configs +Trainer: + name: TrainerReID +Global: + checkpoints: null + pretrained_model: null + output_dir: "./output/" + device: "gpu" + class_num: 30671 + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 160 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: "./inference" + num_split: 1 + feature_normalize: True + + +# model architecture +Arch: + name: "RecModel" + Backbone: + name: "ResNet50" + Stoplayer: + name: "flatten_0" + output_dim: 2048 + embedding_size: 512 + Head: + name: "ArcMargin" + embedding_size: 512 + class_num: 431 + margin: 0.15 + scale: 32 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + - TripletLossV2: + weight: 1.0 + margin: 0.5 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: MultiStepDecay + learning_rate: 0.01 + milestones: [30, 60, 70, 80, 90, 100, 120, 140] + gamma: 0.5 + verbose: False + last_epoch: -1 + regularizer: + name: 'L2' + coeff: 0.0005 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: "VeriWild" + image_root: "/work/dataset/VeRI-Wild/images/" + cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_train.txt" + transform_ops: + - ResizeImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - AugMix: + prob: 0.5 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.5 + sl: 0.02 + sh: 0.4 + r1: 0.3 + mean: [0., 0., 0.] + + sampler: + name: DistributedRandomIdentitySampler + batch_size: 128 + num_instances: 2 + drop_last: False + shuffle: True + loader: + num_workers: 6 + use_shared_memory: False + + Query: + # TOTO: modify to the latest trainer + dataset: + name: "VeriWild" + image_root: "/work/dataset/VeRI-Wild/images" + cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_test_query.txt" + transform_ops: + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 6 + use_shared_memory: False + + Gallery: + # TOTO: modify to the latest trainer + dataset: + name: "VeriWild" + image_root: "/work/dataset/VeRI-Wild/images" + cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_test.txt" + transform_ops: + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 6 + use_shared_memory: False + +Infer: + infer_imgs: "docs/images/whl/demo.jpg" + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 661fb3324ed1f3e196287cd5bfbb95d5578ac350..fca8bf093259e4f42e4c5af5a3de125a2f81cf61 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -27,14 +27,13 @@ from ppcls.data.dataloader.common_dataset import create_operators from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild # sampler -from ppcls.data.dataloader import DistributedRandomIdentitySampler - +from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler from ppcls.data.preprocess import transform def build_dataloader(config, mode, device, seed=None): - assert mode in ['Train', 'Eval', 'Test' - ], "Mode should be Train, Eval or Test." + assert mode in ['Train', 'Eval', 'Test', 'Gallery', 'Query' + ], "Mode should be Train, Eval, Test, Gallery or Query" # build dataset config_dataset = config[mode]['dataset'] config_dataset = copy.deepcopy(config_dataset) diff --git a/ppcls/engine/trainer.py b/ppcls/engine/trainer.py index 6ac3b88b78ea7307430f69dd6dc6b96024ef4740..abbc8579aeb462ff999eb581635b65e4d53e85ff 100644 --- a/ppcls/engine/trainer.py +++ b/ppcls/engine/trainer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -109,8 +109,10 @@ class Trainer(object): def train(self): # build train loss and metric info loss_func = self._build_loss_info(self.config["Loss"]) - - metric_func = self._build_metric_info(self.config["Metric"]) + if "Metric" in self.config: + metric_func = self._build_metric_info(self.config["Metric"]) + else: + metric_func = None train_dataloader = build_dataloader(self.config["DataLoader"], "Train", self.device) @@ -156,7 +158,7 @@ class Trainer(object): else: out = self.model(batch[0], batch[1]) # calc loss - loss_dict = loss_func(out, batch[-1]) + loss_dict = loss_func(out, batch[1]) for key in loss_dict: if not key in output_info: output_info[key] = AverageMeter(key, '7.5f') diff --git a/ppcls/engine/trainer_reid.py b/ppcls/engine/trainer_reid.py new file mode 100644 index 0000000000000000000000000000000000000000..361de6c6a2a7263f6aaa9848457bafa0f72abafd --- /dev/null +++ b/ppcls/engine/trainer_reid.py @@ -0,0 +1,208 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../../'))) + +import numpy as np +import paddle +from .trainer import Trainer +from ppcls.utils import logger +from ppcls.data import build_dataloader + + +class TrainerReID(Trainer): + def __init__(self, config, mode="train"): + super().__init__(config, mode) + + self.gallery_dataloader = build_dataloader(self.config["DataLoader"], + "Gallery", self.device) + + self.query_dataloader = build_dataloader(self.config["DataLoader"], + "Query", self.device) + + @paddle.no_grad() + def eval(self, epoch_id=0): + output_info = dict() + self.model.eval() + print_batch_step = self.config["Global"]["print_batch_step"] + + # step1. build gallery + gallery_feas, gallery_img_id, gallery_camera_id = self._cal_feature( + name='gallery') + query_feas, query_img_id, query_camera_id = self._cal_feature( + name='query') + + # step2. do evaluation + if "num_split" in self.config["Global"]: + num_split = self.config["Global"]["num_split"] + else: + num_split = 1 + fea_blocks = paddle.split(query_feas, num_or_sections=1) + + total_similarities_matrix = None + + for block_fea in fea_blocks: + similarities_matrix = paddle.matmul( + block_fea, gallery_feas, transpose_y=True) + if total_similarities_matrix is None: + total_similarities_matrix = similarities_matrix + else: + total_similarities_matrix = paddle.concat( + [total_similarities_matrix, similarities_matrix]) + + # distmat = (1 - total_similarities_matrix).numpy() + q_pids = query_img_id.numpy().reshape((query_img_id.shape[0])) + g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0])) + if query_camera_id is not None and gallery_camera_id is not None: + q_camids = query_camera_id.numpy().reshape( + (query_camera_id.shape[0])) + g_camids = gallery_camera_id.numpy().reshape( + (gallery_camera_id.shape[0])) + max_rank = 50 + + num_q, num_g = total_similarities_matrix.shape + if num_g < max_rank: + max_rank = num_g + print('Note: number of gallery samples is quite small, got {}'. + format(num_g)) + + # indices = np.argsort(distmat, axis=1) + indices = paddle.argsort( + total_similarities_matrix, axis=1, descending=True).numpy() + + matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) + + # compute cmc curve for each query + all_cmc = [] + all_AP = [] + all_INP = [] + num_valid_q = 0. # number of valid query + for q_idx in range(num_q): + # get query pid and camid + q_pid = q_pids[q_idx] + q_camid = q_camids[q_idx] + + # remove gallery samples that have the same pid and camid with query + order = indices[q_idx] + if query_camera_id is not None and gallery_camera_id is not None: + remove = (g_pids[order] == q_pid) & ( + g_camids[order] == q_camid) + else: + remove = g_pids[order] == q_pid + keep = np.invert(remove) + + # compute cmc curve + raw_cmc = matches[q_idx][ + keep] # binary vector, positions with value 1 are correct matches + if not np.any(raw_cmc): + # this condition is true when query identity does not appear in gallery + continue + + cmc = raw_cmc.cumsum() + + pos_idx = np.where(raw_cmc == 1) + max_pos_idx = np.max(pos_idx) + inp = cmc[max_pos_idx] / (max_pos_idx + 1.0) + all_INP.append(inp) + + cmc[cmc > 1] = 1 + + all_cmc.append(cmc[:max_rank]) + num_valid_q += 1. + + # compute average precision + # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision + num_rel = raw_cmc.sum() + tmp_cmc = raw_cmc.cumsum() + tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] + tmp_cmc = np.asarray(tmp_cmc) * raw_cmc + AP = tmp_cmc.sum() / num_rel + all_AP.append(AP) + assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' + + all_cmc = np.asarray(all_cmc).astype(np.float32) + all_cmc = all_cmc.sum(0) / num_valid_q + + mAP = np.mean(all_AP) + mINP = np.mean(all_INP) + logger.info( + "[Eval][Epoch {}]: mAP: {:.5f}, mINP: {:.5f},rank_1: {:.5f}, rank_5: {:.5f}" + .format(epoch_id, mAP, mINP, all_cmc[0], all_cmc[4])) + return mAP + + def _cal_feature(self, name='gallery'): + all_feas = None + all_image_id = None + all_camera_id = None + if name == 'gallery': + dataloader = self.gallery_dataloader + elif name == 'query': + dataloader = self.query_dataloader + else: + raise RuntimeError("Only support gallery or query dataset") + + has_cam_id = False + for idx, batch in enumerate(dataloader( + )): # load is very time-consuming + batch = [paddle.to_tensor(x) for x in batch] + batch[1] = batch[1].reshape([-1, 1]) + if len(batch) == 3: + has_cam_id = True + batch[2] = batch[2].reshape([-1, 1]) + out = self.model(batch[0], batch[1]) + batch_feas = out["features"] + + # do norm + if self.config["Global"].get("feature_normalize", True): + feas_norm = paddle.sqrt( + paddle.sum(paddle.square(batch_feas), axis=1, + keepdim=True)) + batch_feas = paddle.divide(batch_feas, feas_norm) + + batch_feas = batch_feas + batch_image_labels = batch[1] + if has_cam_id: + batch_camera_labels = batch[2] + + if all_feas is None: + all_feas = batch_feas + if has_cam_id: + all_camera_id = batch[2] + all_image_id = batch[1] + else: + all_feas = paddle.concat([all_feas, batch_feas]) + all_image_id = paddle.concat([all_image_id, batch[1]]) + if has_cam_id: + all_camera_id = paddle.concat([all_camera_id, batch[2]]) + + if paddle.distributed.get_world_size() > 1: + feat_list = [] + img_id_list = [] + cam_id_list = [] + paddle.distributed.all_gather(feat_list, all_feas) + paddle.distributed.all_gather(img_id_list, all_image_id) + all_feas = paddle.concat(feat_list, axis=0) + all_image_id = paddle.concat(img_id_list, axis=0) + if has_cam_id: + paddle.distributed.all_gather(cam_id_list, all_camera_id) + all_camera_id = paddle.concat(cam_id_list, axis=0) + + logger.info("Build {} done, all feat shape: {}, begin to eval..". + format(name, all_feas.shape)) + return all_feas, all_image_id, all_camera_id diff --git a/tools/train.py b/tools/train.py index aec796c71bf57ab0124844c07db865de916403f4..d1d014639dde029d07789249d5cca4181176009b 100644 --- a/tools/train.py +++ b/tools/train.py @@ -22,9 +22,13 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) from ppcls.utils import config from ppcls.engine.trainer import Trainer +from ppcls.engine.trainer_reid import TrainerReID if __name__ == "__main__": args = config.parse_args() config = config.get_config(args.config, overrides=args.override, show=True) - trainer = Trainer(config, mode="train") + if "Trainer" in config: + trainer = eval(config["Trainer"]["name"])(config, mode="train") + else: + trainer = Trainer(config, mode="train") trainer.train()