diff --git a/README.md b/README.md index 1ce6db0f6f84b78804e4563c4ef5f990bcd55f78..9be6c8bc66b5cf01f9ec7671b3c4afa4c8523693 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,7 @@ PaddleClas is a toolset for image classification tasks prepared for the industry - Advanced tutorials - [Knowledge distillation](./docs/en/advanced_tutorials/distillation/distillation_en.md) - [Data augmentation](./docs/en/advanced_tutorials/image_augmentation/ImageAugment_en.md) + - [Multilabel classification](./docs/en/advanced_tutorials/multilabel/multilabel_en.md) - Applications - [Transfer learning](./docs/en/application/transfer_learning_en.md) - [Pretrained model with 100,000 categories](./docs/en/application/transfer_learning_en.md) diff --git a/README_cn.md b/README_cn.md index 717abd76c4ceecd4cd6acf2a252f9299605d9974..f1f273bac97076998e40c45fd1e437cf60910fc7 100644 --- a/README_cn.md +++ b/README_cn.md @@ -83,6 +83,7 @@ - 高阶使用 - [知识蒸馏](./docs/zh_CN/advanced_tutorials/distillation/distillation.md) - [数据增广](./docs/zh_CN/advanced_tutorials/image_augmentation/ImageAugment.md) + - [多标签分类](./docs/zh_CN/advanced_tutorials/multilabel/multilabel.md) - 特色拓展应用 - [迁移学习](./docs/zh_CN/application/transfer_learning.md) - [10万类图像分类预训练模型](./docs/zh_CN/application/transfer_learning.md) diff --git a/configs/quick_start/ResNet50_vd_multilabel.yaml b/configs/quick_start/ResNet50_vd_multilabel.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fe785b4b54c12cef8040d48868febb11b99ddb96 --- /dev/null +++ b/configs/quick_start/ResNet50_vd_multilabel.yaml @@ -0,0 +1,79 @@ +mode: 'train' +ARCHITECTURE: + name: 'ResNet50_vd' + +pretrained_model: "./pretrained/ResNet50_vd_pretrained" +model_save_dir: "./output/" +classes_num: 33 +total_images: 17463 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 10 +topk: 1 +image_shape: [3, 224, 224] + +multilabel: True + +use_mix: False +ls_epsilon: 0.1 + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.07 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000070 + +TRAIN: + batch_size: 256 + num_workers: 4 + file_list: "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/multilabel_train_list.txt" + data_dir: "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + mix: + - MixupOperator: + alpha: 0.2 + +VALID: + batch_size: 64 + num_workers: 4 + file_list: "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/multilabel_test_list.txt" + data_dir: "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: \ No newline at end of file diff --git a/docs/en/advanced_tutorials/index.rst b/docs/en/advanced_tutorials/index.rst index 2042d4541aa9a0de6d538eca6bc22f8f573fd1e2..2dba33220700136218d52c414fbbcfb080d132a2 100644 --- a/docs/en/advanced_tutorials/index.rst +++ b/docs/en/advanced_tutorials/index.rst @@ -6,4 +6,4 @@ advanced_tutorials image_augmentation/index distillation/index - + multilabel/index diff --git a/docs/en/advanced_tutorials/multilabel/index.rst b/docs/en/advanced_tutorials/multilabel/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..1e8acfdfb3c0d93101f78fdee05ce364041c1882 --- /dev/null +++ b/docs/en/advanced_tutorials/multilabel/index.rst @@ -0,0 +1,7 @@ +Multilabel Classification +================================ + +.. toctree:: + :maxdepth: 3 + + multilabel.md \ No newline at end of file diff --git a/docs/en/advanced_tutorials/multilabel/multilabel_en.md b/docs/en/advanced_tutorials/multilabel/multilabel_en.md new file mode 100644 index 0000000000000000000000000000000000000000..29a18c181bb2dc474a7a368c5db81dd198d7a5e2 --- /dev/null +++ b/docs/en/advanced_tutorials/multilabel/multilabel_en.md @@ -0,0 +1,82 @@ +# Multilabel classification quick start + +Based on the [NUS-WIDE-SCENE](https://lms.comp.nus.edu.sg/wp-content/uploads/2019/research/nuswide/NUS-WIDE.html) dataset which is a subset of NUS-WIDE dataset, you can experience multilabel of PaddleClas, include training, evaluation and prediction. Please refer to [Installation](install.md) to install at first. + +## Preparation + +* Enter PaddleClas directory + +``` +cd path_to_PaddleClas +``` + +* Create and enter `dataset/NUS-WIDE-SCENE` directory, download and decompress NUS-WIDE-SCENE dataset + +```shell +mkdir dataset/NUS-WIDE-SCENE +cd dataset/NUS-WIDE-SCENE +wget https://paddle-imagenet-models-name.bj.bcebos.com/data/NUS-SCENE-dataset.tar +tar -xf NUS-SCENE-dataset.tar +``` + +* Return `PaddleClas` root home + +``` +cd ../../ +``` + +## Environment + +### Download pretrained model + +You can use the following commands to download the pretrained model of ResNet50_vd. + +```bash +mkdir pretrained +cd pretrained +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams +cd ../ +``` + +## Training + +```shell +export CUDA_VISIBLE_DEVICES=0 +python -m paddle.distributed.launch \ + --gpus="0" \ + tools/train.py \ + -c ./configs/quick_start/ResNet50_vd_multilabel.yaml +``` + +After training for 10 epochs, the best accuracy over the validation set should be around 0.72. + +## Evaluation + +```bash +python tools/eval.py \ + -c ./configs/quick_start/ResNet50_vd_multilabel.yaml \ + -o pretrained_model="./output/ResNet50_vd/best_model/ppcls" \ + -o load_static_weights=False +``` + +The metric of evaluation is based on mAP, which is commonly used in multilabel task to show model perfermance. The mAP over validation set should be around 0.57. + +## Prediction + +```bash +python tools/infer/infer.py \ + -i "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/0199_434752251.jpg" \ + --model ResNet50_vd \ + --pretrained_model "./output/ResNet50_vd/best_model/ppcls" \ + --use_gpu True \ + --load_static_weights False \ + --multilabel True \ + --class_num 33 +``` + +You will get multiple output such as the following: +``` + class id: 3, probability: 0.6025 + class id: 23, probability: 0.5491 + class id: 32, probability: 0.7006 +``` \ No newline at end of file diff --git a/docs/zh_CN/advanced_tutorials/index.rst b/docs/zh_CN/advanced_tutorials/index.rst index c48f204739559ec23265a83c0af8c6326e59e00f..e3bda60be938c4a8ac29bfd51f5f682405f4d0f1 100644 --- a/docs/zh_CN/advanced_tutorials/index.rst +++ b/docs/zh_CN/advanced_tutorials/index.rst @@ -6,4 +6,4 @@ image_augmentation/index distillation/index - + multilabel/index diff --git a/docs/zh_CN/advanced_tutorials/multilabel/index.rst b/docs/zh_CN/advanced_tutorials/multilabel/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..28d5b69dac6347c90381e5112e60210429fca730 --- /dev/null +++ b/docs/zh_CN/advanced_tutorials/multilabel/index.rst @@ -0,0 +1,7 @@ +多标签分类 +================================ + +.. toctree:: + :maxdepth: 3 + + multilabel.md \ No newline at end of file diff --git a/docs/zh_CN/advanced_tutorials/multilabel/multilabel.md b/docs/zh_CN/advanced_tutorials/multilabel/multilabel.md new file mode 100644 index 0000000000000000000000000000000000000000..ef445ca82b7cdb2061d5c48f00cd86fa133b8449 --- /dev/null +++ b/docs/zh_CN/advanced_tutorials/multilabel/multilabel.md @@ -0,0 +1,82 @@ +# 多标签分类quick start + +基于[NUS-WIDE-SCENE](https://lms.comp.nus.edu.sg/wp-content/uploads/2019/research/nuswide/NUS-WIDE.html)数据集,体验多标签分类的训练、评估、预测的过程,该数据集是NUS-WIDE数据集的一个子集。请事先参考[安装指南](install.md)配置运行环境和克隆PaddleClas代码。 + +## 一、数据和模型准备 + +* 进入PaddleClas目录。 + +``` +cd path_to_PaddleClas +``` + +* 创建并进入`dataset/NUS-WIDE-SCENE`目录,下载并解压NUS-WIDE-SCENE数据集。 + +```shell +mkdir dataset/NUS-WIDE-SCENE +cd dataset/NUS-WIDE-SCENE +wget https://paddle-imagenet-models-name.bj.bcebos.com/data/NUS-SCENE-dataset.tar +tar -xf NUS-SCENE-dataset.tar +``` + +* 返回`PaddleClas`根目录 + +``` +cd ../../ +``` + +## 二、环境准备 + +### 2.1 下载预训练模型 + +本例展示基于ResNet50_vd模型的多标签分类流程,因此首先下载ResNet50_vd的预训练模型 + +```bash +mkdir pretrained +cd pretrained +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams +cd ../ +``` + +## 三、模型训练 + +```shell +export CUDA_VISIBLE_DEVICES=0 +python -m paddle.distributed.launch \ + --gpus="0" \ + tools/train.py \ + -c ./configs/quick_start/ResNet50_vd_multilabel.yaml +``` + +训练10epoch之后,验证集最好的正确率应该在0.72左右。 + +## 四、模型评估 + +```bash +python tools/eval.py \ + -c ./configs/quick_start/ResNet50_vd_multilabel.yaml \ + -o pretrained_model="./output/ResNet50_vd/best_model/ppcls" \ + -o load_static_weights=False +``` + +评估指标采用mAP,验证集的mAP应该在0.57左右。 + +## 五、模型预测 + +```bash +python tools/infer/infer.py \ + -i "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/0199_434752251.jpg" \ + --model ResNet50_vd \ + --pretrained_model "./output/ResNet50_vd/best_model/ppcls" \ + --use_gpu True \ + --load_static_weights False \ + --multilabel True \ + --class_num 33 +``` + +得到类似下面的输出: +``` + class id: 3, probability: 0.6025 + class id: 23, probability: 0.5491 + class id: 32, probability: 0.7006 +``` \ No newline at end of file diff --git a/ppcls/data/reader.py b/ppcls/data/reader.py index b1a17637db9c22ff4d59e12c6bb44e0240cd91e4..90bff3589d88d036da717946de3fe6d6821edb37 100755 --- a/ppcls/data/reader.py +++ b/ppcls/data/reader.py @@ -197,6 +197,40 @@ class CommonDataset(Dataset): def __len__(self): return self.num_samples + + +class MultiLabelDataset(Dataset): + """ + Define dataset class for multilabel image classification + """ + + def __init__(self, params): + self.params = params + self.mode = params.get("mode", "train") + self.full_lines = get_file_list(params) + self.delimiter = params.get("delimiter", "\t") + self.ops = create_operators(params["transforms"]) + self.num_samples = len(self.full_lines) + return + + def __getitem__(self, idx): + try: + line = self.full_lines[idx] + img_path, label_str = line.split(self.delimiter) + img_path = os.path.join(self.params["data_dir"], img_path) + with open(img_path, "rb") as f: + img = f.read() + + labels = label_str.split(',') + labels = [int(i) for i in labels] + + return (transform(img, self.ops), np.array(labels).astype("float32")) + except Exception as e: + logger.error("data read failed: {}, exception info: {}".format(line, e)) + return self.__getitem__(random.randint(0, len(self))) + + def __len__(self): + return self.num_samples class Reader: @@ -229,6 +263,7 @@ class Reader: self.collate_fn = self.mix_collate_fn self.places = places + self.multilabel = config.get("multilabel", False) def mix_collate_fn(self, batch): batch = transform(batch, self.batch_ops) @@ -246,7 +281,10 @@ class Reader: def __call__(self): batch_size = int(self.params['batch_size']) // trainers_num - dataset = CommonDataset(self.params) + if self.multilabel: + dataset = MultiLabelDataset(self.params) + else: + dataset = CommonDataset(self.params) is_train = self.params['mode'] == "train" batch_sampler = DistributedBatchSampler( diff --git a/ppcls/modeling/loss.py b/ppcls/modeling/loss.py index 0a8e0c66ffb643582a57bdd89505b7dcab101e1d..5e7abd643fc2c1a3bada3209210d03a9cacfb3f4 100644 --- a/ppcls/modeling/loss.py +++ b/ppcls/modeling/loss.py @@ -15,7 +15,7 @@ import paddle import paddle.nn.functional as F -__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss', 'JSDivLoss'] +__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss', 'JSDivLoss', 'MultiLabelLoss'] class Loss(object): @@ -41,6 +41,17 @@ class Loss(object): soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon) soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim]) return soft_target + + def _binary_crossentropy(self, input, target): + if self._label_smoothing: + target = self._labelsmoothing(target) + cost = F.binary_cross_entropy_with_logits(logit=input, label=target) + else: + cost = F.binary_cross_entropy_with_logits(logit=input, label=target) + + avg_cost = paddle.mean(cost) + + return avg_cost def _crossentropy(self, input, target): if self._label_smoothing: @@ -68,6 +79,20 @@ class Loss(object): def __call__(self, input, target): pass + + +class MultiLabelLoss(Loss): + """ + Multilabel loss based binary cross entropy + """ + + def __init__(self, class_dim=1000, epsilon=None): + super(MultiLabelLoss, self).__init__(class_dim, epsilon) + + def __call__(self, input, target): + cost = self._binary_crossentropy(input, target) + + return cost class CELoss(Loss): diff --git a/ppcls/utils/__init__.py b/ppcls/utils/__init__.py index 9bdc581d0ce72154cdda88e014ec4e131bcc2b67..632cc78824d51d5adae9315fda8fccde50eda73a 100644 --- a/ppcls/utils/__init__.py +++ b/ppcls/utils/__init__.py @@ -15,7 +15,13 @@ from . import logger from . import misc from . import model_zoo +from . import metrics from .save_load import init_model, save_model from .config import get_config from .misc import AverageMeter +from .metrics import multi_hot_encode +from .metrics import hamming_distance +from .metrics import accuracy_score +from .metrics import precision_recall_fscore +from .metrics import mean_average_precision diff --git a/ppcls/utils/metrics.py b/ppcls/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..724cae2d8b7acc95ccbafa32685856fe2bfccd2d --- /dev/null +++ b/ppcls/utils/metrics.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from sklearn.metrics import hamming_loss +from sklearn.metrics import accuracy_score as accuracy_metric +from sklearn.metrics import multilabel_confusion_matrix +from sklearn.metrics import precision_recall_fscore_support +from sklearn.metrics import average_precision_score +from sklearn.preprocessing import binarize + +import numpy as np + +__all__ = ["multi_hot_encode", "hamming_distance", "accuracy_score", "precision_recall_fscore", "mean_average_precision"] + + +def multi_hot_encode(logits, threshold=0.5): + """ + Encode logits to multi-hot by elementwise for multilabel + """ + + return binarize(logits, threshold) + + +def hamming_distance(output, target): + """ + Soft metric based label for multilabel classification + Returns: + The smaller the return value is, the better model is. + """ + + return hamming_loss(target, output) + + +def accuracy_score(output, target, base="sample"): + """ + Hard metric for multilabel classification + Args: + output: + target: + base: ["sample", "label"], default="sample" + if "sample", return metric score based sample, + if "label", return metric score based label. + Returns: + accuracy: + """ + + assert base in ["sample", "label"], 'must be one of ["sample", "label"]' + + if base == "sample": + accuracy = accuracy_metric(target, output) + elif base == "label": + mcm = multilabel_confusion_matrix(target, output) + tns = mcm[:, 0, 0] + fns = mcm[:, 1, 0] + tps = mcm[:, 1, 1] + fps = mcm[:, 0, 1] + + accuracy = (sum(tps) + sum(tns)) / (sum(tps) + sum(tns) + sum(fns) + sum(fps)) + + return accuracy + + +def precision_recall_fscore(output, target): + """ + Metric based label for multilabel classification + Returns: + precisions: + recalls: + fscores: + """ + + precisions, recalls, fscores, _ = precision_recall_fscore_support(target, output) + + return precisions, recalls, fscores + + +def mean_average_precision(logits, target): + """ + Calculate average precision + Args: + logits: probability from network before sigmoid or softmax + target: ground truth, 0 or 1 + """ + if not (isinstance(logits, np.ndarray) and isinstance(target, np.ndarray)): + raise TypeError("logits and target should be np.ndarray.") + + aps = [] + for i in range(target.shape[1]): + ap = average_precision_score(target[:, i], logits[:, i]) + aps.append(ap) + + return np.mean(aps) diff --git a/requirements.txt b/requirements.txt index 205d780f0191d986d97224977c1e935b9d50e26e..654c270fb301a6138f4083f46b075741d7aa4fbc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ tqdm PyYAML visualdl >= 2.0.0b scipy +scikit-learn==0.23.2 diff --git a/tools/eval.py b/tools/eval.py index b69e71dc0312ea2bbb208bcd05740920c6500c54..472eeb5e284c8c938eba99e3227bd1d5719fdaec 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -13,6 +13,7 @@ # limitations under the License. import paddle +import paddle.nn.functional as F import argparse import os @@ -24,9 +25,15 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) from ppcls.utils import logger from ppcls.utils.save_load import init_model from ppcls.utils.config import get_config +from ppcls.utils import multi_hot_encode +from ppcls.utils import accuracy_score +from ppcls.utils import mean_average_precision +from ppcls.utils import precision_recall_fscore from ppcls.data import Reader import program +import numpy as np + def parse_args(): parser = argparse.ArgumentParser("PaddleClas eval script") @@ -52,6 +59,7 @@ def main(args, return_dict={}): # assign place use_gpu = config.get("use_gpu", True) place = paddle.set_device('gpu' if use_gpu else 'cpu') + multilabel = config.get("multilabel", False) trainer_num = paddle.distributed.get_world_size() use_data_parallel = trainer_num != 1 @@ -68,12 +76,38 @@ def main(args, return_dict={}): valid_dataloader = Reader(config, 'valid', places=place)() net.eval() with paddle.no_grad(): - top1_acc = program.run(valid_dataloader, config, net, None, None, 0, - 'valid') - return_dict["top1_acc"] = top1_acc - return top1_acc + if not multilabel: + top1_acc = program.run(valid_dataloader, config, net, None, None, 0, + 'valid') + return_dict["top1_acc"] = top1_acc + + return top1_acc + else: + all_outs = [] + targets = [] + for idx, batch in enumerate(valid_dataloader()): + feeds = program.create_feeds(batch, False, config.classes_num, multilabel) + out = net(feeds["image"]) + out = F.sigmoid(out) + + use_distillation = config.get("use_distillation", False) + if use_distillation: + out = out[1] + + all_outs.extend(list(out.numpy())) + targets.extend(list(feeds["label"].numpy())) + all_outs = np.array(all_outs) + targets = np.array(targets) + + mAP = mean_average_precision(all_outs, targets) + + return_dict["mean average precision"] = mAP + + return mAP if __name__ == '__main__': args = parse_args() - main(args) + return_dict = {} + main(args, return_dict) + print(return_dict) diff --git a/tools/infer/infer.py b/tools/infer/infer.py index b3391c09a75dda495f8fe3a639508cedc218f3b4..87fe9f32035017c7c142e1e8c97e2ba56fec9348 100644 --- a/tools/infer/infer.py +++ b/tools/infer/infer.py @@ -34,6 +34,7 @@ def main(): args = parse_args() # assign the place place = paddle.set_device('gpu' if args.use_gpu else 'cpu') + multilabel = True if args.multilabel else False net = architectures.__dict__[args.model](class_dim=args.class_num) load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights) @@ -61,17 +62,25 @@ def main(): batch_outputs = net(batch_tensor) if args.model == "GoogLeNet": batch_outputs = batch_outputs[0] - batch_outputs = F.softmax(batch_outputs) + if multilabel: + batch_outputs = F.sigmoid(batch_outputs) + else: + batch_outputs = F.softmax(batch_outputs) batch_outputs = batch_outputs.numpy() - batch_result_list = postprocess(batch_outputs, args.top_k) + batch_result_list = postprocess(batch_outputs, args.top_k, multilabel=multilabel) for number, result_dict in enumerate(batch_result_list): filename = img_path_list[number].split("/")[-1] clas_ids = result_dict["clas_ids"] - scores_str = "[{}]".format(", ".join("{:.2f}".format( - r) for r in result_dict["scores"])) - print("File:{}, Top-{} result: class id(s): {}, score(s): {}". - format(filename, args.top_k, clas_ids, scores_str)) + if multilabel: + print("File:{}, multilabel result: ".format(filename)) + for id, score in zip(clas_ids, result_dict["scores"]): + print("\tclass id: {}, probability: {:.2f}".format(id, score)) + else: + scores_str = "[{}]".format(", ".join("{:.2f}".format( + r) for r in result_dict["scores"])) + print("File:{}, Top-{} result: class id(s): {}, score(s): {}". + format(filename, args.top_k, clas_ids, scores_str)) if args.pre_label_image: save_prelabel_results(clas_ids[0], img_path_list[number], diff --git a/tools/infer/utils.py b/tools/infer/utils.py index 639f599d20e917bb61ee493f25dfd62b755b2727..8862e5f5fb0d6bfa606d2bc7a0a9952edf2a6233 100644 --- a/tools/infer/utils.py +++ b/tools/infer/utils.py @@ -31,6 +31,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("-i", "--image_file", type=str) parser.add_argument("--use_gpu", type=str2bool, default=True) + parser.add_argument("--multilabel", type=str2bool, default=False) # params for preprocess parser.add_argument("--resize_short", type=int, default=256) @@ -124,11 +125,14 @@ def preprocess(img, args): return img -def postprocess(batch_outputs, topk=5): +def postprocess(batch_outputs, topk=5, multilabel=False): batch_results = [] for probs in batch_outputs: results = [] - index = probs.argsort(axis=0)[-topk:][::-1].astype("int32") + if multilabel: + index = np.where(probs >= 0.5)[0].astype('int32') + else: + index = probs.argsort(axis=0)[-topk:][::-1].astype("int32") clas_id_list = [] score_list = [] for i in index: diff --git a/tools/program.py b/tools/program.py index e09789a337305aae47350a82918f7aa932d90028..7006c9c3cfa46394edf7f1a66b055c078df0b248 100644 --- a/tools/program.py +++ b/tools/program.py @@ -29,12 +29,16 @@ import paddle.nn.functional as F from ppcls.optimizer import LearningRateBuilder from ppcls.optimizer import OptimizerBuilder from ppcls.modeling import architectures +from ppcls.modeling.loss import MultiLabelLoss from ppcls.modeling.loss import CELoss from ppcls.modeling.loss import MixCELoss from ppcls.modeling.loss import JSDivLoss from ppcls.modeling.loss import GoogLeNetLoss from ppcls.utils.misc import AverageMeter from ppcls.utils import logger +from ppcls.utils import multi_hot_encode +from ppcls.utils import hamming_distance +from ppcls.utils import accuracy_score def create_model(architecture, classes_num): @@ -61,7 +65,8 @@ def create_loss(feeds, classes_num=1000, epsilon=None, use_mix=False, - use_distillation=False): + use_distillation=False, + multilabel=False): """ Create a loss for optimization, such as: 1. CrossEnotry loss @@ -100,7 +105,10 @@ def create_loss(feeds, feed_lam = feeds['lam'] return loss(out, feed_y_a, feed_y_b, feed_lam) else: - loss = CELoss(class_dim=classes_num, epsilon=epsilon) + if not multilabel: + loss = CELoss(class_dim=classes_num, epsilon=epsilon) + else: + loss = MultiLabelLoss(class_dim=classes_num, epsilon=epsilon) return loss(out, feeds["label"]) @@ -110,6 +118,7 @@ def create_metric(out, topk=5, classes_num=1000, use_distillation=False, + multilabel=False, mode="train"): """ Create measures of model accuracy, such as top1 and top5 @@ -135,24 +144,43 @@ def create_metric(out, softmax_out = F.softmax(out) fetchs = OrderedDict() - # set top1 to fetchs - top1 = paddle.metric.accuracy(softmax_out, label=label, k=1) - # set topk to fetchs - k = min(topk, classes_num) - topk = paddle.metric.accuracy(softmax_out, label=label, k=k) + metric_names = set() + if not multilabel: + softmax_out = F.softmax(out) + + # set top1 to fetchs + top1 = paddle.metric.accuracy(softmax_out, label=label, k=1) + # set topk to fetchs + k = min(topk, classes_num) + topk = paddle.metric.accuracy(softmax_out, label=label, k=k) + + metric_names.add("top1") + metric_names.add("top{}".format(k)) + + fetchs['top1'] = top1 + topk_name = "top{}".format(k) + fetchs[topk_name] = topk + else: + out = F.sigmoid(out) + preds = multi_hot_encode(out.numpy()) + targets = label.numpy() + ham_dist = to_tensor(hamming_distance(preds, targets)) + accuracy = to_tensor(accuracy_score(preds, targets, base="label")) + + ham_dist_name = "hamming_distance" + accuracy_name = "multilabel_accuracy" + metric_names.add(ham_dist_name) + metric_names.add(accuracy_name) + + fetchs[accuracy_name] = accuracy + fetchs[ham_dist_name] = ham_dist # multi cards' eval if mode != "train" and paddle.distributed.get_world_size() > 1: - top1 = paddle.distributed.all_reduce( - top1, op=paddle.distributed.ReduceOp. - SUM) / paddle.distributed.get_world_size() - topk = paddle.distributed.all_reduce( - topk, op=paddle.distributed.ReduceOp. - SUM) / paddle.distributed.get_world_size() - - fetchs['top1'] = top1 - topk_name = 'top{}'.format(k) - fetchs[topk_name] = topk + for metric_name in metric_names: + fetchs[metric_name] = paddle.distributed.all_reduce( + fetchs[metric_name], op=paddle.distributed.ReduceOp. + SUM) / paddle.distributed.get_world_size() return fetchs @@ -182,12 +210,14 @@ def create_fetchs(feeds, net, config, mode="train"): epsilon = config.get('ls_epsilon') use_mix = config.get('use_mix') and mode == 'train' use_distillation = config.get('use_distillation') + multilabel = config.get('multilabel', False) out = net(feeds["image"]) fetchs = OrderedDict() fetchs['loss'] = create_loss(feeds, out, architecture, classes_num, - epsilon, use_mix, use_distillation) + epsilon, use_mix, use_distillation, + multilabel) if not use_mix: metric = create_metric( out, @@ -196,6 +226,7 @@ def create_fetchs(feeds, net, config, mode="train"): topk, classes_num, use_distillation, + multilabel=multilabel, mode=mode) fetchs.update(metric) @@ -240,7 +271,7 @@ def create_optimizer(config, parameter_list=None): return opt(lr, parameter_list), lr -def create_feeds(batch, use_mix): +def create_feeds(batch, use_mix, num_classes, multilabel=False): image = batch[0] if use_mix: y_a = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1)) @@ -248,7 +279,10 @@ def create_feeds(batch, use_mix): lam = to_tensor(batch[3].numpy().astype("float32").reshape(-1, 1)) feeds = {"image": image, "y_a": y_a, "y_b": y_b, "lam": lam} else: - label = to_tensor(batch[1].numpy().astype('int64').reshape(-1, 1)) + if not multilabel: + label = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1)) + else: + label = to_tensor(batch[1].numpy().astype('float32').reshape(-1, num_classes)) feeds = {"image": image, "label": label} return feeds @@ -279,6 +313,8 @@ def run(dataloader, """ print_interval = config.get("print_interval", 10) use_mix = config.get("use_mix", False) and mode == "train" + multilabel = config.get("multilabel", False) + classes_num = config.get("classes_num") metric_list = [ ("loss", AverageMeter( @@ -291,13 +327,19 @@ def run(dataloader, 'reader_cost', '.5f', postfix=" s,")), ] if not use_mix: - topk_name = 'top{}'.format(config.topk) - metric_list.insert( - 0, (topk_name, AverageMeter( - topk_name, '.5f', postfix=","))) - metric_list.insert( - 0, ("top1", AverageMeter( - "top1", '.5f', postfix=","))) + if not multilabel: + topk_name = 'top{}'.format(config.topk) + metric_list.insert( + 0, (topk_name, AverageMeter( + topk_name, '.5f', postfix=","))) + metric_list.insert( + 0, ("top1", AverageMeter( + "top1", '.5f', postfix=","))) + else: + metric_list.insert(0, ("multilabel_accuracy", AverageMeter( + "multilabel_accuracy", '.5f', postfix=","))) + metric_list.insert(0, ("hamming_distance", AverageMeter( + "hamming_distance", '.5f', postfix=","))) metric_list = OrderedDict(metric_list) @@ -310,7 +352,7 @@ def run(dataloader, metric_list['reader_time'].update(time.time() - tic) batch_size = len(batch[0]) - feeds = create_feeds(batch, use_mix) + feeds = create_feeds(batch, use_mix, classes_num, multilabel) fetchs = create_fetchs(feeds, net, config, mode) if mode == 'train': avg_loss = fetchs['loss'] @@ -387,4 +429,7 @@ def run(dataloader, # return top1_acc in order to save the best model if mode == 'valid': - return metric_list['top1'].avg + if multilabel: + return metric_list['multilabel_accuracy'].avg + else: + return metric_list['top1'].avg