diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index 1923b09932fc30389d6c93abcf70c2bdef9b0f95..a1727613a6eef3f167d74f371bf4f25639c85b95 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -58,6 +58,7 @@ from ppcls.arch.backbone.model_zoo.rednet import RedNet26, RedNet38, RedNet50, R from ppcls.arch.backbone.model_zoo.tnt import TNT_small from ppcls.arch.backbone.model_zoo.hardnet import HarDNet68, HarDNet85, HarDNet39_ds, HarDNet68_ds from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1 +from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid def get_apis(): diff --git a/ppcls/arch/backbone/variant_models/__init__.py b/ppcls/arch/backbone/variant_models/__init__.py index e87a0e098cf19e4066c39c060d174512b64451a2..34fcdeb2c6d60bf07bd54687d472dabb5b4ae825 100644 --- a/ppcls/arch/backbone/variant_models/__init__.py +++ b/ppcls/arch/backbone/variant_models/__init__.py @@ -1 +1,2 @@ from .resnet_variant import ResNet50_last_stage_stride1 +from .vgg_variant import VGG19Sigmoid diff --git a/ppcls/arch/backbone/variant_models/vgg_variant.py b/ppcls/arch/backbone/variant_models/vgg_variant.py new file mode 100644 index 0000000000000000000000000000000000000000..b73ad35929b5e14cb725867acdd16cb0c3216fa9 --- /dev/null +++ b/ppcls/arch/backbone/variant_models/vgg_variant.py @@ -0,0 +1,28 @@ +import paddle +from paddle.nn import Sigmoid +from ppcls.arch.backbone.legendary_models.vgg import VGG19 + +__all__ = ["VGG19Sigmoid"] + + +class SigmoidSuffix(paddle.nn.Layer): + def __init__(self, origin_layer): + super(SigmoidSuffix, self).__init__() + self.origin_layer = origin_layer + self.sigmoid = Sigmoid() + + def forward(self, input, res_dict=None, **kwargs): + x = self.origin_layer(input) + x = self.sigmoid(x) + return x + + +def VGG19Sigmoid(pretrained=False, use_ssld=False, **kwargs): + def replace_function(origin_layer): + new_layer = SigmoidSuffix(origin_layer) + return new_layer + + match_re = "linear_2" + model = VGG19(pretrained=pretrained, use_ssld=use_ssld, **kwargs) + model.replace_sub(match_re, replace_function, True) + return model diff --git a/ppcls/arch/gears/circlemargin.py b/ppcls/arch/gears/circlemargin.py index 65967d0d7b3f11c54a792f3937525411a840ccc6..87baee839aff50713973b9e0e8d7b7f9f09c187f 100644 --- a/ppcls/arch/gears/circlemargin.py +++ b/ppcls/arch/gears/circlemargin.py @@ -28,7 +28,7 @@ class CircleMargin(nn.Layer): weight_attr = paddle.ParamAttr( initializer=paddle.nn.initializer.XavierNormal()) - self.fc0 = paddle.nn.Linear( + self.fc = paddle.nn.Linear( self.embedding_size, self.class_num, weight_attr=weight_attr) def forward(self, input, label): @@ -36,19 +36,22 @@ class CircleMargin(nn.Layer): paddle.sum(paddle.square(input), axis=1, keepdim=True)) input = paddle.divide(input, feat_norm) - weight = self.fc0.weight + weight = self.fc.weight weight_norm = paddle.sqrt( paddle.sum(paddle.square(weight), axis=0, keepdim=True)) weight = paddle.divide(weight, weight_norm) logits = paddle.matmul(input, weight) + if not self.training or label is None: + return logits alpha_p = paddle.clip(-logits.detach() + 1 + self.margin, min=0.) alpha_n = paddle.clip(logits.detach() + self.margin, min=0.) delta_p = 1 - self.margin delta_n = self.margin - index = paddle.fluid.layers.where(label != -1).reshape([-1]) + m_hot = F.one_hot(label.reshape([-1]), num_classes=logits.shape[1]) + logits_p = alpha_p * (logits - delta_p) logits_n = alpha_n * (logits - delta_n) pre_logits = logits_p * m_hot + logits_n * (1 - m_hot) diff --git a/ppcls/arch/gears/cosmargin.py b/ppcls/arch/gears/cosmargin.py index 51db550868352e2ef20c3accd1fa4dc92d64321d..378e102a215664e33bb8d79016673778ed8a4221 100644 --- a/ppcls/arch/gears/cosmargin.py +++ b/ppcls/arch/gears/cosmargin.py @@ -46,6 +46,9 @@ class CosMargin(paddle.nn.Layer): weight = paddle.divide(weight, weight_norm) cos = paddle.matmul(input, weight) + if not self.training or label is None: + return cos + cos_m = cos - self.margin one_hot = paddle.nn.functional.one_hot(label, self.class_num) diff --git a/ppcls/configs/Products/MV3_Large_1x_Aliproduct_DLBHC.yaml b/ppcls/configs/Products/MV3_Large_1x_Aliproduct_DLBHC.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c9a8b7b4012378f8d04ce6f2b215076ad5850166 --- /dev/null +++ b/ppcls/configs/Products/MV3_Large_1x_Aliproduct_DLBHC.yaml @@ -0,0 +1,149 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output_dlbhc/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 100 + #eval_mode: "retrieval" + print_batch_step: 10 + use_visualdl: False + + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + + #feature postprocess + feature_normalize: False + feature_binarize: "round" + +# model architecture +Arch: + name: "RecModel" + Backbone: + name: "MobileNetV3_large_x1_0" + pretrained: True + class_num: 512 + Head: + name: "FC" + class_num: 50030 + embedding_size: 512 + + infer_output_key: "features" + infer_add_softmax: "false" + +# loss function config for train/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Piecewise + learning_rate: 0.1 + decay_epochs: [50, 150] + values: [0.1, 0.01, 0.001] + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/Aliproduct/ + cls_label_path: ./dataset/Aliproduct/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: 256 + - RandCropImage: + size: 227 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.4914, 0.4822, 0.4465] + std: [0.2023, 0.1994, 0.2010] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/Aliproduct/ + cls_label_path: ./dataset/Aliproduct/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: 227 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.4914, 0.4822, 0.4465] + std: [0.2023, 0.1994, 0.2010] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 256 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/whl/demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 227 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.4914, 0.4822, 0.4465] + std: [0.2023, 0.1994, 0.2010] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] + +# switch to metric below when eval by retrieval +# - Recallk: +# topk: [1] +# - mAP: +# - Precisionk: +# topk: [1] + diff --git a/ppcls/configs/quick_start/professional/VGG19_CIFAR10_DeepHash.yaml b/ppcls/configs/quick_start/professional/VGG19_CIFAR10_DeepHash.yaml new file mode 100644 index 0000000000000000000000000000000000000000..97228828a6c1381312ff32e313e7e19429dc9392 --- /dev/null +++ b/ppcls/configs/quick_start/professional/VGG19_CIFAR10_DeepHash.yaml @@ -0,0 +1,147 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + eval_mode: "retrieval" + epochs: 128 + print_batch_step: 10 + use_visualdl: False + + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + + #feature postprocess + feature_normalize: False + feature_binarize: "round" + +# model architecture +Arch: + name: "RecModel" + Backbone: + name: "VGG19Sigmoid" + pretrained: True + class_num: 48 + Head: + name: "FC" + class_num: 10 + embedding_size: 48 + + infer_output_key: "features" + infer_add_softmax: "false" + +# loss function config for train/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Piecewise + learning_rate: 0.01 + decay_epochs: [200] + values: [0.01, 0.001] + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/cifar10/ + cls_label_path: ./dataset/cifar10/cifar10-2/train.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: 256 + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.4914, 0.4822, 0.4465] + std: [0.2023, 0.1994, 0.2010] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + Query: + dataset: + name: ImageNetDataset + image_root: ./dataset/cifar10/ + cls_label_path: ./dataset/cifar10/cifar10-2/test.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.4914, 0.4822, 0.4465] + std: [0.2023, 0.1994, 0.2010] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 512 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + + Gallery: + dataset: + name: ImageNetDataset + image_root: ./dataset/cifar10/ + cls_label_path: ./dataset/cifar10/cifar10-2/database.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.4914, 0.4822, 0.4465] + std: [0.2023, 0.1994, 0.2010] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 512 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - mAP: + - Precisionk: + topk: [1, 5] + diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index c0da3ace25f36725711cdf6a95a966f1ff8107d9..49d9626f639510eb98947eacbe280c2564a3e554 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -124,6 +124,13 @@ def cal_feature(evaler, name='gallery'): feas_norm = paddle.sqrt( paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True)) batch_feas = paddle.divide(batch_feas, feas_norm) + + # do binarize + if evaler.config["Global"].get("feature_binarize") == "round": + batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0 + + if evaler.config["Global"].get("feature_binarize") == "sign": + batch_feas = paddle.sign(batch_feas).astype("float32") if all_feas is None: all_feas = batch_feas @@ -135,8 +142,10 @@ def cal_feature(evaler, name='gallery'): all_image_id = paddle.concat([all_image_id, batch[1]]) if has_unique_id: all_unique_id = paddle.concat([all_unique_id, batch[2]]) + if evaler.use_dali: dataloader_tmp.reset() + if paddle.distributed.get_world_size() > 1: feat_list = [] img_id_list = [] diff --git a/ppcls/loss/deephashloss.py b/ppcls/loss/deephashloss.py new file mode 100644 index 0000000000000000000000000000000000000000..44c08ef3b4f3062316e2367e489381d70bf32c84 --- /dev/null +++ b/ppcls/loss/deephashloss.py @@ -0,0 +1,90 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn + +class DSHSDLoss(nn.Layer): + """ + # DSHSD(IEEE ACCESS 2019) + # paper [Deep Supervised Hashing Based on Stable Distribution](https://ieeexplore.ieee.org/document/8648432/) + # [DSHSD] epoch:70, bit:48, dataset:cifar10-1, MAP:0.809, Best MAP: 0.809 + # [DSHSD] epoch:250, bit:48, dataset:nuswide_21, MAP:0.809, Best MAP: 0.815 + # [DSHSD] epoch:135, bit:48, dataset:imagenet, MAP:0.647, Best MAP: 0.647 + """ + def __init__(self, n_class, bit, alpha, multi_label=False): + super(DSHSDLoss, self).__init__() + self.m = 2 * bit + self.alpha = alpha + self.multi_label = multi_label + self.n_class = n_class + self.fc = paddle.nn.Linear(bit, n_class, bias_attr=False) + + def forward(self, input, label): + feature = input["features"] + feature = feature.tanh().astype("float32") + + dist = paddle.sum( + paddle.square((paddle.unsqueeze(feature, 1) - paddle.unsqueeze(feature, 0))), + axis=2) + + # label to ont-hot + label = paddle.flatten(label) + label = paddle.nn.functional.one_hot(label, self.n_class).astype("float32") + + s = (paddle.matmul(label, label, transpose_y=True) == 0).astype("float32") + Ld = (1 - s) / 2 * dist + s / 2 * (self.m - dist).clip(min=0) + Ld = Ld.mean() + + logits = self.fc(feature) + if self.multi_label: + # multiple labels classification loss + Lc = (logits - label * logits + ((1 + (-logits).exp()).log())).sum(axis=1).mean() + else: + # single labels classification loss + Lc = (-paddle.nn.functional.softmax(logits).log() * label).sum(axis=1).mean() + + return {"dshsdloss": Lc + Ld * self.alpha} + +class LCDSHLoss(nn.Layer): + """ + # paper [Locality-Constrained Deep Supervised Hashing for Image Retrieval](https://www.ijcai.org/Proceedings/2017/0499.pdf) + # [LCDSH] epoch:145, bit:48, dataset:cifar10-1, MAP:0.798, Best MAP: 0.798 + # [LCDSH] epoch:183, bit:48, dataset:nuswide_21, MAP:0.833, Best MAP: 0.834 + """ + def __init__(self, n_class, _lambda): + super(LCDSHLoss, self).__init__() + self._lambda = _lambda + self.n_class = n_class + + def forward(self, input, label): + feature = input["features"] + + # label to ont-hot + label = paddle.flatten(label) + label = paddle.nn.functional.one_hot(label, self.n_class).astype("float32") + + s = 2 * (paddle.matmul(label, label, transpose_y=True) > 0).astype("float32") - 1 + inner_product = paddle.matmul(feature, feature, transpose_y=True) * 0.5 + + inner_product = inner_product.clip(min=-50, max=50) + L1 = paddle.log(1 + paddle.exp(-s * inner_product)).mean() + + b = feature.sign() + inner_product_ = paddle.matmul(b, b, transpose_y=True) * 0.5 + sigmoid = paddle.nn.Sigmoid() + L2 = (sigmoid(inner_product) - sigmoid(inner_product_)).pow(2).mean() + + return {"lcdshloss": L1 + self._lambda * L2} + diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index db80bab615874a30833567c6f988275f3f75d72e..4c817a115268120456337930fdc09f4bd7a48da0 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -16,7 +16,7 @@ from paddle import nn import copy from collections import OrderedDict -from .metrics import TopkAcc, mAP, mINP, Recallk +from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk from .metrics import DistillationTopkAcc from .metrics import GoogLeNetTopkAcc diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 9908b3541d612f985e4bd0bbcecbcf488e8ec719..204d2af093c79841a7637cdab0a4023f743e04ef 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -168,6 +168,47 @@ class Recallk(nn.Layer): return metric_dict +class Precisionk(nn.Layer): + def __init__(self, topk=(1, 5)): + super().__init__() + assert isinstance(topk, (int, list, tuple)) + if isinstance(topk, int): + topk = [topk] + self.topk = topk + + def forward(self, similarities_matrix, query_img_id, gallery_img_id, + keep_mask): + metric_dict = dict() + + #get cmc + choosen_indices = paddle.argsort( + similarities_matrix, axis=1, descending=True) + gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0]) + gallery_labels_transpose = paddle.broadcast_to( + gallery_labels_transpose, + shape=[ + choosen_indices.shape[0], gallery_labels_transpose.shape[1] + ]) + choosen_label = paddle.index_sample(gallery_labels_transpose, + choosen_indices) + equal_flag = paddle.equal(choosen_label, query_img_id) + if keep_mask is not None: + keep_mask = paddle.index_sample( + keep_mask.astype('float32'), choosen_indices) + equal_flag = paddle.logical_and(equal_flag, + keep_mask.astype('bool')) + equal_flag = paddle.cast(equal_flag, 'float32') + + Ns = paddle.arange(gallery_img_id.shape[0]) + 1 + equal_flag_cumsum = paddle.cumsum(equal_flag, axis=1) + Precision_at_k = (paddle.mean(equal_flag_cumsum, axis=0) / Ns).numpy() + + for k in self.topk: + metric_dict["precision@{}".format(k)] = Precision_at_k[k - 1] + + return metric_dict + + class DistillationTopkAcc(TopkAcc): def __init__(self, model_key, feature_key=None, topk=(1, 5)): super().__init__(topk=topk)