diff --git a/deploy/configs/inference_attr.yaml b/deploy/configs/inference_attr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b49e2af6482e72e01716faceefb8676d87c08347 --- /dev/null +++ b/deploy/configs/inference_attr.yaml @@ -0,0 +1,33 @@ +Global: + infer_imgs: "./images/Pedestrain_Attr.jpg" + inference_model_dir: "../inference/" + batch_size: 1 + use_gpu: True + enable_mkldnn: False + cpu_num_threads: 10 + enable_benchmark: True + use_fp16: False + ir_optim: True + use_tensorrt: False + gpu_mem: 8000 + enable_profile: False + +PreProcess: + transform_ops: + - ResizeImage: + size: [192, 256] + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + channel_num: 3 + - ToCHWImage: + +PostProcess: + main_indicator: Attribute + Attribute: + threshold: 0.5 #default threshold + glasses_threshold: 0.3 #threshold only for glasses + hold_threshold: 0.6 #threshold only for hold + \ No newline at end of file diff --git a/deploy/images/Pedestrain_Attr.jpg b/deploy/images/Pedestrain_Attr.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6a87e856af8c17a3b93617b93ea517b91c508619 Binary files /dev/null and b/deploy/images/Pedestrain_Attr.jpg differ diff --git a/deploy/python/postprocess.py b/deploy/python/postprocess.py index 4f4d005fdff2bf17e04265e136443d0cd837f10e..1107b805085531de74ca1c34d25c98a5d226d531 100644 --- a/deploy/python/postprocess.py +++ b/deploy/python/postprocess.py @@ -64,9 +64,17 @@ class ThreshOutput(object): for idx, probs in enumerate(x): score = probs[1] if score < self.threshold: - result = {"class_ids": [0], "scores": [1 - score], "label_names": [self.label_0]} + result = { + "class_ids": [0], + "scores": [1 - score], + "label_names": [self.label_0] + } else: - result = {"class_ids": [1], "scores": [score], "label_names": [self.label_1]} + result = { + "class_ids": [1], + "scores": [score], + "label_names": [self.label_1] + } if file_names is not None: result["file_name"] = file_names[idx] y.append(result) @@ -179,3 +187,96 @@ class Binarize(object): byte[:, i:i + 1] = np.dot(x[:, i * 8:(i + 1) * 8], self.unit) return byte + + +class Attribute(object): + def __init__(self, + threshold=0.5, + glasses_threshold=0.3, + hold_threshold=0.6): + self.threshold = threshold + self.glasses_threshold = glasses_threshold + self.hold_threshold = hold_threshold + + def __call__(self, batch_preds, file_names=None): + # postprocess output of predictor + age_list = ['AgeLess18', 'Age18-60', 'AgeOver60'] + direct_list = ['Front', 'Side', 'Back'] + bag_list = ['HandBag', 'ShoulderBag', 'Backpack'] + upper_list = ['UpperStride', 'UpperLogo', 'UpperPlaid', 'UpperSplice'] + lower_list = [ + 'LowerStripe', 'LowerPattern', 'LongCoat', 'Trousers', 'Shorts', + 'Skirt&Dress' + ] + batch_res = [] + for res in batch_preds: + res = res.tolist() + label_res = [] + # gender + gender = 'Female' if res[22] > self.threshold else 'Male' + label_res.append(gender) + # age + age = age_list[np.argmax(res[19:22])] + label_res.append(age) + # direction + direction = direct_list[np.argmax(res[23:])] + label_res.append(direction) + # glasses + glasses = 'Glasses: ' + if res[1] > self.glasses_threshold: + glasses += 'True' + else: + glasses += 'False' + label_res.append(glasses) + # hat + hat = 'Hat: ' + if res[0] > self.threshold: + hat += 'True' + else: + hat += 'False' + label_res.append(hat) + # hold obj + hold_obj = 'HoldObjectsInFront: ' + if res[18] > self.hold_threshold: + hold_obj += 'True' + else: + hold_obj += 'False' + label_res.append(hold_obj) + # bag + bag = bag_list[np.argmax(res[15:18])] + bag_score = res[15 + np.argmax(res[15:18])] + bag_label = bag if bag_score > self.threshold else 'No bag' + label_res.append(bag_label) + # upper + upper_res = res[4:8] + upper_label = 'Upper:' + sleeve = 'LongSleeve' if res[3] > res[2] else 'ShortSleeve' + upper_label += ' {}'.format(sleeve) + for i, r in enumerate(upper_res): + if r > self.threshold: + upper_label += ' {}'.format(upper_list[i]) + label_res.append(upper_label) + # lower + lower_res = res[8:14] + lower_label = 'Lower: ' + has_lower = False + for i, l in enumerate(lower_res): + if l > self.threshold: + lower_label += ' {}'.format(lower_list[i]) + has_lower = True + if not has_lower: + lower_label += ' {}'.format(lower_list[np.argmax(lower_res)]) + + label_res.append(lower_label) + # shoe + shoe = 'Boots' if res[14] > self.threshold else 'No boots' + label_res.append(shoe) + + threshold_list = [0.5] * len(res) + threshold_list[1] = self.glasses_threshold + threshold_list[18] = self.hold_threshold + pred_res = (np.array(res) > np.array(threshold_list) + ).astype(np.int8).tolist() + + batch_res.append([label_res, pred_res]) + return batch_res diff --git a/deploy/python/predict_cls.py b/deploy/python/predict_cls.py index 64c07ea875eaa2c456393328183b7270080a64d1..41b46090a7f118f401beefd12a9e9d2513cb8bfb 100644 --- a/deploy/python/predict_cls.py +++ b/deploy/python/predict_cls.py @@ -138,13 +138,21 @@ def main(config): continue batch_results = cls_predictor.predict(batch_imgs) for number, result_dict in enumerate(batch_results): - filename = batch_names[number] - clas_ids = result_dict["class_ids"] - scores_str = "[{}]".format(", ".join("{:.2f}".format( - r) for r in result_dict["scores"])) - label_names = result_dict["label_names"] - print("{}:\tclass id(s): {}, score(s): {}, label_name(s): {}". - format(filename, clas_ids, scores_str, label_names)) + if "Attribute" in config["PostProcess"]: + filename = batch_names[number] + attr_message = result_dict[0] + pred_res = result_dict[1] + print("{}:\t attributes: {}, \npredict output: {}".format( + filename, attr_message, pred_res)) + else: + filename = batch_names[number] + clas_ids = result_dict["class_ids"] + scores_str = "[{}]".format(", ".join("{:.2f}".format( + r) for r in result_dict["scores"])) + label_names = result_dict["label_names"] + print( + "{}:\tclass id(s): {}, score(s): {}, label_name(s): {}". + format(filename, clas_ids, scores_str, label_names)) batch_imgs = [] batch_names = [] if cls_predictor.benchmark: diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index bc41454c7b54806270110987d59f1657ac95cafa..e957358479cb98d8bde3dac0d4b2785b8965c7bf 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -70,6 +70,7 @@ from ppcls.arch.backbone.model_zoo.van import VAN_tiny from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1 from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid from ppcls.arch.backbone.variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh +from ppcls.arch.backbone.model_zoo.adaface_ir_net import AdaFace_IR_18, AdaFace_IR_34, AdaFace_IR_50, AdaFace_IR_101, AdaFace_IR_152, AdaFace_IR_SE_50, AdaFace_IR_SE_101, AdaFace_IR_SE_152, AdaFace_IR_SE_200 # help whl get all the models' api (class type) and components' api (func type) diff --git a/ppcls/arch/backbone/legendary_models/resnet.py b/ppcls/arch/backbone/legendary_models/resnet.py index ca75c2eaa4f2d7f4a604a312ed591c10811105c4..4a3d40f37fb3ed0008777643469841ef3ac38b80 100644 --- a/ppcls/arch/backbone/legendary_models/resnet.py +++ b/ppcls/arch/backbone/legendary_models/resnet.py @@ -137,8 +137,11 @@ class ConvBNLayer(TheseusLayer): weight_attr = ParamAttr(learning_rate=lr_mult, trainable=True) bias_attr = ParamAttr(learning_rate=lr_mult, trainable=True) - self.bn = BatchNorm2D( - num_filters, weight_attr=weight_attr, bias_attr=bias_attr) + self.bn = BatchNorm( + num_filters, + param_attr=ParamAttr(learning_rate=lr_mult), + bias_attr=ParamAttr(learning_rate=lr_mult), + data_layout=data_format) self.relu = nn.ReLU() def forward(self, x): @@ -287,7 +290,8 @@ class ResNet(TheseusLayer): data_format="NCHW", input_image_channel=3, return_patterns=None, - return_stages=None): + return_stages=None, + **kargs): super().__init__() self.cfg = config diff --git a/ppcls/arch/backbone/model_zoo/adaface_ir_net.py b/ppcls/arch/backbone/model_zoo/adaface_ir_net.py new file mode 100644 index 0000000000000000000000000000000000000000..47de152b646e6f824e5a888692b770d9e146223b --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/adaface_ir_net.py @@ -0,0 +1,529 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# this code is based on AdaFace(https://github.com/mk-minchul/AdaFace) +from collections import namedtuple +import paddle +import paddle.nn as nn +from paddle.nn import Dropout +from paddle.nn import MaxPool2D +from paddle.nn import Sequential +from paddle.nn import Conv2D, Linear +from paddle.nn import BatchNorm1D, BatchNorm2D +from paddle.nn import ReLU, Sigmoid +from paddle.nn import Layer +from paddle.nn import PReLU + +# from ppcls.arch.backbone.legendary_models.resnet import _load_pretrained + + +class Flatten(Layer): + """ Flat tensor + """ + + def forward(self, input): + return paddle.reshape(input, [input.shape[0], -1]) + + +class LinearBlock(Layer): + """ Convolution block without no-linear activation layer + """ + + def __init__(self, + in_c, + out_c, + kernel=(1, 1), + stride=(1, 1), + padding=(0, 0), + groups=1): + super(LinearBlock, self).__init__() + self.conv = Conv2D( + in_c, + out_c, + kernel, + stride, + padding, + groups=groups, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=None) + weight_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=1.0)) + bias_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=0.0)) + self.bn = BatchNorm2D( + out_c, weight_attr=weight_attr, bias_attr=bias_attr) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class GNAP(Layer): + """ Global Norm-Aware Pooling block + """ + + def __init__(self, in_c): + super(GNAP, self).__init__() + self.bn1 = BatchNorm2D(in_c, weight_attr=False, bias_attr=False) + self.pool = nn.AdaptiveAvgPool2D((1, 1)) + self.bn2 = BatchNorm1D(in_c, weight_attr=False, bias_attr=False) + + def forward(self, x): + x = self.bn1(x) + x_norm = paddle.norm(x, 2, 1, True) + x_norm_mean = paddle.mean(x_norm) + weight = x_norm_mean / x_norm + x = x * weight + x = self.pool(x) + x = x.view(x.shape[0], -1) + feature = self.bn2(x) + return feature + + +class GDC(Layer): + """ Global Depthwise Convolution block + """ + + def __init__(self, in_c, embedding_size): + super(GDC, self).__init__() + self.conv_6_dw = LinearBlock( + in_c, + in_c, + groups=in_c, + kernel=(7, 7), + stride=(1, 1), + padding=(0, 0)) + self.conv_6_flatten = Flatten() + self.linear = Linear( + in_c, + embedding_size, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False) + self.bn = BatchNorm1D( + embedding_size, weight_attr=False, bias_attr=False) + + def forward(self, x): + x = self.conv_6_dw(x) + x = self.conv_6_flatten(x) + x = self.linear(x) + x = self.bn(x) + return x + + +class SELayer(Layer): + """ SE block + """ + + def __init__(self, channels, reduction): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2D(1) + weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.XavierUniform()) + self.fc1 = Conv2D( + channels, + channels // reduction, + kernel_size=1, + padding=0, + weight_attr=weight_attr, + bias_attr=False) + + self.relu = ReLU() + self.fc2 = Conv2D( + channels // reduction, + channels, + kernel_size=1, + padding=0, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False) + + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + + return module_input * x + + +class BasicBlockIR(Layer): + """ BasicBlock for IRNet + """ + + def __init__(self, in_channel, depth, stride): + super(BasicBlockIR, self).__init__() + + weight_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=1.0)) + bias_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=0.0)) + if in_channel == depth: + self.shortcut_layer = MaxPool2D(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2D( + in_channel, + depth, (1, 1), + stride, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False), + BatchNorm2D( + depth, weight_attr=weight_attr, bias_attr=bias_attr)) + self.res_layer = Sequential( + BatchNorm2D( + in_channel, weight_attr=weight_attr, bias_attr=bias_attr), + Conv2D( + in_channel, + depth, (3, 3), (1, 1), + 1, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False), + BatchNorm2D( + depth, weight_attr=weight_attr, bias_attr=bias_attr), + PReLU(depth), + Conv2D( + depth, + depth, (3, 3), + stride, + 1, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False), + BatchNorm2D( + depth, weight_attr=weight_attr, bias_attr=bias_attr)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + + return res + shortcut + + +class BottleneckIR(Layer): + """ BasicBlock with bottleneck for IRNet + """ + + def __init__(self, in_channel, depth, stride): + super(BottleneckIR, self).__init__() + reduction_channel = depth // 4 + weight_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=1.0)) + bias_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=0.0)) + if in_channel == depth: + self.shortcut_layer = MaxPool2D(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2D( + in_channel, + depth, (1, 1), + stride, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False), + BatchNorm2D( + depth, weight_attr=weight_attr, bias_attr=bias_attr)) + self.res_layer = Sequential( + BatchNorm2D( + in_channel, weight_attr=weight_attr, bias_attr=bias_attr), + Conv2D( + in_channel, + reduction_channel, (1, 1), (1, 1), + 0, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False), + BatchNorm2D( + reduction_channel, + weight_attr=weight_attr, + bias_attr=bias_attr), + PReLU(reduction_channel), + Conv2D( + reduction_channel, + reduction_channel, (3, 3), (1, 1), + 1, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False), + BatchNorm2D( + reduction_channel, + weight_attr=weight_attr, + bias_attr=bias_attr), + PReLU(reduction_channel), + Conv2D( + reduction_channel, + depth, (1, 1), + stride, + 0, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False), + BatchNorm2D( + depth, weight_attr=weight_attr, bias_attr=bias_attr)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + + return res + shortcut + + +class BasicBlockIRSE(BasicBlockIR): + def __init__(self, in_channel, depth, stride): + super(BasicBlockIRSE, self).__init__(in_channel, depth, stride) + self.res_layer.add_sublayer("se_block", SELayer(depth, 16)) + + +class BottleneckIRSE(BottleneckIR): + def __init__(self, in_channel, depth, stride): + super(BottleneckIRSE, self).__init__(in_channel, depth, stride) + self.res_layer.add_sublayer("se_block", SELayer(depth, 16)) + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + '''A named tuple describing a ResNet block.''' + + +def get_block(in_channel, depth, num_units, stride=2): + + return [Bottleneck(in_channel, depth, stride)] +\ + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 18: + blocks = [ + get_block( + in_channel=64, depth=64, num_units=2), get_block( + in_channel=64, depth=128, num_units=2), get_block( + in_channel=128, depth=256, num_units=2), get_block( + in_channel=256, depth=512, num_units=2) + ] + elif num_layers == 34: + blocks = [ + get_block( + in_channel=64, depth=64, num_units=3), get_block( + in_channel=64, depth=128, num_units=4), get_block( + in_channel=128, depth=256, num_units=6), get_block( + in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 50: + blocks = [ + get_block( + in_channel=64, depth=64, num_units=3), get_block( + in_channel=64, depth=128, num_units=4), get_block( + in_channel=128, depth=256, num_units=14), get_block( + in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block( + in_channel=64, depth=64, num_units=3), get_block( + in_channel=64, depth=128, num_units=13), get_block( + in_channel=128, depth=256, num_units=30), get_block( + in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block( + in_channel=64, depth=256, num_units=3), get_block( + in_channel=256, depth=512, num_units=8), get_block( + in_channel=512, depth=1024, num_units=36), get_block( + in_channel=1024, depth=2048, num_units=3) + ] + elif num_layers == 200: + blocks = [ + get_block( + in_channel=64, depth=256, num_units=3), get_block( + in_channel=256, depth=512, num_units=24), get_block( + in_channel=512, depth=1024, num_units=36), get_block( + in_channel=1024, depth=2048, num_units=3) + ] + + return blocks + + +class Backbone(Layer): + def __init__(self, input_size, num_layers, mode='ir'): + """ Args: + input_size: input_size of backbone + num_layers: num_layers of backbone + mode: support ir or irse + """ + super(Backbone, self).__init__() + assert input_size[0] in [112, 224], \ + "input_size should be [112, 112] or [224, 224]" + assert num_layers in [18, 34, 50, 100, 152, 200], \ + "num_layers should be 18, 34, 50, 100 or 152" + assert mode in ['ir', 'ir_se'], \ + "mode should be ir or ir_se" + weight_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=1.0)) + bias_attr = paddle.ParamAttr( + regularizer=None, initializer=nn.initializer.Constant(value=0.0)) + self.input_layer = Sequential( + Conv2D( + 3, + 64, (3, 3), + 1, + 1, + weight_attr=nn.initializer.KaimingNormal(), + bias_attr=False), + BatchNorm2D( + 64, weight_attr=weight_attr, bias_attr=bias_attr), + PReLU(64)) + blocks = get_blocks(num_layers) + if num_layers <= 100: + if mode == 'ir': + unit_module = BasicBlockIR + elif mode == 'ir_se': + unit_module = BasicBlockIRSE + output_channel = 512 + else: + if mode == 'ir': + unit_module = BottleneckIR + elif mode == 'ir_se': + unit_module = BottleneckIRSE + output_channel = 2048 + + if input_size[0] == 112: + self.output_layer = Sequential( + BatchNorm2D( + output_channel, + weight_attr=weight_attr, + bias_attr=bias_attr), + Dropout(0.4), + Flatten(), + Linear( + output_channel * 7 * 7, + 512, + weight_attr=nn.initializer.KaimingNormal()), + BatchNorm1D( + 512, weight_attr=False, bias_attr=False)) + else: + self.output_layer = Sequential( + BatchNorm2D( + output_channel, + weight_attr=weight_attr, + bias_attr=bias_attr), + Dropout(0.4), + Flatten(), + Linear( + output_channel * 14 * 14, + 512, + weight_attr=nn.initializer.KaimingNormal()), + BatchNorm1D( + 512, weight_attr=False, bias_attr=False)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append( + unit_module(bottleneck.in_channel, bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + # initialize_weights(self.modules()) + + def forward(self, x): + + # current code only supports one extra image + # it comes with a extra dimension for number of extra image. We will just squeeze it out for now + x = self.input_layer(x) + + for idx, module in enumerate(self.body): + x = module(x) + + x = self.output_layer(x) + # norm = paddle.norm(x, 2, 1, True) + # output = paddle.divide(x, norm) + # return output, norm + return x + + +def AdaFace_IR_18(input_size=(112, 112)): + """ Constructs a ir-18 model. + """ + model = Backbone(input_size, 18, 'ir') + return model + + +def AdaFace_IR_34(input_size=(112, 112)): + """ Constructs a ir-34 model. + """ + model = Backbone(input_size, 34, 'ir') + + return model + + +def AdaFace_IR_50(input_size=(112, 112)): + """ Constructs a ir-50 model. + """ + model = Backbone(input_size, 50, 'ir') + + return model + + +def AdaFace_IR_101(input_size=(112, 112)): + """ Constructs a ir-101 model. + """ + model = Backbone(input_size, 100, 'ir') + + return model + + +def AdaFace_IR_152(input_size=(112, 112)): + """ Constructs a ir-152 model. + """ + model = Backbone(input_size, 152, 'ir') + + return model + + +def AdaFace_IR_200(input_size=(112, 112)): + """ Constructs a ir-200 model. + """ + model = Backbone(input_size, 200, 'ir') + + return model + + +def AdaFace_IR_SE_50(input_size=(112, 112)): + """ Constructs a ir_se-50 model. + """ + model = Backbone(input_size, 50, 'ir_se') + + return model + + +def AdaFace_IR_SE_101(input_size=(112, 112)): + """ Constructs a ir_se-101 model. + """ + model = Backbone(input_size, 100, 'ir_se') + + return model + + +def AdaFace_IR_SE_152(input_size=(112, 112)): + """ Constructs a ir_se-152 model. + """ + model = Backbone(input_size, 152, 'ir_se') + + return model + + +def AdaFace_IR_SE_200(input_size=(112, 112)): + """ Constructs a ir_se-200 model. + """ + model = Backbone(input_size, 200, 'ir_se') + + return model diff --git a/ppcls/arch/gears/__init__.py b/ppcls/arch/gears/__init__.py index 8757aa4aeb4a510857ca4dc1c60696b1d6e86a0b..871967804e21c362935915942aa3f621207b934e 100644 --- a/ppcls/arch/gears/__init__.py +++ b/ppcls/arch/gears/__init__.py @@ -19,6 +19,7 @@ from .fc import FC from .vehicle_neck import VehicleNeck from paddle.nn import Tanh from .bnneck import BNNeck +from .adamargin import AdaMargin __all__ = ['build_gear'] @@ -26,7 +27,7 @@ __all__ = ['build_gear'] def build_gear(config): support_dict = [ 'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh', - 'BNNeck' + 'BNNeck', 'AdaMargin' ] module_name = config.pop('name') assert module_name in support_dict, Exception( diff --git a/ppcls/arch/gears/adamargin.py b/ppcls/arch/gears/adamargin.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0f5f245dbbe2c282f726b7d5be3634d6df912c --- /dev/null +++ b/ppcls/arch/gears/adamargin.py @@ -0,0 +1,111 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is based on AdaFace(https://github.com/mk-minchul/AdaFace) +# Paper: AdaFace: Quality Adaptive Margin for Face Recognition +from paddle.nn import Layer +import math +import paddle + + +def l2_norm(input, axis=1): + norm = paddle.norm(input, 2, axis, True) + output = paddle.divide(input, norm) + return output + + +class AdaMargin(Layer): + def __init__( + self, + embedding_size=512, + class_num=70722, + m=0.4, + h=0.333, + s=64., + t_alpha=1.0, ): + super(AdaMargin, self).__init__() + self.classnum = class_num + kernel_weight = paddle.uniform( + [embedding_size, class_num], min=-1, max=1) + kernel_weight_norm = paddle.norm( + kernel_weight, p=2, axis=0, keepdim=True) + kernel_weight_norm = paddle.where(kernel_weight_norm > 1e-5, + kernel_weight_norm, + paddle.ones_like(kernel_weight_norm)) + kernel_weight = kernel_weight / kernel_weight_norm + self.kernel = self.create_parameter( + [embedding_size, class_num], + attr=paddle.nn.initializer.Assign(kernel_weight)) + + # initial kernel + # self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5) + self.m = m + self.eps = 1e-3 + self.h = h + self.s = s + + # ema prep + self.t_alpha = t_alpha + self.register_buffer('t', paddle.zeros([1]), persistable=True) + self.register_buffer( + 'batch_mean', paddle.ones([1]) * 20, persistable=True) + self.register_buffer( + 'batch_std', paddle.ones([1]) * 100, persistable=True) + + def forward(self, embbedings, label): + + norms = paddle.norm(embbedings, 2, 1, True) + embbedings = paddle.divide(embbedings, norms) + kernel_norm = l2_norm(self.kernel, axis=0) + cosine = paddle.mm(embbedings, kernel_norm) + cosine = paddle.clip(cosine, -1 + self.eps, + 1 - self.eps) # for stability + + safe_norms = paddle.clip(norms, min=0.001, max=100) # for stability + safe_norms = safe_norms.clone().detach() + + # update batchmean batchstd + with paddle.no_grad(): + mean = safe_norms.mean().detach() + std = safe_norms.std().detach() + self.batch_mean = mean * self.t_alpha + (1 - self.t_alpha + ) * self.batch_mean + self.batch_std = std * self.t_alpha + (1 - self.t_alpha + ) * self.batch_std + + margin_scaler = (safe_norms - self.batch_mean) / ( + self.batch_std + self.eps) # 66% between -1, 1 + margin_scaler = margin_scaler * self.h # 68% between -0.333 ,0.333 when h:0.333 + margin_scaler = paddle.clip(margin_scaler, -1, 1) + + # g_angular + m_arc = paddle.nn.functional.one_hot( + label.reshape([-1]), self.classnum) + g_angular = self.m * margin_scaler * -1 + m_arc = m_arc * g_angular + theta = paddle.acos(cosine) + theta_m = paddle.clip( + theta + m_arc, min=self.eps, max=math.pi - self.eps) + cosine = paddle.cos(theta_m) + + # g_additive + m_cos = paddle.nn.functional.one_hot( + label.reshape([-1]), self.classnum) + g_add = self.m + (self.m * margin_scaler) + m_cos = m_cos * g_add + cosine = cosine - m_cos + + # scale + scaled_cosine_m = cosine * self.s + return scaled_cosine_m diff --git a/ppcls/configs/Attr/StrongBaselineAttr.yaml b/ppcls/configs/Attr/StrongBaselineAttr.yaml index 7501669bc5707fa2577c7d0b573a3b23cd2a0213..2324015d667a09a56570677713792b16f1b2ed03 100644 --- a/ppcls/configs/Attr/StrongBaselineAttr.yaml +++ b/ppcls/configs/Attr/StrongBaselineAttr.yaml @@ -20,6 +20,7 @@ Arch: name: "ResNet50" pretrained: True class_num: 26 + infer_add_softmax: False # loss function config for traing/eval process Loss: @@ -110,5 +111,3 @@ DataLoader: Metric: Eval: - ATTRMetric: - - diff --git a/ppcls/configs/metric_learning/adaface_ir18.yaml b/ppcls/configs/metric_learning/adaface_ir18.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2cbfe5da43763701b244b2422bf9ad82b19ef4d6 --- /dev/null +++ b/ppcls/configs/metric_learning/adaface_ir18.yaml @@ -0,0 +1,105 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: "./output/" + device: "gpu" + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 26 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 112, 112] + save_inference_dir: "./inference" + eval_mode: "adaface" + +# model architecture +Arch: + name: "RecModel" + infer_output_key: "features" + infer_add_softmax: False + Backbone: + name: "AdaFace_IR_18" + input_size: [112, 112] + Head: + name: "AdaMargin" + embedding_size: 512 + class_num: 70722 + m: 0.4 + s: 64 + h: 0.333 + t_alpha: 0.01 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Piecewise + learning_rate: 0.1 + decay_epochs: [12, 20, 24] + values: [0.1, 0.01, 0.001, 0.0001] + regularizer: + name: 'L2' + coeff: 0.0005 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: "AdaFaceDataset" + root_dir: "dataset/face/" + label_path: "dataset/face/train_filter_label.txt" + transform: + - CropWithPadding: + prob: 0.2 + padding_num: 0 + size: [112, 112] + scale: [0.2, 1.0] + ratio: [0.75, 1.3333333333333333] + - RandomInterpolationAugment: + prob: 0.2 + - ColorJitter: + prob: 0.2 + brightness: 0.5 + contrast: 0.5 + saturation: 0.5 + hue: 0 + - RandomHorizontalFlip: + - ToTensor: + - Normalize: + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + sampler: + name: DistributedBatchSampler + batch_size: 256 + drop_last: False + shuffle: True + loader: + num_workers: 6 + use_shared_memory: True + + Eval: + dataset: + name: FiveValidationDataset + val_data_path: dataset/face/faces_emore + concat_mem_file_name: dataset/face/faces_emore/concat_validation_memfile + sampler: + name: BatchSampler + batch_size: 256 + drop_last: False + shuffle: True + loader: + num_workers: 6 + use_shared_memory: True +Metric: + Train: + - TopkAcc: + topk: [1, 5] \ No newline at end of file diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 9fc4d760be545ffa93652c80d285e17ad0c8ae57..80cf3bc9af826e935fe0fe6ccf8cad8d6924d370 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -30,6 +30,7 @@ from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset from ppcls.data.dataloader.mix_dataset import MixDataset from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset from ppcls.data.dataloader.person_dataset import Market1501, MSMT17 +from ppcls.data.dataloader.face_dataset import FiveValidationDataset, AdaFaceDataset # sampler @@ -88,7 +89,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): # build sampler config_sampler = config[mode]['sampler'] - if "name" not in config_sampler: + if config_sampler and "name" not in config_sampler: batch_sampler = None batch_size = config_sampler["batch_size"] drop_last = config_sampler["drop_last"] diff --git a/ppcls/data/dataloader/__init__.py b/ppcls/data/dataloader/__init__.py index 2b1d92b76bd202e36086f21a3a092c3673277690..796f4b458410e5b4b8540b72dd663711c4ad9f46 100644 --- a/ppcls/data/dataloader/__init__.py +++ b/ppcls/data/dataloader/__init__.py @@ -10,3 +10,4 @@ from ppcls.data.dataloader.mix_sampler import MixSampler from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler from ppcls.data.dataloader.pk_sampler import PKSampler from ppcls.data.dataloader.person_dataset import Market1501, MSMT17 +from ppcls.data.dataloader.face_dataset import AdaFaceDataset, FiveValidationDataset diff --git a/ppcls/data/dataloader/face_dataset.py b/ppcls/data/dataloader/face_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a32cc2c5f89aa8c8e4904e7decc6ec5fb996aab3 --- /dev/null +++ b/ppcls/data/dataloader/face_dataset.py @@ -0,0 +1,163 @@ +import os +import json +import numpy as np +from PIL import Image +import cv2 +import paddle +import paddle.vision.datasets as datasets +from paddle.vision import transforms +from paddle.vision.transforms import functional as F +from paddle.io import Dataset +from .common_dataset import create_operators +from ppcls.data.preprocess import transform as transform_func + +# code is based on AdaFace: https://github.com/mk-minchul/AdaFace + + +class AdaFaceDataset(Dataset): + def __init__(self, root_dir, label_path, transform=None): + self.root_dir = root_dir + self.transform = create_operators(transform) + + with open(label_path) as fd: + lines = fd.readlines() + self.samples = [] + for l in lines: + l = l.strip().split() + self.samples.append([os.path.join(root_dir, l[0]), int(l[1])]) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + [path, target] = self.samples[index] + with open(path, 'rb') as f: + img = Image.open(f) + sample = img.convert('RGB') + + # if 'WebFace' in self.root: + # # swap rgb to bgr since image is in rgb for webface + # sample = Image.fromarray(np.asarray(sample)[:, :, ::-1] + if self.transform is not None: + sample = transform_func(sample, self.transform) + return sample, target + + +class FiveValidationDataset(Dataset): + def __init__(self, val_data_path, concat_mem_file_name): + ''' + concatenates all validation datasets from emore + val_data_dict = { + 'agedb_30': (agedb_30, agedb_30_issame), + "cfp_fp": (cfp_fp, cfp_fp_issame), + "lfw": (lfw, lfw_issame), + "cplfw": (cplfw, cplfw_issame), + "calfw": (calfw, calfw_issame), + } + agedb_30: 0 + cfp_fp: 1 + lfw: 2 + cplfw: 3 + calfw: 4 + ''' + val_data = get_val_data(val_data_path) + age_30, cfp_fp, lfw, age_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame = val_data + val_data_dict = { + 'agedb_30': (age_30, age_30_issame), + "cfp_fp": (cfp_fp, cfp_fp_issame), + "lfw": (lfw, lfw_issame), + "cplfw": (cplfw, cplfw_issame), + "calfw": (calfw, calfw_issame), + } + self.dataname_to_idx = { + "agedb_30": 0, + "cfp_fp": 1, + "lfw": 2, + "cplfw": 3, + "calfw": 4 + } + + self.val_data_dict = val_data_dict + # concat all dataset + all_imgs = [] + all_issame = [] + all_dataname = [] + key_orders = [] + for key, (imgs, issame) in val_data_dict.items(): + all_imgs.append(imgs) + dup_issame = [ + ] # hacky way to make the issame length same as imgs. [1, 1, 0, 0, ...] + for same in issame: + dup_issame.append(same) + dup_issame.append(same) + all_issame.append(dup_issame) + all_dataname.append([self.dataname_to_idx[key]] * len(imgs)) + key_orders.append(key) + assert key_orders == ['agedb_30', 'cfp_fp', 'lfw', 'cplfw', 'calfw'] + + if isinstance(all_imgs[0], np.memmap): + self.all_imgs = read_memmap(concat_mem_file_name) + else: + self.all_imgs = np.concatenate(all_imgs) + + self.all_issame = np.concatenate(all_issame) + self.all_dataname = np.concatenate(all_dataname) + + def __getitem__(self, index): + x_np = self.all_imgs[index].copy() + x = paddle.to_tensor(x_np) + y = self.all_issame[index] + dataname = self.all_dataname[index] + return x, y, dataname, index + + def __len__(self): + return len(self.all_imgs) + + +def read_memmap(mem_file_name): + # r+ mode: Open existing file for reading and writing + with open(mem_file_name + '.conf', 'r') as file: + memmap_configs = json.load(file) + return np.memmap(mem_file_name, mode='r+', \ + shape=tuple(memmap_configs['shape']), \ + dtype=memmap_configs['dtype']) + + +def get_val_pair(path, name, use_memfile=True): + # installing bcolz should set proxy to access internet + import bcolz + if use_memfile: + mem_file_dir = os.path.join(path, name, 'memfile') + mem_file_name = os.path.join(mem_file_dir, 'mem_file.dat') + if os.path.isdir(mem_file_dir): + print('laoding validation data memfile') + np_array = read_memmap(mem_file_name) + else: + os.makedirs(mem_file_dir) + carray = bcolz.carray(rootdir=os.path.join(path, name), mode='r') + np_array = np.array(carray) + # mem_array = make_memmap(mem_file_name, np_array) + # del np_array, mem_array + del np_array + np_array = read_memmap(mem_file_name) + else: + np_array = bcolz.carray(rootdir=os.path.join(path, name), mode='r') + + issame = np.load(os.path.join(path, '{}_list.npy'.format(name))) + return np_array, issame + + +def get_val_data(data_path): + agedb_30, agedb_30_issame = get_val_pair(data_path, 'agedb_30') + cfp_fp, cfp_fp_issame = get_val_pair(data_path, 'cfp_fp') + lfw, lfw_issame = get_val_pair(data_path, 'lfw') + cplfw, cplfw_issame = get_val_pair(data_path, 'cplfw') + calfw, calfw_issame = get_val_pair(data_path, 'calfw') + return agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame diff --git a/ppcls/data/preprocess/__init__.py b/ppcls/data/preprocess/__init__.py index 6822e2081f26ff033239b31edf8d5bdeffe85ce0..d0cfcf2409d2d890adcf03ef0e03b2475625ead8 100644 --- a/ppcls/data/preprocess/__init__.py +++ b/ppcls/data/preprocess/__init__.py @@ -33,6 +33,10 @@ from ppcls.data.preprocess.ops.operators import AugMix from ppcls.data.preprocess.ops.operators import Pad from ppcls.data.preprocess.ops.operators import ToTensor from ppcls.data.preprocess.ops.operators import Normalize +from ppcls.data.preprocess.ops.operators import RandomHorizontalFlip +from ppcls.data.preprocess.ops.operators import CropWithPadding +from ppcls.data.preprocess.ops.operators import RandomInterpolationAugment +from ppcls.data.preprocess.ops.operators import ColorJitter from ppcls.data.preprocess.ops.operators import RandomCropImage from ppcls.data.preprocess.ops.operators import Padv2 diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index e5732d3925a6ea452c028c057b56bf9b335aee90..d31ec4b8c4f40dcaa4d53b864996725c7138a393 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -25,8 +25,8 @@ import cv2 import numpy as np from PIL import Image, ImageOps, __version__ as PILLOW_VERSION from paddle.vision.transforms import ColorJitter as RawColorJitter -from paddle.vision.transforms import ToTensor, Normalize - +from paddle.vision.transforms import ToTensor, Normalize, RandomHorizontalFlip, RandomResizedCrop +from paddle.vision.transforms import functional as F from .autoaugment import ImageNetPolicy from .functional import augmentations from ppcls.utils import logger @@ -93,6 +93,42 @@ class UnifiedResize(object): return self.resize_func(src, size) +class RandomInterpolationAugment(object): + def __init__(self, prob): + self.prob = prob + + def _aug(self, img): + img_shape = img.shape + side_ratio = np.random.uniform(0.2, 1.0) + small_side = int(side_ratio * img_shape[0]) + interpolation = np.random.choice([ + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, + cv2.INTER_CUBIC, cv2.INTER_LANCZOS4 + ]) + small_img = cv2.resize( + img, (small_side, small_side), interpolation=interpolation) + interpolation = np.random.choice([ + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, + cv2.INTER_CUBIC, cv2.INTER_LANCZOS4 + ]) + aug_img = cv2.resize( + small_img, (img_shape[1], img_shape[0]), + interpolation=interpolation) + return aug_img + + def __call__(self, img): + if np.random.random() < self.prob: + if isinstance(img, np.ndarray): + return self._aug(img) + else: + pil_img = np.array(img) + aug_img = self._aug(pil_img) + img = Image.fromarray(aug_img.astype(np.uint8)) + return img + else: + return img + + class OperatorParamError(ValueError): """ OperatorParamError """ @@ -170,6 +206,52 @@ class ResizeImage(object): return self._resize_func(img, (w, h)) +class CropWithPadding(RandomResizedCrop): + """ + crop image and padding to original size + """ + + def __init__(self, + prob=1, + padding_num=0, + size=224, + scale=(0.08, 1.0), + ratio=(3. / 4, 4. / 3), + interpolation='bilinear', + key=None): + super().__init__(size, scale, ratio, interpolation, key) + self.prob = prob + self.padding_num = padding_num + + def __call__(self, img): + is_cv2_img = False + if isinstance(img, np.ndarray): + flag = True + if np.random.random() < self.prob: + # RandomResizedCrop augmentation + new = np.zeros_like(np.array(img)) + self.padding_num + # orig_W, orig_H = F._get_image_size(sample) + orig_W, orig_H = self._get_image_size(img) + i, j, h, w = self._get_param(img) + cropped = F.crop(img, i, j, h, w) + new[i:i + h, j:j + w, :] = np.array(cropped) + if not isinstance: + new = Image.fromarray(new.astype(np.uint8)) + return new + else: + return img + + def _get_image_size(self, img): + if F._is_pil_image(img): + return img.size + elif F._is_numpy_image(img): + return img.shape[:2][::-1] + elif F._is_tensor_image(img): + return img.shape[1:][::-1] # chw + else: + raise TypeError("Unexpected type {}".format(type(img))) + + class CropImage(object): """ crop image """ @@ -533,16 +615,18 @@ class ColorJitter(RawColorJitter): """ColorJitter. """ - def __init__(self, *args, **kwargs): + def __init__(self, prob=2, *args, **kwargs): super().__init__(*args, **kwargs) + self.prob = prob def __call__(self, img): - if not isinstance(img, Image.Image): - img = np.ascontiguousarray(img) - img = Image.fromarray(img) - img = super()._apply_image(img) - if isinstance(img, Image.Image): - img = np.asarray(img) + if np.random.random() < self.prob: + if not isinstance(img, Image.Image): + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + img = super()._apply_image(img) + if isinstance(img, Image.Image): + img = np.asarray(img) return img diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index ef24094c2b5214b1aa5811e3c1b28f33a6452c67..2c0ab83f4d4a875901b6655e9ccf91af1737cc73 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -75,8 +75,9 @@ class Engine(object): print_config(config) # init train_func and eval_func - assert self.eval_mode in ["classification", "retrieval"], logger.error( - "Invalid eval mode: {}".format(self.eval_mode)) + assert self.eval_mode in [ + "classification", "retrieval", "adaface" + ], logger.error("Invalid eval mode: {}".format(self.eval_mode)) self.train_epoch_func = train_epoch self.eval_func = getattr(evaluation, self.eval_mode + "_eval") @@ -115,7 +116,7 @@ class Engine(object): self.config["DataLoader"], "Train", self.device, self.use_dali) if self.mode == "eval" or (self.mode == "train" and self.config["Global"]["eval_during_train"]): - if self.eval_mode == "classification": + if self.eval_mode in ["classification", "adaface"]: self.eval_dataloader = build_dataloader( self.config["DataLoader"], "Eval", self.device, self.use_dali) @@ -457,7 +458,9 @@ class Engine(object): def export(self): assert self.mode == "export" - use_multilabel = self.config["Global"].get("use_multilabel", False) + use_multilabel = self.config["Global"].get( + "use_multilabel", + False) and not "ATTRMetric" in self.config["Metric"]["Eval"][0] model = ExportModel(self.config["Arch"], self.model, use_multilabel) if self.config["Global"]["pretrained_model"] is not None: load_dygraph_pretrain(model.base_model, diff --git a/ppcls/engine/evaluation/__init__.py b/ppcls/engine/evaluation/__init__.py index e0cd778887bf6f0e7ce05c18b587e5b54bcf6b3f..a301ad7fda34b87a959b59251b6dd0fffe9eb3e9 100644 --- a/ppcls/engine/evaluation/__init__.py +++ b/ppcls/engine/evaluation/__init__.py @@ -14,3 +14,4 @@ from ppcls.engine.evaluation.classification import classification_eval from ppcls.engine.evaluation.retrieval import retrieval_eval +from ppcls.engine.evaluation.adaface import adaface_eval \ No newline at end of file diff --git a/ppcls/engine/evaluation/adaface.py b/ppcls/engine/evaluation/adaface.py new file mode 100644 index 0000000000000000000000000000000000000000..e62144b5cb374a14a93616c33e56ee74bef0eb01 --- /dev/null +++ b/ppcls/engine/evaluation/adaface.py @@ -0,0 +1,260 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import time +import numpy as np +import platform +import paddle +import sklearn +from sklearn.model_selection import KFold +from sklearn.decomposition import PCA + +from ppcls.utils.misc import AverageMeter +from ppcls.utils import logger + + +def fuse_features_with_norm(stacked_embeddings, stacked_norms): + assert stacked_embeddings.ndim == 3 # (n_features_to_fuse, batch_size, channel) + assert stacked_norms.ndim == 3 # (n_features_to_fuse, batch_size, 1) + pre_norm_embeddings = stacked_embeddings * stacked_norms + fused = pre_norm_embeddings.sum(axis=0) + norm = paddle.norm(fused, 2, 1, True) + fused = paddle.divide(fused, norm) + return fused, norm + + +def adaface_eval(engine, epoch_id=0): + output_info = dict() + time_info = { + "batch_cost": AverageMeter( + "batch_cost", '.5f', postfix=" s,"), + "reader_cost": AverageMeter( + "reader_cost", ".5f", postfix=" s,"), + } + print_batch_step = engine.config["Global"]["print_batch_step"] + + metric_key = None + tic = time.time() + unique_dict = {} + for iter_id, batch in enumerate(engine.eval_dataloader): + images, labels, dataname, image_index = batch + if iter_id == 5: + for key in time_info: + time_info[key].reset() + time_info["reader_cost"].update(time.time() - tic) + batch_size = images.shape[0] + batch[0] = paddle.to_tensor(images) + embeddings = engine.model(images, labels)['features'] + norms = paddle.divide(embeddings, paddle.norm(embeddings, 2, 1, True)) + embeddings = paddle.divide(embeddings, norms) + fliped_images = paddle.flip(images, axis=[3]) + flipped_embeddings = engine.model(fliped_images, labels)['features'] + flipped_norms = paddle.divide( + flipped_embeddings, paddle.norm(flipped_embeddings, 2, 1, True)) + flipped_embeddings = paddle.divide(flipped_embeddings, flipped_norms) + stacked_embeddings = paddle.stack( + [embeddings, flipped_embeddings], axis=0) + stacked_norms = paddle.stack([norms, flipped_norms], axis=0) + embeddings, norms = fuse_features_with_norm(stacked_embeddings, + stacked_norms) + + for out, nor, label, data, idx in zip(embeddings, norms, labels, + dataname, image_index): + unique_dict[int(idx.numpy())] = { + 'output': out, + 'norm': nor, + 'target': label, + 'dataname': data + } + # calc metric + time_info["batch_cost"].update(time.time() - tic) + if iter_id % print_batch_step == 0: + time_msg = "s, ".join([ + "{}: {:.5f}".format(key, time_info[key].avg) + for key in time_info + ]) + + ips_msg = "ips: {:.5f} images/sec".format( + batch_size / time_info["batch_cost"].avg) + + metric_msg = ", ".join([ + "{}: {:.5f}".format(key, output_info[key].val) + for key in output_info + ]) + logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format( + epoch_id, iter_id, + len(engine.eval_dataloader), metric_msg, time_msg, ips_msg)) + + tic = time.time() + + unique_keys = sorted(unique_dict.keys()) + all_output_tensor = paddle.stack( + [unique_dict[key]['output'] for key in unique_keys], axis=0) + all_norm_tensor = paddle.stack( + [unique_dict[key]['norm'] for key in unique_keys], axis=0) + all_target_tensor = paddle.stack( + [unique_dict[key]['target'] for key in unique_keys], axis=0) + all_dataname_tensor = paddle.stack( + [unique_dict[key]['dataname'] for key in unique_keys], axis=0) + + eval_result = cal_metric(all_output_tensor, all_norm_tensor, + all_target_tensor, all_dataname_tensor) + + metric_msg = ", ".join([ + "{}: {:.5f}".format(key, output_info[key].avg) for key in output_info + ]) + face_msg = ", ".join([ + "{}: {:.5f}".format(key, eval_result[key]) + for key in eval_result.keys() + ]) + logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg + ", " + + face_msg)) + + # return 1st metric in the dict + return eval_result['all_test_acc'] + + +def cal_metric(all_output_tensor, all_norm_tensor, all_target_tensor, + all_dataname_tensor): + all_target_tensor = all_target_tensor.reshape([-1]) + all_dataname_tensor = all_dataname_tensor.reshape([-1]) + dataname_to_idx = { + "agedb_30": 0, + "cfp_fp": 1, + "lfw": 2, + "cplfw": 3, + "calfw": 4 + } + idx_to_dataname = {val: key for key, val in dataname_to_idx.items()} + test_logs = {} + # _, indices = paddle.unique(all_dataname_tensor, return_index=True, return_inverse=False, return_counts=False) + for dataname_idx in all_dataname_tensor.unique(): + dataname = idx_to_dataname[dataname_idx.item()] + # per dataset evaluation + embeddings = all_output_tensor[all_dataname_tensor == + dataname_idx].numpy() + labels = all_target_tensor[all_dataname_tensor == dataname_idx].numpy() + issame = labels[0::2] + tpr, fpr, accuracy, best_thresholds = evaluate_face( + embeddings, issame, nrof_folds=10) + acc, best_threshold = accuracy.mean(), best_thresholds.mean() + + num_test_samples = len(embeddings) + test_logs[f'{dataname}_test_acc'] = acc + test_logs[f'{dataname}_test_best_threshold'] = best_threshold + test_logs[f'{dataname}_num_test_samples'] = num_test_samples + + test_acc = np.mean([ + test_logs[f'{dataname}_test_acc'] + for dataname in dataname_to_idx.keys() + if f'{dataname}_test_acc' in test_logs + ]) + + test_logs['all_test_acc'] = test_acc + return test_logs + + +def evaluate_face(embeddings, actual_issame, nrof_folds=10, pca=0): + # Calculate evaluation metrics + thresholds = np.arange(0, 4, 0.01) + embeddings1 = embeddings[0::2] + embeddings2 = embeddings[1::2] + tpr, fpr, accuracy, best_thresholds = calculate_roc( + thresholds, + embeddings1, + embeddings2, + np.asarray(actual_issame), + nrof_folds=nrof_folds, + pca=pca) + return tpr, fpr, accuracy, best_thresholds + + +def calculate_roc(thresholds, + embeddings1, + embeddings2, + actual_issame, + nrof_folds=10, + pca=0): + assert (embeddings1.shape[0] == embeddings2.shape[0]) + assert (embeddings1.shape[1] == embeddings2.shape[1]) + nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) + nrof_thresholds = len(thresholds) + k_fold = KFold(n_splits=nrof_folds, shuffle=False) + + tprs = np.zeros((nrof_folds, nrof_thresholds)) + fprs = np.zeros((nrof_folds, nrof_thresholds)) + accuracy = np.zeros((nrof_folds)) + best_thresholds = np.zeros((nrof_folds)) + indices = np.arange(nrof_pairs) + # print('pca', pca) + dist = None + + if pca == 0: + diff = np.subtract(embeddings1, embeddings2) + dist = np.sum(np.square(diff), 1) + + for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): + # print('train_set', train_set) + # print('test_set', test_set) + if pca > 0: + print('doing pca on', fold_idx) + embed1_train = embeddings1[train_set] + embed2_train = embeddings2[train_set] + _embed_train = np.concatenate((embed1_train, embed2_train), axis=0) + # print(_embed_train.shape) + pca_model = PCA(n_components=pca) + pca_model.fit(_embed_train) + embed1 = pca_model.transform(embeddings1) + embed2 = pca_model.transform(embeddings2) + embed1 = sklearn.preprocessing.normalize(embed1) + embed2 = sklearn.preprocessing.normalize(embed2) + # print(embed1.shape, embed2.shape) + diff = np.subtract(embed1, embed2) + dist = np.sum(np.square(diff), 1) + + # Find the best threshold for the fold + acc_train = np.zeros((nrof_thresholds)) + for threshold_idx, threshold in enumerate(thresholds): + _, _, acc_train[threshold_idx] = calculate_accuracy( + threshold, dist[train_set], actual_issame[train_set]) + best_threshold_index = np.argmax(acc_train) + best_thresholds[fold_idx] = thresholds[best_threshold_index] + for threshold_idx, threshold in enumerate(thresholds): + tprs[fold_idx, threshold_idx], fprs[ + fold_idx, threshold_idx], _ = calculate_accuracy( + threshold, dist[test_set], actual_issame[test_set]) + _, _, accuracy[fold_idx] = calculate_accuracy( + thresholds[best_threshold_index], dist[test_set], + actual_issame[test_set]) + + tpr = np.mean(tprs, 0) + fpr = np.mean(fprs, 0) + return tpr, fpr, accuracy, best_thresholds + + +def calculate_accuracy(threshold, dist, actual_issame): + predict_issame = np.less(dist, threshold) + tp = np.sum(np.logical_and(predict_issame, actual_issame)) + fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) + tn = np.sum( + np.logical_and( + np.logical_not(predict_issame), np.logical_not(actual_issame))) + fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) + + tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) + fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) + acc = float(tp + tn) / dist.size + return tpr, fpr, acc diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 4087cd4d4fd4eca0830d0ce253082dbbbbf16ec0..2161ca86ae51c1c1aa551dd08c1924adc3d9c59b 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -390,6 +390,7 @@ class AccuracyScore(MultiLabelMetric): def get_attr_metrics(gt_label, preds_probs, threshold): """ index: evaluated label index + adapted from "https://github.com/valencebond/Rethinking_of_PAR/blob/master/metrics/pedestrian_metrics.py" """ pred_label = (preds_probs > threshold).astype(int) diff --git a/requirements.txt b/requirements.txt index 5e927756a4a2341b91ca2e23065657bb09a4e514..4787aa84805e84c26a1030f773fbd89826e1aa56 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,9 +4,9 @@ opencv-python==4.4.0.46 pillow tqdm PyYAML -visualdl >= 2.2.0 +visualdl>=2.2.0 scipy -scikit-learn==0.23.2 +scikit-learn>=0.21.0 gast==0.3.3 faiss-cpu==1.7.1.post2 easydict