From 32c99be6aa318d063214cac600349938bbd29260 Mon Sep 17 00:00:00 2001 From: dongshuilong Date: Mon, 16 May 2022 03:50:35 +0000 Subject: [PATCH] add adaface --- ppcls/arch/backbone/model_zoo/ir_net.py | 163 +++++++++++++----- ppcls/arch/gears/adamargin.py | 42 +++-- .../configs/metric_learning/ir18_adaface.yaml | 26 +-- ppcls/data/__init__.py | 3 +- ppcls/data/dataloader/face_dataset.py | 46 +---- ppcls/data/preprocess/__init__.py | 1 + ppcls/data/preprocess/ops/operators.py | 6 +- ppcls/engine/engine.py | 2 +- ppcls/engine/evaluation/adaface.py | 21 ++- 9 files changed, 186 insertions(+), 124 deletions(-) diff --git a/ppcls/arch/backbone/model_zoo/ir_net.py b/ppcls/arch/backbone/model_zoo/ir_net.py index edde2c2b..ecec5811 100644 --- a/ppcls/arch/backbone/model_zoo/ir_net.py +++ b/ppcls/arch/backbone/model_zoo/ir_net.py @@ -1,3 +1,16 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # this code is based on AdaFace(https://github.com/mk-minchul/AdaFace) from collections import namedtuple import paddle @@ -10,28 +23,8 @@ from paddle.nn import BatchNorm1D, BatchNorm2D from paddle.nn import ReLU, Sigmoid from paddle.nn import Layer from paddle.nn import PReLU -from ppcls.arch.backbone.legendary_models.resnet import _load_pretrained -import os - -# def initialize_weights(modules): -# """ Weight initilize, conv2d and linear is initialized with kaiming_normal -# """ -# for m in modules: -# if isinstance(m, nn.Conv2D): -# nn.init.kaiming_normal_(m.weight, -# mode='fan_out', -# nonlinearity='relu') -# if m.bias is not None: -# m.bias.data.zero_() -# elif isinstance(m, nn.BatchNorm2D): -# m.weight.data.fill_(1) -# m.bias.data.zero_() -# elif isinstance(m, nn.Linear): -# nn.init.kaiming_normal_(m.weight, -# mode='fan_out', -# nonlinearity='relu') -# if m.bias is not None: -# m.bias.data.zero_() + +# from ppcls.arch.backbone.legendary_models.resnet import _load_pretrained class Flatten(Layer): @@ -61,8 +54,14 @@ class LinearBlock(Layer): stride, padding, groups=groups, + weight_attr=nn.initializer.KaimingNormal(), bias_attr=None) - self.bn = BatchNorm2D(out_c) + weight_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=1.0)) + bias_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=0.0)) + self.bn = BatchNorm2D( + out_c, weight_attr=weight_attr, bias_attr=bias_attr) def forward(self, x): x = self.conv(x) @@ -106,7 +105,11 @@ class GDC(Layer): stride=(1, 1), padding=(0, 0)) self.conv_6_flatten = Flatten() - self.linear = Linear(in_c, embedding_size, bias_attr=False) + self.linear = Linear( + in_c, + embedding_size, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False) self.bn = BatchNorm1D( embedding_size, weight_attr=False, bias_attr=False) @@ -125,8 +128,7 @@ class SELayer(Layer): def __init__(self, channels, reduction): super(SELayer, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2D(1) - weight_attr = paddle.framework.ParamAttr( - name="linear_weight", + weight_attr = paddle.ParamAttr( initializer=paddle.nn.initializer.XavierUniform()) self.fc1 = Conv2D( channels, @@ -142,6 +144,7 @@ class SELayer(Layer): channels, kernel_size=1, padding=0, + weight_attr=nn.initializer.KaimingNormal(), bias_attr=False) self.sigmoid = Sigmoid() @@ -163,22 +166,44 @@ class BasicBlockIR(Layer): def __init__(self, in_channel, depth, stride): super(BasicBlockIR, self).__init__() + + weight_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=1.0)) + bias_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=0.0)) if in_channel == depth: self.shortcut_layer = MaxPool2D(1, stride) else: self.shortcut_layer = Sequential( Conv2D( - in_channel, depth, (1, 1), stride, bias_attr=False), - BatchNorm2D(depth)) + in_channel, + depth, (1, 1), + stride, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False), + BatchNorm2D( + depth, weight_attr=weight_attr, bias_attr=bias_attr)) self.res_layer = Sequential( - BatchNorm2D(in_channel), + BatchNorm2D( + in_channel, weight_attr=weight_attr, bias_attr=bias_attr), Conv2D( - in_channel, depth, (3, 3), (1, 1), 1, bias_attr=False), - BatchNorm2D(depth), + in_channel, + depth, (3, 3), (1, 1), + 1, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False), + BatchNorm2D( + depth, weight_attr=weight_attr, bias_attr=bias_attr), PReLU(depth), Conv2D( - depth, depth, (3, 3), stride, 1, bias_attr=False), - BatchNorm2D(depth)) + depth, + depth, (3, 3), + stride, + 1, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False), + BatchNorm2D( + depth, weight_attr=weight_attr, bias_attr=bias_attr)) def forward(self, x): shortcut = self.shortcut_layer(x) @@ -194,32 +219,56 @@ class BottleneckIR(Layer): def __init__(self, in_channel, depth, stride): super(BottleneckIR, self).__init__() reduction_channel = depth // 4 + weight_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=1.0)) + bias_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=0.0)) if in_channel == depth: self.shortcut_layer = MaxPool2D(1, stride) else: self.shortcut_layer = Sequential( Conv2D( - in_channel, depth, (1, 1), stride, bias_attr=False), - BatchNorm2D(depth)) + in_channel, + depth, (1, 1), + stride, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False), + BatchNorm2D( + depth, weight_attr=weight_attr, bias_attr=bias_attr)) self.res_layer = Sequential( - BatchNorm2D(in_channel), + BatchNorm2D( + in_channel, weight_attr=weight_attr, bias_attr=bias_attr), Conv2D( in_channel, reduction_channel, (1, 1), (1, 1), 0, + weight_attr=nn.initializer.KaimingNormal(), bias_attr=False), - BatchNorm2D(reduction_channel), + BatchNorm2D( + reduction_channel, + weight_attr=weight_attr, + bias_attr=bias_attr), PReLU(reduction_channel), Conv2D( reduction_channel, reduction_channel, (3, 3), (1, 1), 1, + weight_attr=nn.initializer.KaimingNormal(), bias_attr=False), - BatchNorm2D(reduction_channel), + BatchNorm2D( + reduction_channel, + weight_attr=weight_attr, + bias_attr=bias_attr), PReLU(reduction_channel), Conv2D( - reduction_channel, depth, (1, 1), stride, 0, bias_attr=False), - BatchNorm2D(depth)) + reduction_channel, + depth, (1, 1), + stride, + 0, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False), + BatchNorm2D( + depth, weight_attr=weight_attr, bias_attr=bias_attr)) def forward(self, x): shortcut = self.shortcut_layer(x) @@ -317,10 +366,20 @@ class Backbone(Layer): "num_layers should be 18, 34, 50, 100 or 152" assert mode in ['ir', 'ir_se'], \ "mode should be ir or ir_se" + weight_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=1.0)) + bias_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=0.0)) self.input_layer = Sequential( Conv2D( - 3, 64, (3, 3), 1, 1, bias_attr=False), - BatchNorm2D(64), + 3, + 64, (3, 3), + 1, + 1, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False), + BatchNorm2D( + 64, weight_attr=weight_attr, bias_attr=bias_attr), PReLU(64)) blocks = get_blocks(num_layers) if num_layers <= 100: @@ -338,18 +397,30 @@ class Backbone(Layer): if input_size[0] == 112: self.output_layer = Sequential( - BatchNorm2D(output_channel), + BatchNorm2D( + output_channel, + weight_attr=weight_attr, + bias_attr=bias_attr), Dropout(0.4), Flatten(), - Linear(output_channel * 7 * 7, 512), + Linear( + output_channel * 7 * 7, + 512, + weight_attr=nn.initializer.KaimingNormal()), BatchNorm1D( 512, weight_attr=False, bias_attr=False)) else: self.output_layer = Sequential( - BatchNorm2D(output_channel), + BatchNorm2D( + output_channel, + weight_attr=weight_attr, + bias_attr=bias_attr), Dropout(0.4), Flatten(), - Linear(output_channel * 14 * 14, 512), + Linear( + output_channel * 14 * 14, + 512, + weight_attr=nn.initializer.KaimingNormal()), BatchNorm1D( 512, weight_attr=False, bias_attr=False)) diff --git a/ppcls/arch/gears/adamargin.py b/ppcls/arch/gears/adamargin.py index 2341afd7..1b0f5f24 100644 --- a/ppcls/arch/gears/adamargin.py +++ b/ppcls/arch/gears/adamargin.py @@ -1,4 +1,19 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # This code is based on AdaFace(https://github.com/mk-minchul/AdaFace) +# Paper: AdaFace: Quality Adaptive Margin for Face Recognition from paddle.nn import Layer import math import paddle @@ -21,8 +36,17 @@ class AdaMargin(Layer): t_alpha=1.0, ): super(AdaMargin, self).__init__() self.classnum = class_num + kernel_weight = paddle.uniform( + [embedding_size, class_num], min=-1, max=1) + kernel_weight_norm = paddle.norm( + kernel_weight, p=2, axis=0, keepdim=True) + kernel_weight_norm = paddle.where(kernel_weight_norm > 1e-5, + kernel_weight_norm, + paddle.ones_like(kernel_weight_norm)) + kernel_weight = kernel_weight / kernel_weight_norm self.kernel = self.create_parameter( - [embedding_size, class_num], attr=paddle.nn.initializer.Uniform()) + [embedding_size, class_num], + attr=paddle.nn.initializer.Assign(kernel_weight)) # initial kernel # self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5) @@ -39,14 +63,10 @@ class AdaMargin(Layer): self.register_buffer( 'batch_std', paddle.ones([1]) * 100, persistable=True) - print('\n\AdaFace with the following property') - print('self.m', self.m) - print('self.h', self.h) - print('self.s', self.s) - print('self.t_alpha', self.t_alpha) - - def forward(self, embbedings, norms, label): + def forward(self, embbedings, label): + norms = paddle.norm(embbedings, 2, 1, True) + embbedings = paddle.divide(embbedings, norms) kernel_norm = l2_norm(self.kernel, axis=0) cosine = paddle.mm(embbedings, kernel_norm) cosine = paddle.clip(cosine, -1 + self.eps, @@ -70,7 +90,8 @@ class AdaMargin(Layer): margin_scaler = paddle.clip(margin_scaler, -1, 1) # g_angular - m_arc = paddle.nn.functional.one_hot(label, self.classnum) + m_arc = paddle.nn.functional.one_hot( + label.reshape([-1]), self.classnum) g_angular = self.m * margin_scaler * -1 m_arc = m_arc * g_angular theta = paddle.acos(cosine) @@ -79,7 +100,8 @@ class AdaMargin(Layer): cosine = paddle.cos(theta_m) # g_additive - m_cos = paddle.nn.functional.one_hot(label, self.classnum) + m_cos = paddle.nn.functional.one_hot( + label.reshape([-1]), self.classnum) g_add = self.m + (self.m * margin_scaler) m_cos = m_cos * g_add cosine = cosine - m_cos diff --git a/ppcls/configs/metric_learning/ir18_adaface.yaml b/ppcls/configs/metric_learning/ir18_adaface.yaml index 9079f52c..008aed42 100644 --- a/ppcls/configs/metric_learning/ir18_adaface.yaml +++ b/ppcls/configs/metric_learning/ir18_adaface.yaml @@ -22,14 +22,14 @@ Arch: infer_add_softmax: False Backbone: name: "IR_18" - pretrained: False + input_size: [112, 112] Head: name: "AdaMargin" embedding_size: 512 class_num: 70722 m: 0.4 - scale: 32 - h: 0.3333 + s: 64 + h: 0.333 t_alpha: 0.01 # loss function config for traing/eval process @@ -48,15 +48,15 @@ Optimizer: values: [0.1, 0.01, 0.001, 0.0001] regularizer: name: 'L2' - coeff: 0.0001 + coeff: 0.0005 # data loader for train and eval DataLoader: Train: dataset: name: "AdaFaceDataset" - root_dir: "/work/dataset/face/" - label_path: "/work/dataset/face/train_filter_label.txt" + root_dir: "dataset/face/" + label_path: "dataset/face/train_filter_label.txt" low_res_augmentation_prob: 0.2 crop_augmentation_prob: 0.2 photometric_augmentation_prob: 0.2 @@ -66,7 +66,6 @@ DataLoader: - Normalize: mean: [0.5, 0.5, 0.5] std: [0.5, 0.5, 0.5] - sampler: name: DistributedBatchSampler batch_size: 256 @@ -75,16 +74,21 @@ DataLoader: loader: num_workers: 6 use_shared_memory: True + Eval: dataset: name: FiveValidationDataset - val_data_path: /work/dataset/face/faces_emore - concat_mem_file_name: /work/dataset/face/faces_emore/concat_validation_memfile + val_data_path: dataset/face/faces_emore + concat_mem_file_name: dataset/face/faces_emore/concat_validation_memfile sampler: - name: DistributedBatchSampler + name: BatchSampler batch_size: 256 drop_last: False shuffle: True loader: num_workers: 6 - use_shared_memory: True \ No newline at end of file + use_shared_memory: True +Metric: + Train: + - TopkAcc: + topk: [1, 5] \ No newline at end of file diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 9722bfb8..3109ec8e 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -29,6 +29,7 @@ from ppcls.data.dataloader.logo_dataset import LogoDataset from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset from ppcls.data.dataloader.mix_dataset import MixDataset from ppcls.data.dataloader.person_dataset import Market1501, MSMT17 +from ppcls.data.dataloader.face_dataset import FiveValidationDataset, AdaFaceDataset # sampler from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler @@ -85,7 +86,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): # build sampler config_sampler = config[mode]['sampler'] - if "name" not in config_sampler: + if config_sampler and "name" not in config_sampler: batch_sampler = None batch_size = config_sampler["batch_size"] drop_last = config_sampler["drop_last"] diff --git a/ppcls/data/dataloader/face_dataset.py b/ppcls/data/dataloader/face_dataset.py index ebb25b29..9ee6fdf0 100644 --- a/ppcls/data/dataloader/face_dataset.py +++ b/ppcls/data/dataloader/face_dataset.py @@ -10,28 +10,11 @@ from paddle.vision import transforms from paddle.vision.transforms import functional as F from paddle.io import Dataset from .common_dataset import create_operators +from ppcls.data.preprocess import transform as transform_func # code is based on AdaFace: https://github.com/mk-minchul/AdaFace -def train_dataset(train_dir, label_path, low_res_augmentation_prob, - crop_augmentation_prob, photometric_augmentation_prob): - - # train_dir = os.path.join(data_root, train_data_path) - train_dataset = AdaFaceDataset( - root_dir=train_dir, - label_path=label_path, - transform=transforms.Compose([ - transforms.RandomHorizontalFlip(), transforms.ToTensor(), - transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) - ]), - low_res_augmentation_prob=low_res_augmentation_prob, - crop_augmentation_prob=crop_augmentation_prob, - photometric_augmentation_prob=photometric_augmentation_prob, ) - - return train_dataset - - def _get_image_size(img): if F._is_pil_image(img): return img.size @@ -95,7 +78,7 @@ class AdaFaceDataset(Dataset): sample, _ = self.augment(sample) if self.transform is not None: - sample = self.transform(sample) + sample = transform_func(sample, self.transform) return sample, target @@ -125,16 +108,6 @@ class AdaFaceDataset(Dataset): # photometric augmentation if np.random.random() < self.photometric_augmentation_prob: - # fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ - # self.photometric._get_params(self.photometric.brightness, self.photometric.contrast, - # self.photometric.saturation, self.photometric.hue) - # for fn_id in fn_idx: - # if fn_id == 0 and brightness_factor is not None: - # sample = F.adjust_brightness(sample, brightness_factor) - # elif fn_id == 1 and contrast_factor is not None: - # sample = F.adjust_contrast(sample, contrast_factor) - # elif fn_id == 2 and saturation_factor is not None: - # sample = F.adjust_saturation(sample, saturation_factor) sample = self.photometric(sample) information_score = resize_ratio * crop_ratio return sample, information_score @@ -269,17 +242,4 @@ def get_val_data(data_path): lfw, lfw_issame = get_val_pair(data_path, 'lfw') cplfw, cplfw_issame = get_val_pair(data_path, 'cplfw') calfw, calfw_issame = get_val_pair(data_path, 'calfw') - return agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame - - -if __name__ == "__main__": - t_dataset = train_dataset('/work/dataset/face/', - '/work/dataset/face/train_filter_label.txt', 1, - 1, 1) - img = t_dataset.__getitem__(100) - print(len(t_dataset)) - - val = FiveValidationDataset( - '/work/dataset/face/faces_emore', - '/work/dataset/face/faces_emore/concat_validation_memfile') - a = 1 + return agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame \ No newline at end of file diff --git a/ppcls/data/preprocess/__init__.py b/ppcls/data/preprocess/__init__.py index 62066016..353db19c 100644 --- a/ppcls/data/preprocess/__init__.py +++ b/ppcls/data/preprocess/__init__.py @@ -33,6 +33,7 @@ from ppcls.data.preprocess.ops.operators import AugMix from ppcls.data.preprocess.ops.operators import Pad from ppcls.data.preprocess.ops.operators import ToTensor from ppcls.data.preprocess.ops.operators import Normalize +from ppcls.data.preprocess.ops.operators import RandomHorizontalFlip from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index 157f44f1..e4996abc 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -25,7 +25,7 @@ import cv2 import numpy as np from PIL import Image, ImageOps, __version__ as PILLOW_VERSION from paddle.vision.transforms import ColorJitter as RawColorJitter -from paddle.vision.transforms import ToTensor, Normalize +from paddle.vision.transforms import ToTensor, Normalize, RandomHorizontalFlip from .autoaugment import ImageNetPolicy from .functional import augmentations @@ -463,8 +463,8 @@ class Pad(object): # Process fill color for affine transforms major_found, minor_found = (int(v) for v in PILLOW_VERSION.split('.')[:2]) - major_required, minor_required = ( - int(v) for v in min_pil_version.split('.')[:2]) + major_required, minor_required = (int(v) for v in + min_pil_version.split('.')[:2]) if major_found < major_required or (major_found == major_required and minor_found < minor_required): if fill is None: diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 72e2cf25..ced953d5 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -116,7 +116,7 @@ class Engine(object): self.config["DataLoader"], "Train", self.device, self.use_dali) if self.mode == "eval" or (self.mode == "train" and self.config["Global"]["eval_during_train"]): - if self.eval_mode == "classification": + if self.eval_mode in ["classification", "adaface"]: self.eval_dataloader = build_dataloader( self.config["DataLoader"], "Eval", self.device, self.use_dali) diff --git a/ppcls/engine/evaluation/adaface.py b/ppcls/engine/evaluation/adaface.py index 266b337a..e62144b5 100644 --- a/ppcls/engine/evaluation/adaface.py +++ b/ppcls/engine/evaluation/adaface.py @@ -30,7 +30,7 @@ def fuse_features_with_norm(stacked_embeddings, stacked_norms): assert stacked_embeddings.ndim == 3 # (n_features_to_fuse, batch_size, channel) assert stacked_norms.ndim == 3 # (n_features_to_fuse, batch_size, 1) pre_norm_embeddings = stacked_embeddings * stacked_norms - fused = pre_norm_embeddings.sum(dim=0) + fused = pre_norm_embeddings.sum(axis=0) norm = paddle.norm(fused, 2, 1, True) fused = paddle.divide(fused, norm) return fused, norm @@ -57,12 +57,14 @@ def adaface_eval(engine, epoch_id=0): time_info["reader_cost"].update(time.time() - tic) batch_size = images.shape[0] batch[0] = paddle.to_tensor(images) - embeddings = engine.model(images)["features"] + embeddings = engine.model(images, labels)['features'] norms = paddle.divide(embeddings, paddle.norm(embeddings, 2, 1, True)) + embeddings = paddle.divide(embeddings, norms) fliped_images = paddle.flip(images, axis=[3]) - flipped_embeddings = engine.model(fliped_images)["features"] + flipped_embeddings = engine.model(fliped_images, labels)['features'] flipped_norms = paddle.divide( flipped_embeddings, paddle.norm(flipped_embeddings, 2, 1, True)) + flipped_embeddings = paddle.divide(flipped_embeddings, flipped_norms) stacked_embeddings = paddle.stack( [embeddings, flipped_embeddings], axis=0) stacked_norms = paddle.stack([norms, flipped_norms], axis=0) @@ -114,20 +116,21 @@ def adaface_eval(engine, epoch_id=0): metric_msg = ", ".join([ "{}: {:.5f}".format(key, output_info[key].avg) for key in output_info ]) - face_msg = ", ".join( - ["{}: {:.5f}".format(key, output_info[key]) for key in eval_result]) + face_msg = ", ".join([ + "{}: {:.5f}".format(key, eval_result[key]) + for key in eval_result.keys() + ]) logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg + ", " + face_msg)) - # do not try to save best eval.model - if engine.eval_metric_func is None: - return -1 # return 1st metric in the dict - return output_info[metric_key].avg + return eval_result['all_test_acc'] def cal_metric(all_output_tensor, all_norm_tensor, all_target_tensor, all_dataname_tensor): + all_target_tensor = all_target_tensor.reshape([-1]) + all_dataname_tensor = all_dataname_tensor.reshape([-1]) dataname_to_idx = { "agedb_30": 0, "cfp_fp": 1, -- GitLab