diff --git a/deploy/configs/cls_demo/person/inference_person_cls.yaml b/deploy/configs/PULC/person/inference_person_cls.yaml similarity index 83% rename from deploy/configs/cls_demo/person/inference_person_cls.yaml rename to deploy/configs/PULC/person/inference_person_cls.yaml index 9c5161c7c4cc81e77e8c562c9b22a8f7848cebd1..a70f663a792fcdcab3b7d45059f2afe0b1efbf07 100644 --- a/deploy/configs/cls_demo/person/inference_person_cls.yaml +++ b/deploy/configs/PULC/person/inference_person_cls.yaml @@ -1,9 +1,9 @@ Global: - infer_imgs: "./images/cls_demo/person/objects365_02035329.jpg" + infer_imgs: "./images/PULC/person/objects365_02035329.jpg" inference_model_dir: "./models/person_cls_infer" batch_size: 1 use_gpu: True - enable_mkldnn: True + enable_mkldnn: False cpu_num_threads: 10 enable_benchmark: True use_fp16: False @@ -30,7 +30,7 @@ PostProcess: main_indicator: ThreshOutput ThreshOutput: threshold: 0.9 - label_0: invalid - label_1: valid + label_0: nobody + label_1: someone SavePreLabel: save_dir: ./pre_label/ diff --git a/deploy/images/cls_demo/person/objects365_01780782.jpg b/deploy/images/PULC/person/objects365_01780782.jpg similarity index 100% rename from deploy/images/cls_demo/person/objects365_01780782.jpg rename to deploy/images/PULC/person/objects365_01780782.jpg diff --git a/deploy/images/cls_demo/person/objects365_02035329.jpg b/deploy/images/PULC/person/objects365_02035329.jpg similarity index 100% rename from deploy/images/cls_demo/person/objects365_02035329.jpg rename to deploy/images/PULC/person/objects365_02035329.jpg diff --git a/docs/zh_CN/cls_demo/quick_start_cls_demo.md b/docs/zh_CN/PULC/PULC_person_cls.md similarity index 56% rename from docs/zh_CN/cls_demo/quick_start_cls_demo.md rename to docs/zh_CN/PULC/PULC_person_cls.md index dc1f8cd34be2ce9808380e75b1f8652d73cb11a6..ab9feb4811a5861b1f39a9845962265d1d816f1e 100644 --- a/docs/zh_CN/cls_demo/quick_start_cls_demo.md +++ b/docs/zh_CN/PULC/PULC_person_cls.md @@ -1,6 +1,6 @@ # PaddleClas构建有人/无人分类案例 -此处提供了用户使用 PaddleClas 快速构建轻量级、高精度、可落地的有人/无人的分类模型教程,主要基于有人/无人场景的数据,融合了轻量级骨干网络PPLCNet、SSLD预训练权重、EDA数据增强策略、KL-JS-UGI知识蒸馏策略、SHAS超参数搜索策略,得到精度高、速度快、易于部署的二分类模型。 +此处提供了用户使用 PaddleClas 快速构建轻量级、高精度、可落地的有人/无人的分类模型教程,主要基于有人/无人场景的数据,融合了轻量级骨干网络PPLCNet、SSLD预训练权重、EDA数据增强策略、SKL-UGI知识蒸馏策略、SHAS超参数搜索策略,得到精度高、速度快、易于部署的二分类模型。 ------ @@ -55,7 +55,7 @@ cd deploy mkdir models cd models # 下载inference 模型并解压 -wget https://paddleclas.bj.bcebos.com/models/cls_demo/person_cls_infer.tar && tar -xf person_cls_infer.tar +wget https://paddleclas.bj.bcebos.com/models/PULC/person_cls_infer.tar && tar -xf person_cls_infer.tar ``` 解压完毕后,`models` 文件夹下应有如下文件结构: @@ -75,23 +75,29 @@ wget https://paddleclas.bj.bcebos.com/models/cls_demo/person_cls_infer.tar && ta #### 2.2.1 预测单张图像 -运行下面的命令,对图像 `./images/cls_demo/person/objects365_02035329.jpg` 进行有人/无人分类。 +返回 `deploy` 目录: + +``` +cd ../ +``` + +运行下面的命令,对图像 `./images/PULC/person/objects365_02035329.jpg` 进行有人/无人分类。 ```shell # 使用下面的命令使用 GPU 进行预测 -python3.7 python/predict_cls.py -c configs/cls_demo/person/inference_person_cls.yaml +python3.7 python/predict_cls.py -c configs/PULC/person/inference_person_cls.yaml -o PostProcess.ThreshOutput.threshold=0.9794 # 使用下面的命令使用 CPU 进行预测 -python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.use_gpu=False +python3.7 python/predict_cls.py -c configs/PULC/person/inference_person_cls.yaml -o PostProcess.ThreshOutput.threshold=0.9794 -o Global.use_gpu=False ``` 输出结果如下。 ``` -objects365_02035329.jpg: class id(s): [1, 0], score(s): [1.00, 0.00], label_name(s): ['someone', 'nobody'] +objects365_02035329.jpg: class id(s): [1], score(s): [1.00], label_name(s): ['someone'] ``` -其中,`someone` 表示该图里存在人,`nobody` 表示该图里不存在人。 +**备注:** 真实场景中往往需要在假正类率(Fpr)小于某一个指标下求真正类率(Tpr),该场景中的`val`数据集在千分之一Fpr下得到的最佳Tpr所得到的阈值为`0.9794`,故此处的`threshold`为`0.9794`。该阈值的确定方法可以参考[3.2节](#3.2) @@ -101,16 +107,18 @@ objects365_02035329.jpg: class id(s): [1, 0], score(s): [1.00, 0.00], label_name ```shell # 使用下面的命令使用 GPU 进行预测,如果希望使用 CPU 预测,可以在命令后面添加 -o Global.use_gpu=False -python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.infer_imgs="./images/cls_demo/person/" +python3.7 python/predict_cls.py -c configs/PULC/person/inference_person_cls.yaml -o Global.infer_imgs="./images/PULC/person/" ``` 终端中会输出该文件夹内所有图像的分类结果,如下所示。 ``` -objects365_01780782.jpg: class id(s): [0, 1], score(s): [1.00, 0.00], label_name(s): ['nobody', 'someone'] -objects365_02035329.jpg: class id(s): [1, 0], score(s): [1.00, 0.00], label_name(s): ['someone', 'nobody'] +objects365_01780782.jpg: class id(s): [0], score(s): [1.00], label_name(s): ['nobody'] +objects365_02035329.jpg: class id(s): [1], score(s): [1.00], label_name(s): ['someone'] ``` +其中,`someone` 表示该图里存在人,`nobody` 表示该图里不存在人。 + ## 3.有人/无人场景训练 @@ -161,7 +169,7 @@ cd ../ * **注意**: -* 本案例中所使用的所有数据集均为开源数据,`train`集合为[MS-COCO数据](https://cocodataset.org/#overview)的训练集的子集,`val`集合为[Object365数据](https://www.objects365.org/overview.html)的训练集的子集,`ImageNet_val`为[ImageNet数据](https://www.image-net.org/)的验证集。 +* 本案例中所使用的所有数据集均为开源数据,`train`集合为[MS-COCO数据](https://cocodataset.org/#overview)的训练集的子集,`val`集合为[Object365数据](https://www.objects365.org/overview.html)的训练集的子集,`ImageNet_val`为[ImageNet数据](https://www.image-net.org/)的验证集。数据集的筛选流程可以参考[有人/无人场景数据集筛选方法]()。 @@ -175,47 +183,53 @@ cd ../ ##### 3.2.1.1 基于默认超参数训练轻量级模型 -在`ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml`中提供了基于该场景中已经搜索好的超参数,可以通过如下脚本启动训练: +在`ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml`中提供了基于该场景的训练配置,可以通过如下脚本启动训练: ```shell export CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m paddle.distributed.launch \ --gpus="0,1,2,3" \ tools/train.py \ - -c ./ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml + -c ./ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml ``` -验证集的最佳 metric 在0.94-0.95之间(数据集较小,容易造成波动)。 +验证集的最佳指标在0.94-0.95之间(数据集较小,容易造成波动)。 + +**备注:** + +* 此时使用的指标为Tpr,该指标描述了在假正类率(Fpr)小于某一个指标时的真正类率(Tpr),是产业中二分类问题常用的指标之一。在本案例中,Fpr为千分之一。关于Fpr和Tpr的更多介绍,可以参考[这里](https://baike.baidu.com/item/AUC/19282953)。 + +* 在eval时,会打印出来当前最佳的TprAtFpr指标,具体地,其会打印当前的`Fpr`、`Tpr`值,以及当前的`threshold`值,`Tpr`值反映了在当前`Fpr`值下的召回率,该值越高,代表模型越好。`threshold` 表示当前最佳`Fpr`所对应的分类阈值,可用于后续模型部署落地等。 ##### 3.2.1.2 基于默认超参数训练教师模型 -复用`ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml`中的超参数,训练教师模型,训练脚本如下: +复用`ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml`中的超参数,训练教师模型,训练脚本如下: ```shell export CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m paddle.distributed.launch \ --gpus="0,1,2,3" \ tools/train.py \ - -c ./ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml \ + -c ./ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml \ -o Arch.name=ResNet101_vd ``` -验证集的最佳 metric 为0.97-0.98之间,当前教师模型最好的权重保存在`output/ResNet101_vd/best_model.pdparams`。 +验证集的最佳指标为0.96-0.98之间,当前教师模型最好的权重保存在`output/ResNet101_vd/best_model.pdparams`。 ##### 3.2.1.3 基于默认超参数进行蒸馏训练 -配置文件`ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml`提供了`KL-JS-UGI知识蒸馏策略`的配置。该配置将`ResNet101_vd`当作教师模型,`PPLCNet_x1_0`当作学生模型,使用ImageNet数据集的验证集作为新增的无标签数据。训练脚本如下: +配置文件`ppcls/configs/PULC/PULC/Distillation/PPLCNet_x1_0_distillation.yaml`提供了`SKL-UGI知识蒸馏策略`的配置。该配置将`ResNet101_vd`当作教师模型,`PPLCNet_x1_0`当作学生模型,使用ImageNet数据集的验证集作为新增的无标签数据。训练脚本如下: ```shell export CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m paddle.distributed.launch \ --gpus="0,1,2,3" \ tools/train.py \ - -c .ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml \ + -c ./ppcls/configs/PULC/person/Distillation/PPLCNet_x1_0_distillation.yaml \ -o Arch.models.0.Teacher.pretrained=output/ResNet101_vd/best_model ``` @@ -228,14 +242,19 @@ python3 -m paddle.distributed.launch \ * 搜索运行脚本如下: ```shell -python tools/search_strategy.py -c ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml +python tools/search_strategy.py -c ppcls/configs/StrategySearch/person.yaml ``` +在`ppcls/configs/StrategySearch/person.yaml`中指定了具体的 GPU id 号和搜索配置。 + * **注意**: -* 此过程基于当前数据集在 V100 4 卡上大概需要耗时 6 小时,如果缺少机器资源,希望体验搜索过程,可以将`ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml`中的`train_list.txt`和`val_list.txt`分别替换为`train_list.txt.debug`和`val_list.txt.debug`。替换list只是为了加速跑通整个搜索过程,由于数据量较小,其搜素的结果没有参考性。 +* 3.1小节提供的默认配置已经经过了搜索,所以此过程不是必要的过程,如果自己的训练数据集有变化,可以尝试此过程。 + +* 此过程基于当前数据集在 V100 4 卡上大概需要耗时 10 小时,如果缺少机器资源,希望体验搜索过程,可以将`ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml`中的`train_list.txt`和`val_list.txt`分别替换为`train_list.txt.debug`和`val_list.txt.debug`。替换list只是为了加速跑通整个搜索过程,由于数据量较小,其搜素的结果没有参考性。另外,搜索空间可以根据当前的机器资源来调整,如果机器资源有限,可以尝试缩小搜索空间,如果机器资源较充足,可以尝试扩大搜索空间。 + +* 如果此过程搜索的得到的超参数与[3.2.1小节](#3.2.1)提供的超参数不一致,主要是由于训练数据较小造成的波动导致,可以忽略。 -* 如果此过程搜索的得到的超参数与3.2.1小节提供的超参数不一致,主要是由于训练数据较小造成的波动导致,可以忽略。 @@ -246,11 +265,11 @@ python tools/search_strategy.py -c ppcls/configs/cls_demo/person/PPLCNet/PPLCNet ### 4.1 模型评估 -训练好模型之后,可以通过以下命令实现对模型精度的评估。 +训练好模型之后,可以通过以下命令实现对模型指标的评估。 ```bash python3 tools/eval.py \ - -c ./ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml \ + -c ./ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml \ -o Global.pretrained_model="output/PPLCNet_x1_0/best_model" ``` @@ -262,11 +281,20 @@ python3 tools/eval.py \ ```python python3 tools/infer.py \ - -c ./ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml \ + -c ./ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml \ -o Infer.infer_imgs=./dataset/person/val/objects365_01780637.jpg \ - -o Global.pretrained_model=output/PPLCNet_x1_0/best_model + -o Global.pretrained_model=output/PPLCNet_x1_0/best_model \ + -o Global.pretrained_model=Infer.PostProcess.threshold=0.9794 +``` + +输出结果如下: + +``` +[{'class_ids': [0], 'scores': [0.9878496769815683], 'label_names': ['nobody'], 'file_name': './dataset/person/val/objects365_01780637.jpg'}] ``` +**备注:** 这里的`Infer.PostProcess.threshold`的值需要根据实际场景来确定,此处的`0.9794`是在该场景中的`val`数据集在千分之一Fpr下得到的最佳Tpr所得到的。 + ### 4.3 使用 inference 模型进行推理 @@ -280,7 +308,7 @@ python3 tools/infer.py \ ```bash python3 tools/export_model.py \ - -c ./ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml \ + -c ./ppcls/configs/cls_demo/PULC/PPLCNet/PPLCNet_x1_0.yaml \ -o Global.pretrained_model=output/PPLCNet_x1_0/best_model \ -o Global.save_inference_dir=deploy/models/PPLCNet_x1_0_person ``` @@ -292,8 +320,11 @@ python3 tools/export_model.py \ 推理预测的脚本为: ``` -python3.7 python/predict_cls.py -c configs/cls_demo/person/inference_person_cls.yaml -o Global.inference_model_dir="models/PPLCNet_x1_0_person" +python3.7 python/predict_cls.py -c configs/PULC/person/inference_person_cls.yaml -o Global.inference_model_dir="models/PPLCNet_x1_0_person" -o PostProcess.ThreshOutput.threshold=0.9794 ``` -更多关于推理的细节,可以参考[2.2节](#2.2)。 +**备注:** + +- 此处的`PostProcess.ThreshOutput.threshold`由eval时的最佳`threshold`来确定。 +- 更多关于推理的细节,可以参考[2.2节](#2.2)。 diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index 3f94501fbde0f958d662359889cef074442cf5aa..d3f98885b5e0cd3b4d4db4483f30edf674432f72 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -40,6 +40,7 @@ def build_model(config): arch = getattr(mod, model_type)(**arch_config) if use_sync_bn: arch = nn.SyncBatchNorm.convert_sync_batchnorm(arch) + if isinstance(arch, TheseusLayer): prune_model(config, arch) quantize_model(config, arch) diff --git a/ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml b/ppcls/configs/PULC/person/Distillation/PPLCNet_x1_0_distillation.yaml similarity index 94% rename from ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml rename to ppcls/configs/PULC/person/Distillation/PPLCNet_x1_0_distillation.yaml index a9c3db29682933f19cf93ef000d5b6ec83007aa7..afb9b43a0dfad4153bdc761a13c61a4d0e5fd47d 100644 --- a/ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml +++ b/ppcls/configs/PULC/person/Distillation/PPLCNet_x1_0_distillation.yaml @@ -6,7 +6,7 @@ Global: device: gpu save_interval: 1 eval_during_train: True - start_eval_epoch: 10 + start_eval_epoch: 1 eval_interval: 1 epochs: 20 print_batch_step: 10 @@ -33,14 +33,11 @@ Arch: - Teacher: name: ResNet101_vd class_num: *class_num - use_sync_bn: True - Student: name: PPLCNet_x1_0 class_num: *class_num pretrained: True use_ssld: True - use_sync_bn: True - lr_mult_list: [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] infer_model_name: "Student" @@ -155,9 +152,10 @@ Infer: order: '' - ToCHWImage: PostProcess: - name: Topk - topk: 5 - class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + name: ThreshOutput + threshold: 0.9 + label_0: nobody + label_1: someone Metric: Train: diff --git a/ppcls/configs/cls_demo/person/OtherModels/MobileNetV3_large_x1_0.yaml b/ppcls/configs/PULC/person/OtherModels/MobileNetV3_large_x1_0.yaml similarity index 97% rename from ppcls/configs/cls_demo/person/OtherModels/MobileNetV3_large_x1_0.yaml rename to ppcls/configs/PULC/person/OtherModels/MobileNetV3_large_x1_0.yaml index d16a1b61761105ba60153355cef3eae75813ac0f..d69bb933fdbf5592d497651cad79995a492cdf28 100644 --- a/ppcls/configs/cls_demo/person/OtherModels/MobileNetV3_large_x1_0.yaml +++ b/ppcls/configs/PULC/person/OtherModels/MobileNetV3_large_x1_0.yaml @@ -130,9 +130,10 @@ Infer: order: '' - ToCHWImage: PostProcess: - name: Topk - topk: 5 - class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + name: ThreshOutput + threshold: 0.9 + label_0: nobody + label_1: someone Metric: Train: diff --git a/ppcls/configs/cls_demo/person/OtherModels/SwinTransformer_tiny_patch4_window7_224.yaml b/ppcls/configs/PULC/person/OtherModels/SwinTransformer_tiny_patch4_window7_224.yaml similarity index 97% rename from ppcls/configs/cls_demo/person/OtherModels/SwinTransformer_tiny_patch4_window7_224.yaml rename to ppcls/configs/PULC/person/OtherModels/SwinTransformer_tiny_patch4_window7_224.yaml index 9999f02830a4a3842326cfe5160c560d6179f937..0e2248e98529b511c7821b49ced6cf0625016553 100644 --- a/ppcls/configs/cls_demo/person/OtherModels/SwinTransformer_tiny_patch4_window7_224.yaml +++ b/ppcls/configs/PULC/person/OtherModels/SwinTransformer_tiny_patch4_window7_224.yaml @@ -153,9 +153,10 @@ Infer: order: '' - ToCHWImage: PostProcess: - name: Topk - topk: 5 - class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + name: ThreshOutput + threshold: 0.9 + label_0: nobody + label_1: someone Metric: Train: diff --git a/ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml b/ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml similarity index 95% rename from ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml rename to ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml index 97ae1ba73b0c499a5e2b80f5d32c62964b061a40..e196547923a345a9535f5b63a568817b2784c6d7 100644 --- a/ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml +++ b/ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml @@ -26,7 +26,6 @@ Arch: pretrained: True use_ssld: True use_sync_bn: True - lr_mult_list: [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] # loss function config for traing/eval process Loss: @@ -137,9 +136,10 @@ Infer: order: '' - ToCHWImage: PostProcess: - name: Topk - topk: 5 - class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + name: ThreshOutput + threshold: 0.9 + label_0: nobody + label_1: someone Metric: Train: diff --git a/ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml b/ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0_search.yaml similarity index 97% rename from ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml rename to ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0_search.yaml index 3c9bfb5dbb10fdea9c1209ecaace08c2fb59ac6a..b2126b69f9d773d918df6b1f03361cac06ee44f8 100644 --- a/ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml +++ b/ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0_search.yaml @@ -136,9 +136,10 @@ Infer: order: '' - ToCHWImage: PostProcess: - name: Topk - topk: 5 - class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + name: ThreshOutput + threshold: 0.9 + label_0: nobody + label_1: someone Metric: Train: diff --git a/ppcls/configs/StrategySearch/person.yaml b/ppcls/configs/StrategySearch/person.yaml index d9841c9af7e7f78ce62eaa5a1811be4d337d475b..906635595f33417cf564ca54a430c3c648fd738d 100644 --- a/ppcls/configs/StrategySearch/person.yaml +++ b/ppcls/configs/StrategySearch/person.yaml @@ -1,9 +1,9 @@ -base_config_file: ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml -distill_config_file: ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml +base_config_file: ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0_search.yaml +distill_config_file: ppcls/configs/PULC/person/Distillation/PPLCNet_x1_0_distillation.yaml gpus: 0,1,2,3 output_dir: output/search_person -search_times: 3 +search_times: 1 search_dict: - search_key: lrs replace_config: diff --git a/ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation_search.yaml b/ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation_search.yaml deleted file mode 100644 index 231766d846459e5157a48fcde110b391f7f5cd5a..0000000000000000000000000000000000000000 --- a/ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation_search.yaml +++ /dev/null @@ -1,169 +0,0 @@ -# global configs -Global: - checkpoints: null - pretrained_model: null - output_dir: ./output - device: gpu - save_interval: 1 - eval_during_train: True - start_eval_epoch: 10 - eval_interval: 1 - epochs: 20 - print_batch_step: 10 - use_visualdl: False - # used for static mode and model export - image_shape: [3, 224, 224] - save_inference_dir: ./inference - # training model under @to_static - to_static: False - use_dali: False - -# model architecture -Arch: - name: "DistillationModel" - class_num: &class_num 2 - # if not null, its lengths should be same as models - pretrained_list: - # if not null, its lengths should be same as models - freeze_params_list: - - True - - False - use_sync_bn: True - models: - - Teacher: - name: ResNet101_vd - class_num: *class_num - use_sync_bn: True - - Student: - name: PPLCNet_x1_0 - class_num: *class_num - pretrained: True - use_ssld: True - use_sync_bn: True - - infer_model_name: "Student" - -# loss function config for traing/eval process -Loss: - Train: - - DistillationDMLLoss: - weight: 1.0 - model_name_pairs: - - ["Student", "Teacher"] - Eval: - - CELoss: - weight: 1.0 - - -Optimizer: - name: Momentum - momentum: 0.9 - lr: - name: Cosine - learning_rate: 0.01 - warmup_epoch: 5 - regularizer: - name: 'L2' - coeff: 0.00004 - - -# data loader for train and eval -DataLoader: - Train: - dataset: - name: ImageNetDataset - image_root: ./dataset/person/ - cls_label_path: ./dataset/person/train_list_for_distill.txt - transform_ops: - - DecodeImage: - to_rgb: True - channel_first: False - - RandCropImage: - size: 224 - - RandFlipImage: - flip_code: 1 - - TimmAutoAugment: - prob: 0.0 - config_str: rand-m9-mstd0.5-inc1 - interpolation: bicubic - img_size: 224 - - NormalizeImage: - scale: 1.0/255.0 - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - order: '' - - RandomErasing: - EPSILON: 0.0 - sl: 0.02 - sh: 1.0/3.0 - r1: 0.3 - attempt: 10 - use_log_aspect: True - mode: pixel - sampler: - name: DistributedBatchSampler - batch_size: 64 - drop_last: False - shuffle: True - loader: - num_workers: 16 - use_shared_memory: True - - Eval: - dataset: - name: ImageNetDataset - image_root: ./dataset/person/ - cls_label_path: ./dataset/person/val_list.txt - transform_ops: - - DecodeImage: - to_rgb: True - channel_first: False - - ResizeImage: - resize_short: 256 - - CropImage: - size: 224 - - NormalizeImage: - scale: 1.0/255.0 - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - order: '' - sampler: - name: DistributedBatchSampler - batch_size: 64 - drop_last: False - shuffle: False - loader: - num_workers: 4 - use_shared_memory: True - -Infer: - infer_imgs: docs/images/inference_deployment/whl_demo.jpg - batch_size: 10 - transforms: - - DecodeImage: - to_rgb: True - channel_first: False - - ResizeImage: - resize_short: 256 - - CropImage: - size: 224 - - NormalizeImage: - scale: 1.0/255.0 - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - order: '' - - ToCHWImage: - PostProcess: - name: Topk - topk: 5 - class_id_map_file: ppcls/utils/imagenet1k_label_list.txt - -Metric: - Train: - - DistillationTopkAcc: - model_key: "Student" - topk: [1, 2] - Eval: - - TprAtFpr: - - TopkAcc: - topk: [1, 2] diff --git a/ppcls/data/postprocess/__init__.py b/ppcls/data/postprocess/__init__.py index 831a4da0008ba70824203be3a6f46c9700225457..54678dc443ebab5bf55d54d9284d328bbc4523b3 100644 --- a/ppcls/data/postprocess/__init__.py +++ b/ppcls/data/postprocess/__init__.py @@ -14,9 +14,10 @@ import copy import importlib -from . import topk +from . import topk, threshoutput from .topk import Topk, MultiLabelTopk +from .threshoutput import ThreshOutput def build_postprocess(config): diff --git a/ppcls/data/postprocess/threshoutput.py b/ppcls/data/postprocess/threshoutput.py new file mode 100644 index 0000000000000000000000000000000000000000..607aecbfdeae018a5334f723effd658fb480713a --- /dev/null +++ b/ppcls/data/postprocess/threshoutput.py @@ -0,0 +1,36 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn.functional as F + + +class ThreshOutput(object): + def __init__(self, threshold, label_0="0", label_1="1"): + self.threshold = threshold + self.label_0 = label_0 + self.label_1 = label_1 + + def __call__(self, x, file_names=None): + y = [] + x = F.softmax(x, axis=-1).numpy() + for idx, probs in enumerate(x): + score = probs[1] + if score < self.threshold: + result = {"class_ids": [0], "scores": [1 - score], "label_names": [self.label_0]} + else: + result = {"class_ids": [1], "scores": [score], "label_names": [self.label_1]} + if file_names is not None: + result["file_name"] = file_names[idx] + y.append(result) + return y diff --git a/tools/search_strategy.py b/tools/search_strategy.py index b4325d65ae23f2b1c420b9dfd44c7a79e5f52bbf..15f4aa71be67bbd0f5ec92d240bbc53896684d91 100644 --- a/tools/search_strategy.py +++ b/tools/search_strategy.py @@ -91,7 +91,7 @@ def search_strategy(): res = search_train(teacher_list, teacher_program, base_output_dir, "teacher", replace_config, model_name) all_results["teacher"] = res best = res.get("best") - t_pretrained = "{}/{}_{}/{}/best_model".format(base_output_dir, "teacher", best, best) + t_pretrained = "{}/{}_{}_0/{}/best_model".format(base_output_dir, "teacher", best, best) base_program += ["-o", "Arch.models.0.Teacher.name={}".format(best), "-o", "Arch.models.0.Teacher.pretrained={}".format(t_pretrained)] output_dir = "{}/search_res".format(base_output_dir)