diff --git a/deploy/configs/inference_multilabel_cls.yaml b/deploy/configs/inference_multilabel_cls.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9dc0529792623fe501a62f05250c24d4e14fd93d --- /dev/null +++ b/deploy/configs/inference_multilabel_cls.yaml @@ -0,0 +1,33 @@ +Global: + infer_imgs: "./images/0517_2715693311.jpg" + inference_model_dir: "../inference/" + batch_size: 1 + use_gpu: True + enable_mkldnn: False + cpu_num_threads: 10 + enable_benchmark: True + use_fp16: False + ir_optim: True + use_tensorrt: False + gpu_mem: 8000 + enable_profile: False +PreProcess: + transform_ops: + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + channel_num: 3 + - ToCHWImage: +PostProcess: + main_indicator: MultiLabelTopk + MultiLabelTopk: + topk: 5 + class_id_map_file: None + SavePreLabel: + save_dir: ./pre_label/ diff --git a/deploy/images/0517_2715693311.jpg b/deploy/images/0517_2715693311.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bd9d2f632192b5be14a8f684303bdeb87bedcfaf Binary files /dev/null and b/deploy/images/0517_2715693311.jpg differ diff --git a/deploy/python/postprocess.py b/deploy/python/postprocess.py index 61b5fbcebd6839292cae53f4afcf6d8b6ac40661..d26cbaa9a8558ffb7f96115eef0a0bd9481fe47a 100644 --- a/deploy/python/postprocess.py +++ b/deploy/python/postprocess.py @@ -81,12 +81,14 @@ class Topk(object): class_id_map = None return class_id_map - def __call__(self, x, file_names=None): + def __call__(self, x, file_names=None, multilabel=False): if file_names is not None: assert x.shape[0] == len(file_names) y = [] for idx, probs in enumerate(x): - index = probs.argsort(axis=0)[-self.topk:][::-1].astype("int32") + index = probs.argsort(axis=0)[-self.topk:][::-1].astype( + "int32") if not multilabel else np.where( + probs >= 0.5)[0].astype("int32") clas_id_list = [] score_list = [] label_name_list = [] @@ -108,6 +110,14 @@ class Topk(object): return y +class MultiLabelTopk(Topk): + def __init__(self, topk=1, class_id_map_file=None): + super().__init__() + + def __call__(self, x, file_names=None): + return super().__call__(x, file_names, multilabel=True) + + class SavePreLabel(object): def __init__(self, save_dir): if save_dir is None: @@ -128,23 +138,24 @@ class SavePreLabel(object): os.makedirs(output_dir, exist_ok=True) shutil.copy(image_file, output_dir) + class Binarize(object): - def __init__(self, method = "round"): + def __init__(self, method="round"): self.method = method self.unit = np.array([[128, 64, 32, 16, 8, 4, 2, 1]]).T def __call__(self, x, file_names=None): if self.method == "round": x = np.round(x + 1).astype("uint8") - 1 - + if self.method == "sign": x = ((np.sign(x) + 1) / 2).astype("uint8") embedding_size = x.shape[1] assert embedding_size % 8 == 0, "The Binary index only support vectors with sizes multiple of 8" - + byte = np.zeros([x.shape[0], embedding_size // 8], dtype=np.uint8) for i in range(embedding_size // 8): - byte[:, i:i+1] = np.dot(x[:, i * 8: (i + 1)* 8], self.unit) + byte[:, i:i + 1] = np.dot(x[:, i * 8:(i + 1) * 8], self.unit) return byte diff --git a/deploy/python/predict_cls.py b/deploy/python/predict_cls.py index dc6865404ecfbc517c7b952c52035a27cbc0137f..cdeb32e4881fb5cac4e3ba09adfba8019af579ad 100644 --- a/deploy/python/predict_cls.py +++ b/deploy/python/predict_cls.py @@ -71,7 +71,6 @@ class ClsPredictor(Predictor): output_names = self.paddle_predictor.get_output_names() output_tensor = self.paddle_predictor.get_output_handle(output_names[ 0]) - if self.benchmark: self.auto_logger.times.start() if not isinstance(images, (list, )): @@ -119,7 +118,6 @@ def main(config): ) == len(image_list): if len(batch_imgs) == 0: continue - batch_results = cls_predictor.predict(batch_imgs) for number, result_dict in enumerate(batch_results): filename = batch_names[number] diff --git a/deploy/shell/predict.sh b/deploy/shell/predict.sh index 44be942866e7ec8b89fdfc2c2b4988e18bb3c6a8..f0f59f4ac04ac3e8bd0f0cd89c35a06e4cc5fb2e 100644 --- a/deploy/shell/predict.sh +++ b/deploy/shell/predict.sh @@ -1,6 +1,9 @@ # classification python3.7 python/predict_cls.py -c configs/inference_cls.yaml +# multilabel_classification +#python3.7 python/predict_cls.py -c configs/inference_multilabel_cls.yaml + # feature extractor # python3.7 python/predict_rec.py -c configs/inference_rec.yaml diff --git a/docs/zh_CN/advanced_tutorials/multilabel/multilabel.md b/docs/zh_CN/advanced_tutorials/multilabel/multilabel.md index ef445ca82b7cdb2061d5c48f00cd86fa133b8449..50eec827a91f6b497724c131e55b31e019b567d8 100644 --- a/docs/zh_CN/advanced_tutorials/multilabel/multilabel.md +++ b/docs/zh_CN/advanced_tutorials/multilabel/multilabel.md @@ -25,58 +25,66 @@ tar -xf NUS-SCENE-dataset.tar cd ../../ ``` -## 二、环境准备 +## 二、模型训练 -### 2.1 下载预训练模型 +```shell +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python3 -m paddle.distributed.launch \ + --gpus="0,1,2,3" \ + tools/train.py \ + -c ./ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml +``` + +训练10epoch之后,验证集最好的正确率应该在0.95左右。 -本例展示基于ResNet50_vd模型的多标签分类流程,因此首先下载ResNet50_vd的预训练模型 +## 三、模型评估 ```bash -mkdir pretrained -cd pretrained -wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams -cd ../ +python3 tools/eval.py \ + -c ./ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml \ + -o Arch.pretrained="./output/MobileNetV1/best_model" ``` -## 三、模型训练 +## 四、模型预测 -```shell -export CUDA_VISIBLE_DEVICES=0 -python -m paddle.distributed.launch \ - --gpus="0" \ - tools/train.py \ - -c ./configs/quick_start/ResNet50_vd_multilabel.yaml +```bash +python3 tools/infer.py \ + -c ./ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml \ + -o Arch.pretrained="./output/MobileNetV1/best_model" +``` + +得到类似下面的输出: +``` +[{'class_ids': [6, 13, 17, 23, 26, 30], 'scores': [0.95683, 0.5567, 0.55211, 0.99088, 0.5943, 0.78767], 'file_name': './deploy/images/0517_2715693311.jpg', 'label_names': []}] ``` -训练10epoch之后,验证集最好的正确率应该在0.72左右。 +## 五、基于预测引擎预测 -## 四、模型评估 +### 5.1 导出inference model ```bash -python tools/eval.py \ - -c ./configs/quick_start/ResNet50_vd_multilabel.yaml \ - -o pretrained_model="./output/ResNet50_vd/best_model/ppcls" \ - -o load_static_weights=False +python3 tools/export_model.py \ + -c ./ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml \ + -o Arch.pretrained="./output/MobileNetV1/best_model" ``` +inference model的路径默认在当前路径下`./inference` -评估指标采用mAP,验证集的mAP应该在0.57左右。 +### 5.2 基于预测引擎预测 -## 五、模型预测 +首先进入deploy目录下: ```bash -python tools/infer/infer.py \ - -i "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/0199_434752251.jpg" \ - --model ResNet50_vd \ - --pretrained_model "./output/ResNet50_vd/best_model/ppcls" \ - --use_gpu True \ - --load_static_weights False \ - --multilabel True \ - --class_num 33 +cd ./deploy +``` + +通过预测引擎推理预测: + +``` +python3 python/predict_cls.py \ + -c configs/inference_multilabel_cls.yaml ``` 得到类似下面的输出: -``` - class id: 3, probability: 0.6025 - class id: 23, probability: 0.5491 - class id: 32, probability: 0.7006 -``` \ No newline at end of file +``` +0517_2715693311.jpg: class id(s): [6, 13, 17, 23, 26, 30], score(s): [0.96, 0.56, 0.55, 0.99, 0.59, 0.79], label_name(s): [] +``` diff --git a/ppcls/configs/quick_start/MobileNetV1_multilabel.yaml b/ppcls/configs/quick_start/MobileNetV1_multilabel.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e9c021b65200ca6d3df369c52717f10078f9d9ee --- /dev/null +++ b/ppcls/configs/quick_start/MobileNetV1_multilabel.yaml @@ -0,0 +1,129 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 10 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + use_multilabel: True +# model architecture +Arch: + name: MobileNetV1 + class_num: 33 + pretrained: True + +# loss function config for traing/eval process +Loss: + Train: + - MultiLabelLoss: + weight: 1.0 + Eval: + - MultiLabelLoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.1 + regularizer: + name: 'L2' + coeff: 0.00004 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: MultiLabelDataset + image_root: ./dataset/NUS-SCENE-dataset/images/ + cls_label_path: ./dataset/NUS-SCENE-dataset/multilabel_train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: MultiLabelDataset + image_root: ./dataset/NUS-SCENE-dataset/images/ + cls_label_path: ./dataset/NUS-SCENE-dataset/multilabel_test_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 256 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: dataset/NUS-SCENE-dataset/images/0001_109549716.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: MutiLabelTopk + topk: 5 + class_id_map_file: None + +Metric: + Train: + - HammingDistance: + - AccuracyScore: + Eval: + - HammingDistance: + - AccuracyScore: diff --git a/ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml b/ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7c9e5a7eb331d8e0ab4694348883a39b70c3cb3a --- /dev/null +++ b/ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml @@ -0,0 +1,129 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 10 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + use_multilabel: True +# model architecture +Arch: + name: MobileNetV1 + class_num: 33 + pretrained: True + +# loss function config for traing/eval process +Loss: + Train: + - MultiLabelLoss: + weight: 1.0 + Eval: + - MultiLabelLoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.1 + regularizer: + name: 'L2' + coeff: 0.00004 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: MultiLabelDataset + image_root: ./dataset/NUS-SCENE-dataset/images/ + cls_label_path: ./dataset/NUS-SCENE-dataset/multilabel_train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: MultiLabelDataset + image_root: ./dataset/NUS-SCENE-dataset/images/ + cls_label_path: ./dataset/NUS-SCENE-dataset/multilabel_test_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 256 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: ./deploy/images/0517_2715693311.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: MultiLabelTopk + topk: 5 + class_id_map_file: None + +Metric: + Train: + - HammingDistance: + - AccuracyScore: + Eval: + - HammingDistance: + - AccuracyScore: diff --git a/ppcls/data/dataloader/multilabel_dataset.py b/ppcls/data/dataloader/multilabel_dataset.py index fafecc711a89b14f2a2b57b1e205d6f9fb0cf369..08d2ba15b02b2bd261ad210767b4feb0843fb67d 100644 --- a/ppcls/data/dataloader/multilabel_dataset.py +++ b/ppcls/data/dataloader/multilabel_dataset.py @@ -33,7 +33,7 @@ class MultiLabelDataset(CommonDataset): with open(self._cls_path) as fd: lines = fd.readlines() for l in lines: - l = l.strip().split(" ") + l = l.strip().split("\t") self.images.append(os.path.join(self._img_root, l[0])) labels = l[1].split(',') @@ -44,13 +44,14 @@ class MultiLabelDataset(CommonDataset): def __getitem__(self, idx): try: - img = cv2.imread(self.images[idx]) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + with open(self.images[idx], 'rb') as f: + img = f.read() if self._transform_ops: img = transform(img, self._transform_ops) img = img.transpose((2, 0, 1)) label = np.array(self.labels[idx]).astype("float32") return (img, label) + except Exception as ex: logger.error("Exception occured when parse line: {} with msg: {}". format(self.images[idx], ex)) diff --git a/ppcls/data/postprocess/__init__.py b/ppcls/data/postprocess/__init__.py index 801e7f101cec0d2781c232074f1543821d2aa2d1..831a4da0008ba70824203be3a6f46c9700225457 100644 --- a/ppcls/data/postprocess/__init__.py +++ b/ppcls/data/postprocess/__init__.py @@ -16,7 +16,7 @@ import importlib from . import topk -from .topk import Topk +from .topk import Topk, MultiLabelTopk def build_postprocess(config): diff --git a/ppcls/data/postprocess/topk.py b/ppcls/data/postprocess/topk.py index 2410e32918e90f5a8da8e7ed4028b0d8501931c5..9c1371bfd11f4c93f06c82436e88e0ff20a57b35 100644 --- a/ppcls/data/postprocess/topk.py +++ b/ppcls/data/postprocess/topk.py @@ -45,15 +45,17 @@ class Topk(object): class_id_map = None return class_id_map - def __call__(self, x, file_names=None): + def __call__(self, x, file_names=None, multilabel=False): assert isinstance(x, paddle.Tensor) if file_names is not None: assert x.shape[0] == len(file_names) - x = F.softmax(x, axis=-1) + x = F.softmax(x, axis=-1) if not multilabel else F.sigmoid(x) x = x.numpy() y = [] for idx, probs in enumerate(x): - index = probs.argsort(axis=0)[-self.topk:][::-1].astype("int32") + index = probs.argsort(axis=0)[-self.topk:][::-1].astype( + "int32") if not multilabel else np.where( + probs >= 0.5)[0].astype("int32") clas_id_list = [] score_list = [] label_name_list = [] @@ -73,3 +75,11 @@ class Topk(object): result["label_names"] = label_name_list y.append(result) return y + + +class MultiLabelTopk(Topk): + def __init__(self, topk=1, class_id_map_file=None): + super().__init__() + + def __call__(self, x, file_names=None): + return super().__call__(x, file_names, multilabel=True) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 2d488f8c1c73c2520ec005a19f07c96ff049edd3..d0f2d64721f05a29089cab29db4252973ffe04e2 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -355,7 +355,8 @@ class Engine(object): def export(self): assert self.mode == "export" - model = ExportModel(self.config["Arch"], self.model) + use_multilabel = self.config["Global"].get("use_multilabel", False) + model = ExportModel(self.config["Arch"], self.model, use_multilabel) if self.config["Global"]["pretrained_model"] is not None: load_dygraph_pretrain(model.base_model, self.config["Global"]["pretrained_model"]) @@ -388,10 +389,9 @@ class ExportModel(nn.Layer): ExportModel: add softmax onto the model """ - def __init__(self, config, model): + def __init__(self, config, model, use_multilabel): super().__init__() self.base_model = model - # we should choose a final model to export if isinstance(self.base_model, DistillationModel): self.infer_model_name = config["infer_model_name"] @@ -402,10 +402,13 @@ class ExportModel(nn.Layer): if self.infer_output_key == "features" and isinstance(self.base_model, RecModel): self.base_model.head = IdentityHead() - if config.get("infer_add_softmax", True): - self.softmax = nn.Softmax(axis=-1) + if use_multilabel: + self.out_act = nn.Sigmoid() else: - self.softmax = None + if config.get("infer_add_softmax", True): + self.out_act = nn.Softmax(axis=-1) + else: + self.out_act = None def eval(self): self.training = False @@ -421,6 +424,6 @@ class ExportModel(nn.Layer): x = x[self.infer_model_name] if self.infer_output_key is not None: x = x[self.infer_output_key] - if self.softmax is not None: - x = self.softmax(x) + if self.out_act is not None: + x = self.out_act(x) return x diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index b1ddc41863044e5b8bed01334d374134f3124387..005d740d38da871755c8b507b5ed3412c4f2eb94 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -52,7 +52,8 @@ def classification_eval(evaler, epoch_id=0): time_info["reader_cost"].update(time.time() - tic) batch_size = batch[0].shape[0] batch[0] = paddle.to_tensor(batch[0]).astype("float32") - batch[1] = batch[1].reshape([-1, 1]).astype("int64") + if not evaler.config["Global"].get("use_multilabel", False): + batch[1] = batch[1].reshape([-1, 1]).astype("int64") # image input out = evaler.model(batch[0]) # calc loss diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index 73f225087fe37b38d8274e43e7b901760101af6e..e158548347630ca52d0ad12b38289c69206ca51b 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -36,8 +36,8 @@ def train_epoch(trainer, epoch_id, print_batch_step): paddle.to_tensor(batch[0]['label']) ] batch_size = batch[0].shape[0] - batch[1] = batch[1].reshape([-1, 1]).astype("int64") - + if not trainer.config["Global"].get("use_multilabel", False): + batch[1] = batch[1].reshape([-1, 1]).astype("int64") trainer.global_step += 1 # image input if trainer.amp: diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index 5421f421242d72bd27edcef869b23844c51703c6..7c0374808f9acfff8de1b24126a8ed8031c5d9ba 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -20,6 +20,7 @@ from .distanceloss import DistanceLoss from .distillationloss import DistillationCELoss from .distillationloss import DistillationGTCELoss from .distillationloss import DistillationDMLLoss +from .multilabelloss import MultiLabelLoss class CombinedLoss(nn.Layer): diff --git a/ppcls/loss/multilabelloss.py b/ppcls/loss/multilabelloss.py new file mode 100644 index 0000000000000000000000000000000000000000..d30d5b8d18083385567d0bcdffaa1fd2da4876f5 --- /dev/null +++ b/ppcls/loss/multilabelloss.py @@ -0,0 +1,43 @@ +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class MultiLabelLoss(nn.Layer): + """ + Multi-label loss + """ + + def __init__(self, epsilon=None): + super().__init__() + if epsilon is not None and (epsilon <= 0 or epsilon >= 1): + epsilon = None + self.epsilon = epsilon + + def _labelsmoothing(self, target, class_num): + if target.ndim == 1 or target.shape[-1] != class_num: + one_hot_target = F.one_hot(target, class_num) + else: + one_hot_target = target + soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon) + soft_target = paddle.reshape(soft_target, shape=[-1, class_num]) + return soft_target + + def _binary_crossentropy(self, input, target, class_num): + if self.epsilon is not None: + target = self._labelsmoothing(target, class_num) + cost = F.binary_cross_entropy_with_logits( + logit=input, label=target) + else: + cost = F.binary_cross_entropy_with_logits( + logit=input, label=target) + + return cost + + def forward(self, x, target): + if isinstance(x, dict): + x = x["logits"] + class_num = x.shape[-1] + loss = self._binary_crossentropy(x, target, class_num) + loss = loss.mean() + return {"MultiLabelLoss": loss} diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index 4c817a115268120456337930fdc09f4bd7a48da0..94721235bca5ab4c27ddba36dd265a01cea003ad 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -19,6 +19,8 @@ from collections import OrderedDict from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk from .metrics import DistillationTopkAcc from .metrics import GoogLeNetTopkAcc +from .metrics import HammingDistance, AccuracyScore + class CombinedMetrics(nn.Layer): def __init__(self, config_list): @@ -32,7 +34,8 @@ class CombinedMetrics(nn.Layer): metric_name = list(config)[0] metric_params = config[metric_name] if metric_params is not None: - self.metric_func_list.append(eval(metric_name)(**metric_params)) + self.metric_func_list.append( + eval(metric_name)(**metric_params)) else: self.metric_func_list.append(eval(metric_name)()) @@ -42,6 +45,7 @@ class CombinedMetrics(nn.Layer): metric_dict.update(metric_func(*args, **kwargs)) return metric_dict + def build_metrics(config): metrics_list = CombinedMetrics(copy.deepcopy(config)) return metrics_list diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 204d2af093c79841a7637cdab0a4023f743e04ef..37509eb14ea98b96f7a1fc96ee3e63f9fba18e7c 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -15,6 +15,12 @@ import numpy as np import paddle import paddle.nn as nn +import paddle.nn.functional as F + +from sklearn.metrics import hamming_loss +from sklearn.metrics import accuracy_score as accuracy_metric +from sklearn.metrics import multilabel_confusion_matrix +from sklearn.preprocessing import binarize class TopkAcc(nn.Layer): @@ -198,7 +204,7 @@ class Precisionk(nn.Layer): equal_flag = paddle.logical_and(equal_flag, keep_mask.astype('bool')) equal_flag = paddle.cast(equal_flag, 'float32') - + Ns = paddle.arange(gallery_img_id.shape[0]) + 1 equal_flag_cumsum = paddle.cumsum(equal_flag, axis=1) Precision_at_k = (paddle.mean(equal_flag_cumsum, axis=0) / Ns).numpy() @@ -232,3 +238,71 @@ class GoogLeNetTopkAcc(TopkAcc): def forward(self, x, label): return super().forward(x[0], label) + + +class MutiLabelMetric(object): + def __init__(self): + pass + + def _multi_hot_encode(self, logits, threshold=0.5): + return binarize(logits, threshold=threshold) + + def __call__(self, output): + output = F.sigmoid(output) + preds = self._multi_hot_encode(logits=output.numpy(), threshold=0.5) + return preds + + +class HammingDistance(MutiLabelMetric): + """ + Soft metric based label for multilabel classification + Returns: + The smaller the return value is, the better model is. + """ + + def __init__(self): + super().__init__() + + def __call__(self, output, target): + preds = super().__call__(output) + metric_dict = dict() + metric_dict["HammingDistance"] = paddle.to_tensor( + hamming_loss(target, preds)) + return metric_dict + + +class AccuracyScore(MutiLabelMetric): + """ + Hard metric for multilabel classification + Args: + base: ["sample", "label"], default="sample" + if "sample", return metric score based sample, + if "label", return metric score based label. + Returns: + accuracy: + """ + + def __init__(self, base="label"): + super().__init__() + assert base in ["sample", "label" + ], 'must be one of ["sample", "label"]' + self.base = base + + def __call__(self, output, target): + preds = super().__call__(output) + metric_dict = dict() + if self.base == "sample": + accuracy = accuracy_metric(target, preds) + elif self.base == "label": + mcm = multilabel_confusion_matrix(target, preds) + tns = mcm[:, 0, 0] + fns = mcm[:, 1, 0] + tps = mcm[:, 1, 1] + fps = mcm[:, 0, 1] + accuracy = (sum(tps) + sum(tns)) / ( + sum(tps) + sum(tns) + sum(fns) + sum(fps)) + precision = sum(tps) / (sum(tps) + sum(fps)) + recall = sum(tps) / (sum(tps) + sum(fns)) + F1 = 2 * (accuracy * recall) / (accuracy + recall) + metric_dict["AccuracyScore"] = paddle.to_tensor(accuracy) + return metric_dict diff --git a/tools/train.sh b/tools/train.sh index 5fced8636235d533bdadcdbb40769733930a0763..083934a57184725c2ed4f99a27a91cc6beed36c1 100755 --- a/tools/train.sh +++ b/tools/train.sh @@ -4,4 +4,4 @@ # python3.7 tools/train.py -c ./ppcls/configs/ImageNet/ResNet/ResNet50.yaml # for multi-cards train -python3.7 -m paddle.distributed.launch --gpus="0,1,2,3" tools/train.py -c ./ppcls/configs/ImageNet/ResNet/ResNet50.yaml \ No newline at end of file +python3.7 -m paddle.distributed.launch --gpus="0,1,2,3" tools/train.py -c ./ppcls/configs/ImageNet/ResNet/ResNet50.yaml diff --git a/train.sh b/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..47ae2a68ee94fdd27cc43364b3a9f59bf874f439 --- /dev/null +++ b/train.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash + +# for single card train +# python3.7 tools/train.py -c ./ppcls/configs/ImageNet/ResNet/ResNet50.yaml + +# for multi-cards train +python3.7 -m paddle.distributed.launch --gpus="0" tools/train.py -c ./MobileNetV2.yaml