未验证 提交 51c1bdb2 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #1946 from PaddlePaddle/develop

merge Develop into release/2.4
......@@ -11,6 +11,7 @@
<div align="center">
<img src="https://user-images.githubusercontent.com/80816848/170166458-767a01ca-1429-437f-a628-dd184732ef53.png" width = "150" />
</div>
- 2022.5.23 新增[人员出入管理范例库](https://aistudio.baidu.com/aistudio/projectdetail/4094475),具体内容可以在 AI Stuio 上体验。
- 2022.5.20 上线[PP-HGNet](./docs/zh_CN/models/PP-HGNet.md), [PP-LCNet v2](./docs/zh_CN/models/PP-LCNetV2.md)
- 2022.4.21 新增 CVPR2022 oral论文 [MixFormer](https://arxiv.org/pdf/2204.02557.pdf) 相关[代码](https://github.com/PaddlePaddle/PaddleClas/pull/1820/files)
- 2022.1.27 全面升级文档;新增[PaddleServing C++ pipeline部署方式](./deploy/paddleserving)[18M图像识别安卓部署Demo](./deploy/lite_shitu)
......
Global:
infer_imgs: "./images/PULC/person/objects365_02035329.jpg"
inference_model_dir: "./models/person_cls_infer"
batch_size: 1
use_gpu: True
enable_mkldnn: False
cpu_num_threads: 10
enable_benchmark: True
use_fp16: False
ir_optim: True
use_tensorrt: False
gpu_mem: 8000
enable_profile: False
PreProcess:
transform_ops:
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
channel_num: 3
- ToCHWImage:
PostProcess:
main_indicator: ThreshOutput
ThreshOutput:
threshold: 0.9
label_0: nobody
label_1: someone
SavePreLabel:
save_dir: ./pre_label/
......@@ -53,6 +53,26 @@ class PostProcesser(object):
return rtn
class ThreshOutput(object):
def __init__(self, threshold, label_0="0", label_1="1"):
self.threshold = threshold
self.label_0 = label_0
self.label_1 = label_1
def __call__(self, x, file_names=None):
y = []
for idx, probs in enumerate(x):
score = probs[1]
if score < self.threshold:
result = {"class_ids": [0], "scores": [1 - score], "label_names": [self.label_0]}
else:
result = {"class_ids": [1], "scores": [score], "label_names": [self.label_1]}
if file_names is not None:
result["file_name"] = file_names[idx]
y.append(result)
return y
class Topk(object):
def __init__(self, topk=1, class_id_map_file=None):
assert isinstance(topk, (int, ))
......
......@@ -49,10 +49,15 @@ class ClsPredictor(Predictor):
pid = os.getpid()
size = config["PreProcess"]["transform_ops"][1]["CropImage"][
"size"]
if config["Global"].get("use_int8", False):
precision = "int8"
elif config["Global"].get("use_fp16", False):
precision = "fp16"
else:
precision = "fp32"
self.auto_logger = auto_log.AutoLogger(
model_name=config["Global"].get("model_name", "cls"),
model_precision='fp16'
if config["Global"]["use_fp16"] else 'fp32',
model_precision=precision,
batch_size=config["Global"].get("batch_size", 1),
data_shape=[3, size, size],
save_path=config["Global"].get("save_log_path",
......
......@@ -42,8 +42,22 @@ class Predictor(object):
def create_paddle_predictor(self, args, inference_model_dir=None):
if inference_model_dir is None:
inference_model_dir = args.inference_model_dir
params_file = os.path.join(inference_model_dir, "inference.pdiparams")
model_file = os.path.join(inference_model_dir, "inference.pdmodel")
if "inference_int8.pdiparams" in os.listdir(inference_model_dir):
params_file = os.path.join(inference_model_dir,
"inference_int8.pdiparams")
model_file = os.path.join(inference_model_dir,
"inference_int8.pdmodel")
assert args.get(
"use_fp16", False
) is False, "fp16 mode is not supported for int8 model inference, please set use_fp16 as False during inference."
else:
params_file = os.path.join(inference_model_dir,
"inference.pdiparams")
model_file = os.path.join(inference_model_dir, "inference.pdmodel")
assert args.get(
"use_int8", False
) is False, "int8 mode is not supported for fp32 model inference, please set use_int8 as False during inference."
config = Config(model_file, params_file)
if args.use_gpu:
......@@ -63,12 +77,18 @@ class Predictor(object):
config.disable_glog_info()
config.switch_ir_optim(args.ir_optim) # default true
if args.use_tensorrt:
precision = Config.Precision.Float32
if args.get("use_int8", False):
precision = Config.Precision.Int8
elif args.get("use_fp16", False):
precision = Config.Precision.Half
config.enable_tensorrt_engine(
precision_mode=Config.Precision.Half
if args.use_fp16 else Config.Precision.Float32,
precision_mode=precision,
max_batch_size=args.batch_size,
workspace_size=1 << 30,
min_subgraph_size=30)
min_subgraph_size=30,
use_calib_mode=False)
config.enable_memory_optim()
# use zero copy
......
# PaddleClas构建有人/无人分类案例
此处提供了用户使用 PaddleClas 快速构建轻量级、高精度、可落地的有人/无人的分类模型教程,主要基于有人/无人场景的数据,融合了轻量级骨干网络PPLCNet、SSLD预训练权重、EDA数据增强策略、SKL-UGI知识蒸馏策略、SHAS超参数搜索策略,得到精度高、速度快、易于部署的二分类模型。
------
## 目录
- [1. 环境配置](#1)
- [2. 有人/无人场景推理预测](#2)
- [2.1 下载模型](#2.1)
- [2.2 模型推理预测](#2.2)
- [2.2.1 预测单张图像](#2.2.1)
- [2.2.2 基于文件夹的批量预测](#2.2.2)
- [3.有人/无人场景训练](#3)
- [3.1 数据准备](#3.1)
- [3.2 模型训练](#3.2)
- [3.2.1 基于默认超参数训练](#3.2.1)
- [3.2.1.1 基于默认超参数训练轻量级模型](#3.2.1.1)
- [3.2.1.2 基于默认超参数训练教师模型](#3.2.1.2)
- [3.2.1.3 基于默认超参数进行蒸馏训练](#3.2.1.3)
- [3.2.2 超参数搜索训练](#3.2)
- [4. 模型评估与推理](#4)
- [4.1 模型评估](#3.1)
- [4.2 模型预测](#3.2)
- [4.3 使用 inference 模型进行推理](#4.3)
- [4.3.1 导出 inference 模型](#4.3.1)
- [4.3.2 模型推理预测](#4.3.2)
<a name="1"></a>
## 1. 环境配置
* 安装:请先参考 [Paddle 安装教程](../installation/install_paddle.md) 以及 [PaddleClas 安装教程](../installation/install_paddleclas.md) 配置 PaddleClas 运行环境。
<a name="2"></a>
## 2. 有人/无人场景推理预测
<a name="2.1"></a>
### 2.1 下载模型
* 进入 `deploy` 运行目录。
```
cd deploy
```
下载有人/无人分类的模型。
```
mkdir models
cd models
# 下载inference 模型并解压
wget https://paddleclas.bj.bcebos.com/models/PULC/person_cls_infer.tar && tar -xf person_cls_infer.tar
```
解压完毕后,`models` 文件夹下应有如下文件结构:
```
├── person_cls_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
```
<a name="2.2"></a>
### 2.2 模型推理预测
<a name="2.2.1"></a>
#### 2.2.1 预测单张图像
返回 `deploy` 目录:
```
cd ../
```
运行下面的命令,对图像 `./images/PULC/person/objects365_02035329.jpg` 进行有人/无人分类。
```shell
# 使用下面的命令使用 GPU 进行预测
python3.7 python/predict_cls.py -c configs/PULC/person/inference_person_cls.yaml -o PostProcess.ThreshOutput.threshold=0.9794
# 使用下面的命令使用 CPU 进行预测
python3.7 python/predict_cls.py -c configs/PULC/person/inference_person_cls.yaml -o PostProcess.ThreshOutput.threshold=0.9794 -o Global.use_gpu=False
```
输出结果如下。
```
objects365_02035329.jpg: class id(s): [1], score(s): [1.00], label_name(s): ['someone']
```
**备注:** 真实场景中往往需要在假正类率(Fpr)小于某一个指标下求真正类率(Tpr),该场景中的`val`数据集在千分之一Fpr下得到的最佳Tpr所得到的阈值为`0.9794`,故此处的`threshold``0.9794`。该阈值的确定方法可以参考[3.2节](#3.2)
<a name="2.2.2"></a>
#### 2.2.2 基于文件夹的批量预测
如果希望预测文件夹内的图像,可以直接修改配置文件中的 `Global.infer_imgs` 字段,也可以通过下面的 `-o` 参数修改对应的配置。
```shell
# 使用下面的命令使用 GPU 进行预测,如果希望使用 CPU 预测,可以在命令后面添加 -o Global.use_gpu=False
python3.7 python/predict_cls.py -c configs/PULC/person/inference_person_cls.yaml -o Global.infer_imgs="./images/PULC/person/"
```
终端中会输出该文件夹内所有图像的分类结果,如下所示。
```
objects365_01780782.jpg: class id(s): [0], score(s): [1.00], label_name(s): ['nobody']
objects365_02035329.jpg: class id(s): [1], score(s): [1.00], label_name(s): ['someone']
```
其中,`someone` 表示该图里存在人,`nobody` 表示该图里不存在人。
<a name="3"></a>
## 3.有人/无人场景训练
<a name="3.1"></a>
### 3.1 数据准备
进入 PaddleClas 目录。
```
cd path_to_PaddleClas
```
进入 `dataset/` 目录,下载并解压有人/无人场景的数据。
```shell
cd dataset
wget https://paddleclas.bj.bcebos.com/data/cls_demo/person.tar
tar -xf person.tar
cd ../
```
执行上述命令后,`dataset/`下存在`person`目录,该目录中具有以下数据:
```
├── train
│   ├── 000000000009.jpg
│   ├── 000000000025.jpg
...
├── val
│   ├── objects365_01780637.jpg
│   ├── objects365_01780640.jpg
...
├── ImageNet_val
│   ├── ILSVRC2012_val_00000001.JPEG
│   ├── ILSVRC2012_val_00000002.JPEG
...
├── train_list.txt
├── train_list.txt.debug
├── train_list_for_distill.txt
├── val_list.txt
└── val_list.txt.debug
```
其中`train/``val/`分别为训练集和验证集。`train_list.txt``val_list.txt`分别为训练集和验证集的标签文件,`train_list.txt.debug``val_list.txt.debug`分别为训练集和验证集的`debug`标签文件,其分别是`train_list.txt``val_list.txt`的子集,用该文件可以快速体验本案例的流程。`ImageNet_val/`是ImageNet的验证集,该集合和`train`集合的混合数据用于本案例的`SKL-UGI知识蒸馏策略`,对应的训练标签文件为`train_list_for_distill.txt`
* **注意**:
* 本案例中所使用的所有数据集均为开源数据,`train`集合为[MS-COCO数据](https://cocodataset.org/#overview)的训练集的子集,`val`集合为[Object365数据](https://www.objects365.org/overview.html)的训练集的子集,`ImageNet_val`[ImageNet数据](https://www.image-net.org/)的验证集。数据集的筛选流程可以参考[有人/无人场景数据集筛选方法]()。
<a name="3.2"></a>
### 3.2 模型训练
<a name="3.2.1"></a>
#### 3.2.1 基于默认超参数训练
<a name="3.2.1.1"></a>
##### 3.2.1.1 基于默认超参数训练轻量级模型
`ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml`中提供了基于该场景的训练配置,可以通过如下脚本启动训练:
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
tools/train.py \
-c ./ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml
```
验证集的最佳指标在0.94-0.95之间(数据集较小,容易造成波动)。
**备注:**
* 此时使用的指标为Tpr,该指标描述了在假正类率(Fpr)小于某一个指标时的真正类率(Tpr),是产业中二分类问题常用的指标之一。在本案例中,Fpr为千分之一。关于Fpr和Tpr的更多介绍,可以参考[这里](https://baike.baidu.com/item/AUC/19282953)
* 在eval时,会打印出来当前最佳的TprAtFpr指标,具体地,其会打印当前的`Fpr``Tpr`值,以及当前的`threshold`值,`Tpr`值反映了在当前`Fpr`值下的召回率,该值越高,代表模型越好。`threshold` 表示当前最佳`Fpr`所对应的分类阈值,可用于后续模型部署落地等。
<a name="3.2.1.2"></a>
##### 3.2.1.2 基于默认超参数训练教师模型
复用`ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml`中的超参数,训练教师模型,训练脚本如下:
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
tools/train.py \
-c ./ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml \
-o Arch.name=ResNet101_vd
```
验证集的最佳指标为0.96-0.98之间,当前教师模型最好的权重保存在`output/ResNet101_vd/best_model.pdparams`
<a name="3.2.1.3"></a>
##### 3.2.1.3 基于默认超参数进行蒸馏训练
配置文件`ppcls/configs/PULC/PULC/Distillation/PPLCNet_x1_0_distillation.yaml`提供了`SKL-UGI知识蒸馏策略`的配置。该配置将`ResNet101_vd`当作教师模型,`PPLCNet_x1_0`当作学生模型,使用ImageNet数据集的验证集作为新增的无标签数据。训练脚本如下:
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
tools/train.py \
-c ./ppcls/configs/PULC/person/Distillation/PPLCNet_x1_0_distillation.yaml \
-o Arch.models.0.Teacher.pretrained=output/ResNet101_vd/best_model
```
验证集的最佳指标为0.95-0.97之间,当前模型最好的权重保存在`output/DistillationModel/best_model_student.pdparams`
<a name="3.2.2"></a>
#### 3.2.2 超参数搜索训练
[3.2 小节](#3.2) 提供了在已经搜索并得到的超参数上进行了训练,此部分内容提供了搜索的过程,此过程是为了得到更好的训练超参数。
* 搜索运行脚本如下:
```shell
python tools/search_strategy.py -c ppcls/configs/StrategySearch/person.yaml
```
`ppcls/configs/StrategySearch/person.yaml`中指定了具体的 GPU id 号和搜索配置, 默认搜索的训练日志和模型存放于`output/search_person`中,最终的蒸馏模型存放于`output/search_person/search_res/DistillationModel/best_model_student.pdparams`
* **注意**:
* 3.1小节提供的默认配置已经经过了搜索,所以此过程不是必要的过程,如果自己的训练数据集有变化,可以尝试此过程。
* 此过程基于当前数据集在 V100 4 卡上大概需要耗时 10 小时,如果缺少机器资源,希望体验搜索过程,可以将`ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0_search.yaml`中的`train_list.txt``val_list.txt`分别替换为`train_list.txt.debug``val_list.txt.debug`。替换list只是为了加速跑通整个搜索过程,由于数据量较小,其搜素的结果没有参考性。另外,搜索空间可以根据当前的机器资源来调整,如果机器资源有限,可以尝试缩小搜索空间,如果机器资源较充足,可以尝试扩大搜索空间。
* 如果此过程搜索的得到的超参数与[3.2.1小节](#3.2.1)提供的超参数不一致,主要是由于训练数据较小造成的波动导致,可以忽略。
<a name="4"></a>
## 4. 模型评估与推理
<a name="4.1"></a>
### 4.1 模型评估
训练好模型之后,可以通过以下命令实现对模型指标的评估。
```bash
python3 tools/eval.py \
-c ./ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml \
-o Global.pretrained_model="output/DistillationModel/best_model_student"
```
<a name="4.2"></a>
### 4.2 模型预测
模型训练完成之后,可以加载训练得到的预训练模型,进行模型预测。在模型库的 `tools/infer.py` 中提供了完整的示例,只需执行下述命令即可完成模型预测:
```python
python3 tools/infer.py \
-c ./ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0.yaml \
-o Infer.infer_imgs=./dataset/person/val/objects365_01780637.jpg \
-o Global.pretrained_model=output/DistillationModel/best_model_student \
-o Global.pretrained_model=Infer.PostProcess.threshold=0.9794
```
输出结果如下:
```
[{'class_ids': [0], 'scores': [0.9878496769815683], 'label_names': ['nobody'], 'file_name': './dataset/person/val/objects365_01780637.jpg'}]
```
**备注:** 这里的`Infer.PostProcess.threshold`的值需要根据实际场景来确定,此处的`0.9794`是在该场景中的`val`数据集在千分之一Fpr下得到的最佳Tpr所得到的。
<a name="4.3"></a>
### 4.3 使用 inference 模型进行推理
<a name="4.3.1"></a>
### 4.3.1 导出 inference 模型
通过导出 inference 模型,PaddlePaddle 支持使用预测引擎进行预测推理。接下来介绍如何用预测引擎进行推理:
首先,对训练好的模型进行转换:
```bash
python3 tools/export_model.py \
-c ./ppcls/configs/cls_demo/PULC/PPLCNet/PPLCNet_x1_0.yaml \
-o Global.pretrained_model=output/DistillationModel/best_model_student \
-o Global.save_inference_dir=deploy/models/PPLCNet_x1_0_person
```
执行完该脚本后会在`deploy/models/`下生成`PPLCNet_x1_0_person`文件夹,该文件夹中的模型与 2.2 节下载的推理预测模型格式一致。
<a name="4.3.2"></a>
### 4.3.2 基于 inference 模型推理预测
推理预测的脚本为:
```
python3.7 python/predict_cls.py -c configs/PULC/person/inference_person_cls.yaml -o Global.inference_model_dir="models/PPLCNet_x1_0_person" -o PostProcess.ThreshOutput.threshold=0.9794
```
**备注:**
- 此处的`PostProcess.ThreshOutput.threshold`由eval时的最佳`threshold`来确定。
- 更多关于推理的细节,可以参考[2.2节](#2.2)
## 人员出入管理
近几年,AI视觉技术在安防、工业制造等场景在产业智能化升级进程中发挥着举足轻重的作用。【进出管控】作为各行业中的关键场景,应用需求十分迫切。 如在居家防盗、机房管控以及景区危险告警等场景中,存在大量对异常目标(人、车或其他物体)不经允许擅自进入规定区域的及时检测需求。利用深度学习视觉技术,可以及时准确地对闯入行为进行识别并发出告警信息。切实保障人员的生命财产安全。相比传统人力监管的方式,不仅可以实现7*24小时不间断的全方位保护,还能极大地降低管理成本,解放劳动力。
但在真实产业中,要实现高精度的人员进出识别不是一件容易的事,在实际场景中存在着各种各样的问题:
**摄像头采集到的图像会受到建筑、机器、车辆等遮挡的影响**
**天气多种多样,要适应白天、黑夜、雾天和雨天等**
针对上述场景,本次飞桨产业实践范例库推出了重点区域人员进出管控实践示例,提供从数据准备、技术方案、模型训练优化,到模型部署的全流程可复用方案,有效解决了不同光照、不同天气等室外复杂环境下的图像分类问题,并且极大地降低了数据标注和算力成本,适用于厂区巡检、家居防盗、景区管理等多个产业应用。
![result](./imgs/someone.gif)
**注**: AI Studio在线运行代码请参考[人员出入管理](https://aistudio.baidu.com/aistudio/projectdetail/4094475)
......@@ -32,14 +32,18 @@ from ppcls.arch.distill.afd_attention import LinearTransformStudent, LinearTrans
__all__ = ["build_model", "RecModel", "DistillationModel", "AttentionModel"]
def build_model(config):
def build_model(config, mode="train"):
arch_config = copy.deepcopy(config["Arch"])
model_type = arch_config.pop("name")
use_sync_bn = arch_config.pop("use_sync_bn", False)
mod = importlib.import_module(__name__)
arch = getattr(mod, model_type)(**arch_config)
if use_sync_bn:
arch = nn.SyncBatchNorm.convert_sync_batchnorm(arch)
if isinstance(arch, TheseusLayer):
prune_model(config, arch)
quantize_model(config, arch)
quantize_model(config, arch, mode)
return arch
......@@ -51,6 +55,7 @@ def apply_to_static(config, model):
specs = None
if 'image_shape' in config['Global']:
specs = [InputSpec([None] + config['Global']['image_shape'])]
specs[0].stop_gradient = True
model = to_static(model, input_spec=specs)
logger.info("Successfully to apply @to_static with specs: {}".format(
specs))
......
......@@ -52,7 +52,7 @@ from ppcls.arch.backbone.model_zoo.darknet import DarkNet53
from ppcls.arch.backbone.model_zoo.regnet import RegNetX_200MF, RegNetX_4GF, RegNetX_32GF, RegNetY_200MF, RegNetY_4GF, RegNetY_32GF
from ppcls.arch.backbone.model_zoo.vision_transformer import ViT_small_patch16_224, ViT_base_patch16_224, ViT_base_patch16_384, ViT_base_patch32_384, ViT_large_patch16_224, ViT_large_patch16_384, ViT_large_patch32_384
from ppcls.arch.backbone.model_zoo.distilled_vision_transformer import DeiT_tiny_patch16_224, DeiT_small_patch16_224, DeiT_base_patch16_224, DeiT_tiny_distilled_patch16_224, DeiT_small_distilled_patch16_224, DeiT_base_distilled_patch16_224, DeiT_base_patch16_384, DeiT_base_distilled_patch16_384
from ppcls.arch.backbone.model_zoo.swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384
from ppcls.arch.backbone.legendary_models.swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384
from ppcls.arch.backbone.model_zoo.cswin_transformer import CSWinTransformer_tiny_224, CSWinTransformer_small_224, CSWinTransformer_base_224, CSWinTransformer_large_224, CSWinTransformer_base_384, CSWinTransformer_large_384
from ppcls.arch.backbone.model_zoo.mixnet import MixNet_S, MixNet_M, MixNet_L
from ppcls.arch.backbone.model_zoo.rexnet import ReXNet_1_0, ReXNet_1_3, ReXNet_1_5, ReXNet_2_0, ReXNet_3_0
......
......@@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function
import paddle
import paddle.nn as nn
from paddle import ParamAttr
from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear
from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Dropout, Linear
from paddle.regularizer import L2Decay
from paddle.nn.initializer import KaimingNormal
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
......@@ -83,7 +83,8 @@ class ConvBNLayer(TheseusLayer):
filter_size,
num_filters,
stride,
num_groups=1):
num_groups=1,
lr_mult=1.0):
super().__init__()
self.conv = Conv2D(
......@@ -93,13 +94,13 @@ class ConvBNLayer(TheseusLayer):
stride=stride,
padding=(filter_size - 1) // 2,
groups=num_groups,
weight_attr=ParamAttr(initializer=KaimingNormal()),
weight_attr=ParamAttr(initializer=KaimingNormal(), learning_rate=lr_mult),
bias_attr=False)
self.bn = BatchNorm(
self.bn = BatchNorm2D(
num_filters,
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
weight_attr=ParamAttr(regularizer=L2Decay(0.0), learning_rate=lr_mult),
bias_attr=ParamAttr(regularizer=L2Decay(0.0), learning_rate=lr_mult))
self.hardswish = nn.Hardswish()
def forward(self, x):
......@@ -115,7 +116,8 @@ class DepthwiseSeparable(TheseusLayer):
num_filters,
stride,
dw_size=3,
use_se=False):
use_se=False,
lr_mult=1.0):
super().__init__()
self.use_se = use_se
self.dw_conv = ConvBNLayer(
......@@ -123,14 +125,17 @@ class DepthwiseSeparable(TheseusLayer):
num_filters=num_channels,
filter_size=dw_size,
stride=stride,
num_groups=num_channels)
num_groups=num_channels,
lr_mult=lr_mult)
if use_se:
self.se = SEModule(num_channels)
self.se = SEModule(num_channels,
lr_mult=lr_mult)
self.pw_conv = ConvBNLayer(
num_channels=num_channels,
filter_size=1,
num_filters=num_filters,
stride=1)
stride=1,
lr_mult=lr_mult)
def forward(self, x):
x = self.dw_conv(x)
......@@ -141,7 +146,7 @@ class DepthwiseSeparable(TheseusLayer):
class SEModule(TheseusLayer):
def __init__(self, channel, reduction=4):
def __init__(self, channel, reduction=4, lr_mult=1.0):
super().__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv1 = Conv2D(
......@@ -149,14 +154,18 @@ class SEModule(TheseusLayer):
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0)
padding=0,
weight_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=ParamAttr(learning_rate=lr_mult))
self.relu = nn.ReLU()
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0)
padding=0,
weight_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=ParamAttr(learning_rate=lr_mult))
self.hardsigmoid = nn.Hardsigmoid()
def forward(self, x):
......@@ -177,17 +186,32 @@ class PPLCNet(TheseusLayer):
class_num=1000,
dropout_prob=0.2,
class_expand=1280,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
use_last_conv=True,
return_patterns=None,
return_stages=None):
super().__init__()
self.scale = scale
self.class_expand = class_expand
self.lr_mult_list = lr_mult_list
self.use_last_conv = use_last_conv
if isinstance(self.lr_mult_list, str):
self.lr_mult_list = eval(self.lr_mult_list)
assert isinstance(self.lr_mult_list, (
list, tuple
)), "lr_mult_list should be in (list, tuple) but got {}".format(
type(self.lr_mult_list))
assert len(self.lr_mult_list
) == 6, "lr_mult_list length should be 5 but got {}".format(
len(self.lr_mult_list))
self.conv1 = ConvBNLayer(
num_channels=3,
filter_size=3,
num_filters=make_divisible(16 * scale),
stride=2)
stride=2,
lr_mult=self.lr_mult_list[0])
self.blocks2 = nn.Sequential(* [
DepthwiseSeparable(
......@@ -195,7 +219,8 @@ class PPLCNet(TheseusLayer):
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
use_se=se,
lr_mult=self.lr_mult_list[1])
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"])
])
......@@ -205,7 +230,8 @@ class PPLCNet(TheseusLayer):
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
use_se=se,
lr_mult=self.lr_mult_list[2])
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"])
])
......@@ -215,7 +241,8 @@ class PPLCNet(TheseusLayer):
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
use_se=se,
lr_mult=self.lr_mult_list[3])
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"])
])
......@@ -225,7 +252,8 @@ class PPLCNet(TheseusLayer):
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
use_se=se,
lr_mult=self.lr_mult_list[4])
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"])
])
......@@ -235,25 +263,26 @@ class PPLCNet(TheseusLayer):
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
use_se=se,
lr_mult=self.lr_mult_list[5])
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"])
])
self.avg_pool = AdaptiveAvgPool2D(1)
self.last_conv = Conv2D(
in_channels=make_divisible(NET_CONFIG["blocks6"][-1][2] * scale),
out_channels=self.class_expand,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
self.hardswish = nn.Hardswish()
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
if self.use_last_conv:
self.last_conv = Conv2D(
in_channels=make_divisible(NET_CONFIG["blocks6"][-1][2] * scale),
out_channels=self.class_expand,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
self.hardswish = nn.Hardswish()
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
else:
self.last_conv = None
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
self.fc = Linear(self.class_expand, class_num)
self.fc = Linear(self.class_expand if self.use_last_conv else NET_CONFIG["blocks6"][-1][2], class_num)
super().init_res(
stages_pattern,
......@@ -270,9 +299,10 @@ class PPLCNet(TheseusLayer):
x = self.blocks6(x)
x = self.avg_pool(x)
x = self.last_conv(x)
x = self.hardswish(x)
x = self.dropout(x)
if self.last_conv is not None:
x = self.last_conv(x)
x = self.hardswish(x)
x = self.dropout(x)
x = self.flatten(x)
x = self.fc(x)
return x
......
......@@ -20,9 +20,10 @@ import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
from paddle.nn import Conv2D, BatchNorm, Linear
from paddle.nn import Conv2D, BatchNorm, Linear, BatchNorm2D
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
from paddle.regularizer import L2Decay
import math
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
......@@ -132,11 +133,12 @@ class ConvBNLayer(TheseusLayer):
weight_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=False,
data_format=data_format)
self.bn = BatchNorm(
num_filters,
param_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=ParamAttr(learning_rate=lr_mult),
data_layout=data_format)
weight_attr = ParamAttr(learning_rate=lr_mult, trainable=True)
bias_attr = ParamAttr(learning_rate=lr_mult, trainable=True)
self.bn = BatchNorm2D(
num_filters, weight_attr=weight_attr, bias_attr=bias_attr)
self.relu = nn.ReLU()
def forward(self, x):
......@@ -192,6 +194,7 @@ class BottleneckBlock(TheseusLayer):
is_vd_mode=False if if_first else True,
lr_mult=lr_mult,
data_format=data_format)
self.relu = nn.ReLU()
self.shortcut = shortcut
......@@ -312,7 +315,7 @@ class ResNet(TheseusLayer):
[[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
}
self.stem = nn.Sequential(*[
self.stem = nn.Sequential(* [
ConvBNLayer(
num_channels=in_c,
num_filters=out_c,
......
......@@ -21,8 +21,8 @@ import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import TruncatedNormal, Constant
from .vision_transformer import trunc_normal_, zeros_, ones_, to_2tuple, DropPath, Identity
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.arch.backbone.model_zoo.vision_transformer import trunc_normal_, zeros_, ones_, to_2tuple, DropPath, Identity
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
......@@ -589,7 +589,7 @@ class PatchEmbed(nn.Layer):
return flops
class SwinTransformer(nn.Layer):
class SwinTransformer(TheseusLayer):
""" Swin Transformer
A PaddlePaddle impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
......
......@@ -40,12 +40,14 @@ QUANT_CONFIG = {
}
def quantize_model(config, model):
def quantize_model(config, model, mode="train"):
if config.get("Slim", False) and config["Slim"].get("quant", False):
from paddleslim.dygraph.quant import QAT
assert config["Slim"]["quant"]["name"].lower(
) == 'pact', 'Only PACT quantization method is supported now'
QUANT_CONFIG["activation_preprocess_type"] = "PACT"
if mode in ["infer", "export"]:
QUANT_CONFIG['activation_preprocess_type'] = None
model.quanter = QAT(config=QUANT_CONFIG)
model.quanter.quantize(model)
logger.info("QAT model summary:")
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 5
eval_during_train: True
eval_interval: 1
epochs: 30
print_batch_step: 20
use_visualdl: False
# used for static mode and model export
image_shape: [3, 256, 192]
save_inference_dir: "./inference"
use_multilabel: True
# model architecture
Arch:
name: "ResNet50"
pretrained: True
class_num: 26
# loss function config for traing/eval process
Loss:
Train:
- MultiLabelLoss:
weight: 1.0
weight_ratio: True
size_sum: True
Eval:
- MultiLabelLoss:
weight: 1.0
weight_ratio: True
size_sum: True
Optimizer:
name: Adam
lr:
name: Piecewise
decay_epochs: [12, 18, 24, 28]
values: [0.0001, 0.00001, 0.000001, 0.0000001]
regularizer:
name: 'L2'
coeff: 0.0005
clip_norm: 10
# data loader for train and eval
DataLoader:
Train:
dataset:
name: MultiLabelDataset
image_root: "dataset/attribute/data/"
cls_label_path: "dataset/attribute/trainval.txt"
label_ratio: True
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: [192, 256]
- Padv2:
size: [212, 276]
pad_mode: 1
fill_value: 0
- RandomCropImage:
size: [192, 256]
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: True
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: MultiLabelDataset
image_root: "dataset/attribute/data/"
cls_label_path: "dataset/attribute/test.txt"
label_ratio: True
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: [192, 256]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- ATTRMetric:
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output
device: gpu
save_interval: 1
eval_during_train: True
start_eval_epoch: 1
eval_interval: 1
epochs: 20
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
use_dali: False
# model architecture
Arch:
name: "DistillationModel"
class_num: &class_num 2
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
use_sync_bn: True
models:
- Teacher:
name: ResNet101_vd
class_num: *class_num
- Student:
name: PPLCNet_x1_0
class_num: *class_num
pretrained: True
use_ssld: True
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationDMLLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.01
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00004
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/person/
cls_label_path: ./dataset/person/train_list_for_distill.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 192
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
prob: 0.0
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 192
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.1
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: True
loader:
num_workers: 16
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/person/
cls_label_path: ./dataset/person/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: ThreshOutput
threshold: 0.9
label_0: nobody
label_1: someone
Metric:
Train:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 2]
Eval:
- TprAtFpr:
- TopkAcc:
topk: [1, 2]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
start_eval_epoch: 10
epochs: 20
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
use_dali: False
# mixed precision training
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
# model architecture
Arch:
name: MobileNetV3_large_x1_0
class_num: 2
pretrained: True
use_sync_bn: True
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.13
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00002
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/person/
cls_label_path: ./dataset/person/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 512
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/person/
cls_label_path: ./dataset/person/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: ThreshOutput
threshold: 0.9
label_0: nobody
label_1: someone
Metric:
Train:
- TopkAcc:
topk: [1, 2]
Eval:
- TprAtFpr:
- TopkAcc:
topk: [1, 2]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
start_eval_epoch: 10
epochs: 20
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
use_dali: False
# mixed precision training
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
# model architecture
Arch:
name: SwinTransformer_tiny_patch4_window7_224
class_num: 2
pretrained: True
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
weight_decay: 0.05
no_weight_decay_name: absolute_pos_embed relative_position_bias_table .bias norm
one_dim_param_no_weight_decay: True
lr:
name: Cosine
learning_rate: 1e-4
eta_min: 2e-6
warmup_epoch: 5
warmup_start_lr: 2e-7
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/person/
cls_label_path: ./dataset/person/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
interpolation: bicubic
backend: pil
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.25
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
batch_transform_ops:
- OpSampler:
MixupOperator:
alpha: 0.8
prob: 0.5
CutmixOperator:
alpha: 1.0
prob: 0.5
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/person/
cls_label_path: ./dataset/person/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 8
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: ThreshOutput
threshold: 0.9
label_0: nobody
label_1: someone
Metric:
Train:
- TopkAcc:
topk: [1, 2]
Eval:
- TprAtFpr:
- TopkAcc:
topk: [1, 2]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
start_eval_epoch: 10
epochs: 20
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
use_dali: False
# model architecture
Arch:
name: PPLCNet_x1_0
class_num: 2
pretrained: True
use_ssld: True
use_sync_bn: True
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.01
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00004
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/person/
cls_label_path: ./dataset/person/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 192
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
prob: 0.0
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 192
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.1
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/person/
cls_label_path: ./dataset/person/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: ThreshOutput
threshold: 0.9
label_0: nobody
label_1: someone
Metric:
Train:
- TopkAcc:
topk: [1, 2]
Eval:
- TprAtFpr:
- TopkAcc:
topk: [1, 2]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
start_eval_epoch: 10
epochs: 20
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
use_dali: False
# model architecture
Arch:
name: PPLCNet_x1_0
class_num: 2
pretrained: True
use_ssld: True
use_sync_bn: True
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.01
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00004
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/person/
cls_label_path: ./dataset/person/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
prob: 0.0
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.0
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/person/
cls_label_path: ./dataset/person/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: ThreshOutput
threshold: 0.9
label_0: nobody
label_1: someone
Metric:
Train:
- TopkAcc:
topk: [1, 2]
Eval:
- TprAtFpr:
- TopkAcc:
topk: [1, 2]
base_config_file: ppcls/configs/PULC/person/PPLCNet/PPLCNet_x1_0_search.yaml
distill_config_file: ppcls/configs/PULC/person/Distillation/PPLCNet_x1_0_distillation.yaml
gpus: 0,1,2,3
output_dir: output/search_person
search_times: 1
search_dict:
- search_key: lrs
replace_config:
- Optimizer.lr.learning_rate
search_values: [0.0075, 0.01, 0.0125]
- search_key: resolutions
replace_config:
- DataLoader.Train.dataset.transform_ops.1.RandCropImage.size
- DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.img_size
search_values: [176, 192, 224]
- search_key: ra_probs
replace_config:
- DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.prob
search_values: [0.0, 0.1, 0.5]
- search_key: re_probs
replace_config:
- DataLoader.Train.dataset.transform_ops.5.RandomErasing.EPSILON
search_values: [0.0, 0.1, 0.5]
- search_key: lr_mult_list
replace_config:
- Arch.lr_mult_list
search_values:
- [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
- [0.0, 0.4, 0.4, 0.8, 0.8, 1.0]
- [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
teacher:
rm_keys:
- Arch.lr_mult_list
search_values:
- ResNet101_vd
- ResNet50_vd
final_replace:
Arch.lr_mult_list: Arch.models.1.Student.lr_mult_list
......@@ -44,11 +44,11 @@ def create_operators(params):
class CommonDataset(Dataset):
def __init__(
self,
image_root,
cls_label_path,
transform_ops=None, ):
def __init__(self,
image_root,
cls_label_path,
transform_ops=None,
label_ratio=False):
self._img_root = image_root
self._cls_path = cls_label_path
if transform_ops:
......@@ -56,7 +56,10 @@ class CommonDataset(Dataset):
self.images = []
self.labels = []
self._load_anno()
if label_ratio:
self.label_ratio = self._load_anno(label_ratio=label_ratio)
else:
self._load_anno()
def _load_anno(self):
pass
......
......@@ -25,7 +25,7 @@ from .common_dataset import CommonDataset
class MultiLabelDataset(CommonDataset):
def _load_anno(self):
def _load_anno(self, label_ratio=False):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
self.images = []
......@@ -41,6 +41,8 @@ class MultiLabelDataset(CommonDataset):
self.labels.append(labels)
assert os.path.exists(self.images[-1])
if label_ratio:
return np.array(self.labels).mean(0).astype("float32")
def __getitem__(self, idx):
try:
......@@ -50,7 +52,10 @@ class MultiLabelDataset(CommonDataset):
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
label = np.array(self.labels[idx]).astype("float32")
return (img, label)
if self.label_ratio is not None:
return (img, np.array([label, self.label_ratio]))
else:
return (img, label)
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
......
......@@ -14,9 +14,10 @@
import copy
import importlib
from . import topk
from . import topk, threshoutput
from .topk import Topk, MultiLabelTopk
from .threshoutput import ThreshOutput
def build_postprocess(config):
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.nn.functional as F
class ThreshOutput(object):
def __init__(self, threshold, label_0="0", label_1="1"):
self.threshold = threshold
self.label_0 = label_0
self.label_1 = label_1
def __call__(self, x, file_names=None):
y = []
x = F.softmax(x, axis=-1).numpy()
for idx, probs in enumerate(x):
score = probs[1]
if score < self.threshold:
result = {"class_ids": [0], "scores": [1 - score], "label_names": [self.label_0]}
else:
result = {"class_ids": [1], "scores": [score], "label_names": [self.label_1]}
if file_names is not None:
result["file_name"] = file_names[idx]
y.append(result)
return y
......@@ -33,11 +33,14 @@ from ppcls.data.preprocess.ops.operators import AugMix
from ppcls.data.preprocess.ops.operators import Pad
from ppcls.data.preprocess.ops.operators import ToTensor
from ppcls.data.preprocess.ops.operators import Normalize
from ppcls.data.preprocess.ops.operators import RandomCropImage
from ppcls.data.preprocess.ops.operators import Padv2
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
import numpy as np
from PIL import Image
import random
def transform(data, ops=[]):
......@@ -88,16 +91,16 @@ class RandAugment(RawRandAugment):
class TimmAutoAugment(RawTimmAutoAugment):
""" TimmAutoAugment wrapper to auto fit different img tyeps. """
def __init__(self, *args, **kwargs):
def __init__(self, prob=1.0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prob = prob
def __call__(self, img):
if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
img = super().__call__(img)
if random.random() < self.prob:
img = super().__call__(img)
if isinstance(img, Image.Image):
img = np.asarray(img)
......
......@@ -190,6 +190,105 @@ class CropImage(object):
return img[h_start:h_end, w_start:w_end, :]
class Padv2(object):
def __init__(self,
size=None,
size_divisor=32,
pad_mode=0,
offsets=None,
fill_value=(127.5, 127.5, 127.5)):
"""
Pad image to a specified size or multiple of size_divisor.
Args:
size (int, list): image target size, if None, pad to multiple of size_divisor, default None
size_divisor (int): size divisor, default 32
pad_mode (int): pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets
if 0, only pad to right and bottom. if 1, pad according to center. if 2, only pad left and top
offsets (list): [offset_x, offset_y], specify offset while padding, only supported pad_mode=-1
fill_value (bool): rgb value of pad area, default (127.5, 127.5, 127.5)
"""
if not isinstance(size, (int, list)):
raise TypeError(
"Type of target_size is invalid when random_size is True. \
Must be List, now is {}".format(type(size)))
if isinstance(size, int):
size = [size, size]
assert pad_mode in [
-1, 0, 1, 2
], 'currently only supports four modes [-1, 0, 1, 2]'
if pad_mode == -1:
assert offsets, 'if pad_mode is -1, offsets should not be None'
self.size = size
self.size_divisor = size_divisor
self.pad_mode = pad_mode
self.fill_value = fill_value
self.offsets = offsets
def apply_image(self, image, offsets, im_size, size):
x, y = offsets
im_h, im_w = im_size
h, w = size
canvas = np.ones((h, w, 3), dtype=np.float32)
canvas *= np.array(self.fill_value, dtype=np.float32)
canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
return canvas
def __call__(self, img):
im_h, im_w = img.shape[:2]
if self.size:
w, h = self.size
assert (
im_h <= h and im_w <= w
), '(h, w) of target size should be greater than (im_h, im_w)'
else:
h = int(np.ceil(im_h / self.size_divisor) * self.size_divisor)
w = int(np.ceil(im_w / self.size_divisor) * self.size_divisor)
if h == im_h and w == im_w:
return img.astype(np.float32)
if self.pad_mode == -1:
offset_x, offset_y = self.offsets
elif self.pad_mode == 0:
offset_y, offset_x = 0, 0
elif self.pad_mode == 1:
offset_y, offset_x = (h - im_h) // 2, (w - im_w) // 2
else:
offset_y, offset_x = h - im_h, w - im_w
offsets, im_size, size = [offset_x, offset_y], [im_h, im_w], [h, w]
return self.apply_image(img, offsets, im_size, size)
class RandomCropImage(object):
"""Random crop image only
"""
def __init__(self, size):
super(RandomCropImage, self).__init__()
if isinstance(size, int):
size = [size, size]
self.size = size
def __call__(self, img):
h, w = img.shape[:2]
tw, th = self.size
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
img = img[i:i + th, j:j + tw, :]
if img.shape[0] != 256 or img.shape[1] != 192:
raise ValueError('sample: ', h, w, i, j, th, tw, img.shape)
return img
class RandCropImage(object):
""" random crop image """
......@@ -463,8 +562,8 @@ class Pad(object):
# Process fill color for affine transforms
major_found, minor_found = (int(v)
for v in PILLOW_VERSION.split('.')[:2])
major_required, minor_required = (
int(v) for v in min_pil_version.split('.')[:2])
major_required, minor_required = (int(v) for v in
min_pil_version.split('.')[:2])
if major_found < major_required or (major_found == major_required and
minor_found < minor_required):
if fill is None:
......
......@@ -189,7 +189,7 @@ class Engine(object):
self.eval_metric_func = None
# build model
self.model = build_model(self.config)
self.model = build_model(self.config, self.mode)
# set @to_static for benchmark, skip this by default.
apply_to_static(self.config, self.model)
......@@ -313,7 +313,7 @@ class Engine(object):
print_batch_step = self.config['Global']['print_batch_step']
save_interval = self.config["Global"]["save_interval"]
best_metric = {
"metric": 0.0,
"metric": -1.0,
"epoch": 0,
}
# key:
......@@ -345,18 +345,18 @@ class Engine(object):
if self.use_dali:
self.train_dataloader.reset()
metric_msg = ", ".join([
"{}: {:.5f}".format(key, self.output_info[key].avg)
for key in self.output_info
])
metric_msg = ", ".join(
[self.output_info[key].avg_info for key in self.output_info])
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
epoch_id, self.config["Global"]["epochs"], metric_msg))
self.output_info.clear()
# eval model and save model if possible
start_eval_epoch = self.config["Global"].get("start_eval_epoch",
0) - 1
if self.config["Global"][
"eval_during_train"] and epoch_id % self.config["Global"][
"eval_interval"] == 0:
"eval_interval"] == 0 and epoch_id > start_eval_epoch:
acc = self.eval(epoch_id)
if acc > best_metric["metric"]:
best_metric["metric"] = acc
......@@ -368,7 +368,8 @@ class Engine(object):
self.output_dir,
model_name=self.config["Arch"]["name"],
prefix="best_model",
loss=self.train_loss_func)
loss=self.train_loss_func,
save_student_model=True)
logger.info("[Eval][Epoch {}][best metric: {}]".format(
epoch_id, best_metric["metric"]))
logger.scaler(
......@@ -471,23 +472,19 @@ class Engine(object):
save_path = os.path.join(self.config["Global"]["save_inference_dir"],
"inference")
if model.quanter:
model.quanter.save_quantized_model(
model.base_model,
save_path,
input_spec=[
paddle.static.InputSpec(
shape=[None] + self.config["Global"]["image_shape"],
dtype='float32')
])
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + self.config["Global"]["image_shape"],
dtype='float32')
])
if hasattr(model.base_model,
"quanter") and model.base_model.quanter is not None:
model.base_model.quanter.save_quantized_model(model,
save_path + "_int8")
else:
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + self.config["Global"]["image_shape"],
dtype='float32')
])
paddle.jit.save(model, save_path)
logger.info(
f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
......
......@@ -23,6 +23,8 @@ from ppcls.utils import logger
def classification_eval(engine, epoch_id=0):
if hasattr(engine.eval_metric_func, "reset"):
engine.eval_metric_func.reset()
output_info = dict()
time_info = {
"batch_cost": AverageMeter(
......@@ -80,6 +82,7 @@ def classification_eval(engine, epoch_id=0):
# gather Tensor when distributed
if paddle.distributed.get_world_size() > 1:
label_list = []
paddle.distributed.all_gather(label_list, batch[1])
labels = paddle.concat(label_list, 0)
......@@ -121,18 +124,10 @@ def classification_eval(engine, epoch_id=0):
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0],
current_samples)
# calc metric
if engine.eval_metric_func is not None:
metric_dict = engine.eval_metric_func(preds, labels)
for key in metric_dict:
if metric_key is None:
metric_key = key
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(metric_dict[key].numpy()[0],
current_samples)
engine.eval_metric_func(preds, labels)
time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
......@@ -144,10 +139,14 @@ def classification_eval(engine, epoch_id=0):
ips_msg = "ips: {:.5f} images/sec".format(
batch_size / time_info["batch_cost"].avg)
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].val)
for key in output_info
])
if "ATTRMetric" in engine.config["Metric"]["Eval"][0]:
metric_msg = ""
else:
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].val)
for key in output_info
])
metric_msg += ", {}".format(engine.eval_metric_func.avg_info)
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
epoch_id, iter_id,
len(engine.eval_dataloader), metric_msg, time_msg, ips_msg))
......@@ -155,13 +154,29 @@ def classification_eval(engine, epoch_id=0):
tic = time.time()
if engine.use_dali:
engine.eval_dataloader.reset()
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg) for key in output_info
])
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
# do not try to save best eval.model
if engine.eval_metric_func is None:
return -1
# return 1st metric in the dict
return output_info[metric_key].avg
if "ATTRMetric" in engine.config["Metric"]["Eval"][0]:
metric_msg = ", ".join([
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}".
format(*engine.eval_metric_func.attr_res())
])
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
# do not try to save best eval.model
if engine.eval_metric_func is None:
return -1
# return 1st metric in the dict
return engine.eval_metric_func.attr_res()[0]
else:
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg)
for key in output_info
])
metric_msg += ", {}".format(engine.eval_metric_func.avg_info)
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
# do not try to save best eval.model
if engine.eval_metric_func is None:
return -1
# return 1st metric in the dict
return engine.eval_metric_func.avg
......@@ -3,16 +3,29 @@ import paddle.nn as nn
import paddle.nn.functional as F
def ratio2weight(targets, ratio):
pos_weights = targets * (1. - ratio)
neg_weights = (1. - targets) * ratio
weights = paddle.exp(neg_weights + pos_weights)
# for RAP dataloader, targets element may be 2, with or without smooth, some element must great than 1
weights = weights - weights * (targets > 1)
return weights
class MultiLabelLoss(nn.Layer):
"""
Multi-label loss
"""
def __init__(self, epsilon=None):
def __init__(self, epsilon=None, size_sum=False, weight_ratio=False):
super().__init__()
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
self.epsilon = epsilon
self.weight_ratio = weight_ratio
self.size_sum = size_sum
def _labelsmoothing(self, target, class_num):
if target.ndim == 1 or target.shape[-1] != class_num:
......@@ -24,13 +37,21 @@ class MultiLabelLoss(nn.Layer):
return soft_target
def _binary_crossentropy(self, input, target, class_num):
if self.weight_ratio:
target, label_ratio = target[:, 0, :], target[:, 1, :]
if self.epsilon is not None:
target = self._labelsmoothing(target, class_num)
cost = F.binary_cross_entropy_with_logits(
logit=input, label=target)
else:
cost = F.binary_cross_entropy_with_logits(
logit=input, label=target)
cost = F.binary_cross_entropy_with_logits(
logit=input, label=target, reduction='none')
if self.weight_ratio:
targets_mask = paddle.cast(target > 0.5, 'float32')
weight = ratio2weight(targets_mask, paddle.to_tensor(label_ratio))
weight = weight * (target > -1)
cost = cost * weight
if self.size_sum:
cost = cost.sum(1).mean() if self.size_sum else cost.mean()
return cost
......
......@@ -12,17 +12,19 @@
#See the License for the specific language governing permissions and
#limitations under the License.
from paddle import nn
import copy
from collections import OrderedDict
from .avg_metrics import AvgMetrics
from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk
from .metrics import DistillationTopkAcc
from .metrics import GoogLeNetTopkAcc
from .metrics import HammingDistance, AccuracyScore
from .metrics import ATTRMetric
from .metrics import TprAtFpr
class CombinedMetrics(nn.Layer):
class CombinedMetrics(AvgMetrics):
def __init__(self, config_list):
super().__init__()
self.metric_func_list = []
......@@ -38,13 +40,30 @@ class CombinedMetrics(nn.Layer):
eval(metric_name)(**metric_params))
else:
self.metric_func_list.append(eval(metric_name)())
self.reset()
def __call__(self, *args, **kwargs):
def forward(self, *args, **kwargs):
metric_dict = OrderedDict()
for idx, metric_func in enumerate(self.metric_func_list):
metric_dict.update(metric_func(*args, **kwargs))
return metric_dict
@property
def avg_info(self):
return ", ".join([metric.avg_info for metric in self.metric_func_list])
@property
def avg(self):
return self.metric_func_list[0].avg
def attr_res(self):
return self.metric_func_list[0].attrmeter.res()
def reset(self):
for metric in self.metric_func_list:
if hasattr(metric, "reset"):
metric.reset()
def build_metrics(config):
metrics_list = CombinedMetrics(copy.deepcopy(config))
......
from paddle import nn
class AvgMetrics(nn.Layer):
def __init__(self):
super().__init__()
self.avg_meters = {}
def reset(self):
self.avg_meters = {}
@property
def avg(self):
if self.avg_meters:
for metric_key in self.avg_meters:
return self.avg_meters[metric_key].avg
@property
def avg_info(self):
return ", ".join([self.avg_meters[key].avg_info for key in self.avg_meters])
......@@ -22,14 +22,26 @@ from sklearn.metrics import accuracy_score as accuracy_metric
from sklearn.metrics import multilabel_confusion_matrix
from sklearn.preprocessing import binarize
from easydict import EasyDict
class TopkAcc(nn.Layer):
from ppcls.metric.avg_metrics import AvgMetrics
from ppcls.utils.misc import AverageMeter, AttrMeter
class TopkAcc(AvgMetrics):
def __init__(self, topk=(1, 5)):
super().__init__()
assert isinstance(topk, (int, list, tuple))
if isinstance(topk, int):
topk = [topk]
self.topk = topk
self.reset()
def reset(self):
self.avg_meters = {
"top{}".format(k): AverageMeter("top{}".format(k))
for k in self.topk
}
def forward(self, x, label):
if isinstance(x, dict):
......@@ -39,6 +51,7 @@ class TopkAcc(nn.Layer):
for k in self.topk:
metric_dict["top{}".format(k)] = paddle.metric.accuracy(
x, label, k=k)
self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0])
return metric_dict
......@@ -108,7 +121,7 @@ class mINP(nn.Layer):
choosen_indices)
equal_flag = paddle.equal(choosen_label, query_img_id)
if keep_mask is not None:
keep_mask = paddle.index_sample(
keep_mask = paddle.indechmx_sample(
keep_mask.astype('float32'), choosen_indices)
equal_flag = paddle.logical_and(equal_flag,
keep_mask.astype('bool'))
......@@ -131,6 +144,61 @@ class mINP(nn.Layer):
return metric_dict
class TprAtFpr(nn.Layer):
def __init__(self, max_fpr=1 / 1000.):
super().__init__()
self.gt_pos_score_list = []
self.gt_neg_score_list = []
self.softmax = nn.Softmax(axis=-1)
self.max_fpr = max_fpr
self.max_tpr = 0.
def forward(self, x, label):
if isinstance(x, dict):
x = x["logits"]
x = self.softmax(x)
for i, label_i in enumerate(label):
if label_i[0] == 0:
self.gt_neg_score_list.append(x[i][1].numpy())
else:
self.gt_pos_score_list.append(x[i][1].numpy())
return {}
def reset(self):
self.gt_pos_score_list = []
self.gt_neg_score_list = []
self.max_tpr = 0.
@property
def avg(self):
return self.max_tpr
@property
def avg_info(self):
max_tpr = 0.
result = ""
gt_pos_score_list = np.array(self.gt_pos_score_list)
gt_neg_score_list = np.array(self.gt_neg_score_list)
for i in range(0, 10000):
threshold = i / 10000.
if len(gt_pos_score_list) == 0:
continue
tpr = np.sum(
gt_pos_score_list > threshold) / len(gt_pos_score_list)
if len(gt_neg_score_list) == 0 and tpr > max_tpr:
max_tpr = tpr
result = "threshold: {}, fpr: {}, tpr: {:.5f}".format(
threshold, fpr, tpr)
fpr = np.sum(
gt_neg_score_list > threshold) / len(gt_neg_score_list)
if fpr <= self.max_fpr and tpr > max_tpr:
max_tpr = tpr
result = "threshold: {}, fpr: {}, tpr: {:.5f}".format(
threshold, fpr, tpr)
self.max_tpr = max_tpr
return result
class Recallk(nn.Layer):
def __init__(self, topk=(1, 5), descending=True):
super().__init__()
......@@ -245,20 +313,17 @@ class GoogLeNetTopkAcc(TopkAcc):
return super().forward(x[0], label)
class MutiLabelMetric(object):
def __init__(self):
pass
def _multi_hot_encode(self, logits, threshold=0.5):
return binarize(logits, threshold=threshold)
class MultiLabelMetric(AvgMetrics):
def __init__(self, bi_threshold=0.5):
super().__init__()
self.bi_threshold = bi_threshold
def __call__(self, output):
output = F.sigmoid(output)
preds = self._multi_hot_encode(logits=output.numpy(), threshold=0.5)
return preds
def _multi_hot_encode(self, output):
logits = F.sigmoid(output).numpy()
return binarize(logits, threshold=self.bi_threshold)
class HammingDistance(MutiLabelMetric):
class HammingDistance(MultiLabelMetric):
"""
Soft metric based label for multilabel classification
Returns:
......@@ -267,16 +332,22 @@ class HammingDistance(MutiLabelMetric):
def __init__(self):
super().__init__()
self.reset()
def reset(self):
self.avg_meters = {"HammingDistance": AverageMeter("HammingDistance")}
def __call__(self, output, target):
preds = super().__call__(output)
def forward(self, output, target):
preds = super()._multi_hot_encode(output)
metric_dict = dict()
metric_dict["HammingDistance"] = paddle.to_tensor(
hamming_loss(target, preds))
self.avg_meters["HammingDistance"].update(
metric_dict["HammingDistance"].numpy()[0], output.shape[0])
return metric_dict
class AccuracyScore(MutiLabelMetric):
class AccuracyScore(MultiLabelMetric):
"""
Hard metric for multilabel classification
Args:
......@@ -292,9 +363,13 @@ class AccuracyScore(MutiLabelMetric):
assert base in ["sample", "label"
], 'must be one of ["sample", "label"]'
self.base = base
self.reset()
def __call__(self, output, target):
preds = super().__call__(output)
def reset(self):
self.avg_meters = {"AccuracyScore": AverageMeter("AccuracyScore")}
def forward(self, output, target):
preds = super()._multi_hot_encode(output)
metric_dict = dict()
if self.base == "sample":
accuracy = accuracy_metric(target, preds)
......@@ -307,4 +382,66 @@ class AccuracyScore(MutiLabelMetric):
accuracy = (sum(tps) + sum(tns)) / (
sum(tps) + sum(tns) + sum(fns) + sum(fps))
metric_dict["AccuracyScore"] = paddle.to_tensor(accuracy)
self.avg_meters["AccuracyScore"].update(
metric_dict["AccuracyScore"].numpy()[0], output.shape[0])
return metric_dict
def get_attr_metrics(gt_label, preds_probs, threshold):
"""
index: evaluated label index
"""
pred_label = (preds_probs > threshold).astype(int)
eps = 1e-20
result = EasyDict()
has_fuyi = gt_label == -1
pred_label[has_fuyi] = -1
###############################
# label metrics
# TP + FN
result.gt_pos = np.sum((gt_label == 1), axis=0).astype(float)
# TN + FP
result.gt_neg = np.sum((gt_label == 0), axis=0).astype(float)
# TP
result.true_pos = np.sum((gt_label == 1) * (pred_label == 1),
axis=0).astype(float)
# TN
result.true_neg = np.sum((gt_label == 0) * (pred_label == 0),
axis=0).astype(float)
# FP
result.false_pos = np.sum(((gt_label == 0) * (pred_label == 1)),
axis=0).astype(float)
# FN
result.false_neg = np.sum(((gt_label == 1) * (pred_label == 0)),
axis=0).astype(float)
################
# instance metrics
result.gt_pos_ins = np.sum((gt_label == 1), axis=1).astype(float)
result.true_pos_ins = np.sum((pred_label == 1), axis=1).astype(float)
# true positive
result.intersect_pos = np.sum((gt_label == 1) * (pred_label == 1),
axis=1).astype(float)
# IOU
result.union_pos = np.sum(((gt_label == 1) + (pred_label == 1)),
axis=1).astype(float)
return result
class ATTRMetric(nn.Layer):
def __init__(self, threshold=0.5):
super().__init__()
self.threshold = threshold
def reset(self):
self.attrmeter = AttrMeter(threshold=0.5)
def forward(self, output, target):
metric_dict = get_attr_metrics(target[:, 0, :].numpy(),
output.numpy(), self.threshold)
self.attrmeter.update(metric_dict)
return metric_dict
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
__all__ = ['AverageMeter']
......@@ -42,6 +44,12 @@ class AverageMeter(object):
self.count += n
self.avg = self.sum / self.count
@property
def avg_info(self):
if isinstance(self.avg, paddle.Tensor):
self.avg = self.avg.numpy()[0]
return "{}: {:.5f}".format(self.name, self.avg)
@property
def total(self):
return '{self.name}_sum: {self.sum:{self.fmt}}{self.postfix}'.format(
......@@ -61,3 +69,87 @@ class AverageMeter(object):
def value(self):
return '{self.name}: {self.val:{self.fmt}}{self.postfix}'.format(
self=self)
class AttrMeter(object):
"""
Computes and stores the average and current value
Code was based on https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
def __init__(self, threshold=0.5):
self.threshold = threshold
self.reset()
def reset(self):
self.gt_pos = 0
self.gt_neg = 0
self.true_pos = 0
self.true_neg = 0
self.false_pos = 0
self.false_neg = 0
self.gt_pos_ins = []
self.true_pos_ins = []
self.intersect_pos = []
self.union_pos = []
def update(self, metric_dict):
self.gt_pos += metric_dict['gt_pos']
self.gt_neg += metric_dict['gt_neg']
self.true_pos += metric_dict['true_pos']
self.true_neg += metric_dict['true_neg']
self.false_pos += metric_dict['false_pos']
self.false_neg += metric_dict['false_neg']
self.gt_pos_ins += metric_dict['gt_pos_ins'].tolist()
self.true_pos_ins += metric_dict['true_pos_ins'].tolist()
self.intersect_pos += metric_dict['intersect_pos'].tolist()
self.union_pos += metric_dict['union_pos'].tolist()
def res(self):
import numpy as np
eps = 1e-20
label_pos_recall = 1.0 * self.true_pos / (
self.gt_pos + eps) # true positive
label_neg_recall = 1.0 * self.true_neg / (
self.gt_neg + eps) # true negative
# mean accuracy
label_ma = (label_pos_recall + label_neg_recall) / 2
label_pos_recall = np.mean(label_pos_recall)
label_neg_recall = np.mean(label_neg_recall)
label_prec = (self.true_pos / (self.true_pos + self.false_pos + eps))
label_acc = (self.true_pos /
(self.true_pos + self.false_pos + self.false_neg + eps))
label_f1 = np.mean(2 * label_prec * label_pos_recall /
(label_prec + label_pos_recall + eps))
ma = (np.mean(label_ma))
self.gt_pos_ins = np.array(self.gt_pos_ins)
self.true_pos_ins = np.array(self.true_pos_ins)
self.intersect_pos = np.array(self.intersect_pos)
self.union_pos = np.array(self.union_pos)
instance_acc = self.intersect_pos / (self.union_pos + eps)
instance_prec = self.intersect_pos / (self.true_pos_ins + eps)
instance_recall = self.intersect_pos / (self.gt_pos_ins + eps)
instance_f1 = 2 * instance_prec * instance_recall / (
instance_prec + instance_recall + eps)
instance_acc = np.mean(instance_acc)
instance_prec = np.mean(instance_prec)
instance_recall = np.mean(instance_recall)
instance_f1 = 2 * instance_prec * instance_recall / (
instance_prec + instance_recall + eps)
instance_acc = np.mean(instance_acc)
instance_prec = np.mean(instance_prec)
instance_recall = np.mean(instance_recall)
instance_f1 = np.mean(instance_f1)
res = [
ma, label_f1, label_pos_recall, label_neg_recall, instance_f1,
instance_acc, instance_prec, instance_recall
]
return res
......@@ -42,6 +42,14 @@ def _mkdir_if_not_exist(path):
raise OSError('Failed to mkdir {}'.format(path))
def _extract_student_weights(all_params, student_prefix="Student."):
s_params = {
key[len(student_prefix):]: all_params[key]
for key in all_params if student_prefix in key
}
return s_params
def load_dygraph_pretrain(model, path=None):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {}.pdparams does not "
......@@ -127,7 +135,8 @@ def save_model(net,
model_path,
model_name="",
prefix='ppcls',
loss: paddle.nn.Layer=None):
loss: paddle.nn.Layer=None,
save_student_model=False):
"""
save model to the target path
"""
......@@ -138,11 +147,18 @@ def save_model(net,
model_path = os.path.join(model_path, prefix)
params_state_dict = net.state_dict()
loss_state_dict = loss.state_dict()
keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys())
assert len(keys_inter) == 0, \
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
params_state_dict.update(loss_state_dict)
if loss is not None:
loss_state_dict = loss.state_dict()
keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys(
))
assert len(keys_inter) == 0, \
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
params_state_dict.update(loss_state_dict)
if save_student_model:
s_params = _extract_student_weights(params_state_dict)
if len(s_params) > 0:
paddle.save(s_params, model_path + "_student.pdparams")
paddle.save(params_state_dict, model_path + ".pdparams")
paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt")
......
......@@ -9,3 +9,4 @@ scipy
scikit-learn==0.23.2
gast==0.3.3
faiss-cpu==1.7.1.post2
easydict
#!/usr/bin/env bash
GPU_IDS="0,1,2,3"
# Basic Config
CONFIG="ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml"
EPOCHS=1
OUTPUT="output_debug4"
STATUS_LOG="${OUTPUT}/status_result.log"
RESULT="${OUTPUT}/result.log"
# Search Options
LR_LIST=( 0.0075 0.01 0.0125 )
RESOLUTION_LIST=( 176 192 224 )
RA_PROB_LIST=( 0.0 0.1 0.5 )
RE_PROB_LIST=( 0.0 0.1 0.5 )
LR_MULT_LIST=( [0.0,0.2,0.4,0.6,0.8,1.0] [0.0,0.4,0.4,0.8,0.8,1.0] )
TEACHER_LIST=( "ResNet101_vd" "ResNet50_vd" )
# Train Mode
declare -A MODE_MAP
MODE_MAP=(["search_lr"]=1 ["search_resolution"]=1 ["search_ra_prob"]=1 ["search_re_prob"]=1 ["search_lr_mult_list"]=1 ["search_teacher"]=1 ["train_distillation_model"]=1)
export CUDA_VISIBLE_DEVICES=${GPU_IDS}
function status_check(){
last_status=$1 # the exit code
run_command=$2
run_log=$3
if [ $last_status -eq 0 ]; then
echo -e "\033[33m Run successfully with command - ${run_command}! \033[0m" | tee -a ${run_log}
else
echo -e "\033[33m Run failed with command - ${run_command}! \033[0m" | tee -a ${run_log}
fi
}
function get_max_value(){
array=($*)
max=${array[0]}
index=0
for (( i=0; i<${#array[*]-1}; i++ )); do
if [[ $(echo "${array[$i]} > $max"|bc) -eq 1 ]]; then
max=${array[$i]}
index=${i}
else
continue
fi
done
echo ${max}
echo ${index}
}
function get_best_info(){
_parameter=$1
params_index=2
if [[ ${_parameter} == "TEACHER" ]]; then
params_index=3
fi
parameters_list=$(find ${OUTPUT}/${_parameter}* -name train.log | awk -v params_index=${params_index} -F "/" '{print $params_index}')
metric_list=$(find ${OUTPUT}/${_parameter}* -name train.log | xargs cat | grep "best" | grep "Epoch ${EPOCHS}" | awk -F " " '{print substr($NF,0,7)}')
best_info=$(get_max_value ${metric_list[*]})
best_metric=$(echo $best_info | awk -F " " '{print $1}')
best_index=$(echo $best_info | awk -F " " '{print $2}')
best_parameter=$(echo $parameters_list | awk -v best=$(($best_index+1)) '{print $best}' | awk -F "_" '{print $2}')
echo ${best_metric}
echo ${best_parameter}
}
function search_lr(){
for lr in ${LR_LIST[*]}; do
cmd_train="python3.7 -m paddle.distributed.launch --gpus=${GPU_IDS} tools/train.py \
-c ${CONFIG} \
-o Global.output_dir=${OUTPUT}/LR_${lr} \
-o Optimizer.lr.learning_rate=${lr} \
-o Global.epochs=${EPOCHS}"
eval ${cmd_train}
status_check $? "${cmd_train}" "${STATUS_LOG}"
cmd="find ${OUTPUT} -name epoch* | xargs rm -rf"
eval ${cmd}
done
}
function search_resolution(){
_lr=$1
for resolution in ${RESOLUTION_LIST[*]}; do
cmd_train="python3.7 -m paddle.distributed.launch --gpus=${GPU_IDS} tools/train.py \
-c ${CONFIG} \
-o Global.output_dir=${OUTPUT}/RESOLUTION_${resolution} \
-o Optimizer.lr.learning_rate=${_lr} \
-o Global.epochs=${EPOCHS} \
-o DataLoader.Train.dataset.transform_ops.1.RandCropImage.size=${resolution}"
eval ${cmd_train}
status_check $? "${cmd_train}" "${STATUS_LOG}"
cmd="find ${OUTPUT} -name epoch* | xargs rm -rf"
eval ${cmd}
done
}
function search_ra_prob(){
_lr=$1
_resolution=$2
for ra_prob in ${RA_PROB_LIST[*]}; do
cmd_train="python3.7 -m paddle.distributed.launch --gpus=${GPU_IDS} tools/train.py \
-c ${CONFIG} \
-o Global.output_dir=${OUTPUT}/RA_${ra_prob} \
-o Optimizer.lr.learning_rate=${_lr} \
-o Global.epochs=${EPOCHS} \
-o DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.prob=${ra_prob} \
-o DataLoader.Train.dataset.transform_ops.1.RandCropImage.size=${_resolution} \
-o DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.img_size=${_resolution}"
eval ${cmd_train}
status_check $? "${cmd_train}" "${STATUS_LOG}"
cmd="find ${OUTPUT} -name epoch* | xargs rm -rf"
eval ${cmd}
done
}
function search_re_prob(){
_lr=$1
_resolution=$2
_ra_prob=$3
for re_prob in ${RE_PROB_LIST[*]}; do
cmd_train="python3.7 -m paddle.distributed.launch --gpus=${GPU_IDS} tools/train.py \
-c ${CONFIG} \
-o Global.output_dir=${OUTPUT}/RE_${re_prob} \
-o Optimizer.lr.learning_rate=${_lr} \
-o Global.epochs=${EPOCHS} \
-o DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.prob=${_ra_prob} \
-o DataLoader.Train.dataset.transform_ops.5.RandomErasing.EPSILON=${re_prob} \
-o DataLoader.Train.dataset.transform_ops.1.RandCropImage.size=${_resolution} \
-o DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.img_size=${_resolution}"
eval ${cmd_train}
status_check $? "${cmd_train}" "${STATUS_LOG}"
cmd="find ${OUTPUT} -name epoch* | xargs rm -rf"
eval ${cmd}
done
}
function search_lr_mult_list(){
_lr=$1
_resolution=$2
_ra_prob=$3
_re_prob=$4
for lr_mult in ${LR_MULT_LIST[*]}; do
cmd_train="python3.7 -m paddle.distributed.launch --gpus=${GPU_IDS} tools/train.py \
-c ${CONFIG} \
-o Global.output_dir=${OUTPUT}/LR_MULT_${lr_mult} \
-o Optimizer.lr.learning_rate=${_lr} \
-o Global.epochs=${EPOCHS} \
-o DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.prob=${_ra_prob} \
-o DataLoader.Train.dataset.transform_ops.5.RandomErasing.EPSILON=${_re_prob} \
-o DataLoader.Train.dataset.transform_ops.1.RandCropImage.size=${_resolution} \
-o DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.img_size=${_resolution} \
-o Arch.lr_mult_list=${lr_mult}"
eval ${cmd_train}
status_check $? "${cmd_train}" "${STATUS_LOG}"
cmd="find ${OUTPUT} -name epoch* | xargs rm -rf"
eval ${cmd}
done
}
function search_teacher(){
_lr=$1
_resolution=$2
_ra_prob=$3
_re_prob=$4
for teacher in ${TEACHER_LIST[*]}; do
cmd_train="python3.7 -m paddle.distributed.launch --gpus=${GPU_IDS} tools/train.py \
-c ${CONFIG} \
-o Global.output_dir=${OUTPUT}/TEACHER_${teacher} \
-o Optimizer.lr.learning_rate=${_lr} \
-o Global.epochs=${EPOCHS} \
-o DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.prob=${_ra_prob} \
-o DataLoader.Train.dataset.transform_ops.5.RandomErasing.EPSILON=${_re_prob} \
-o DataLoader.Train.dataset.transform_ops.1.RandCropImage.size=${_resolution} \
-o DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.img_size=${_resolution} \
-o Arch.name=${teacher}"
eval ${cmd_train}
status_check $? "${cmd_train}" "${STATUS_LOG}"
cmd="find ${OUTPUT}/* -name epoch* | xargs rm -rf"
eval ${cmd}
done
}
# train the model for knowledge distillation
function train_distillation_model(){
_lr=$1
_resolution=$2
_ra_prob=$3
_re_prob=$4
_lr_mult=$5
teacher=$6
t_pretrained_model="${OUTPUT}/TEACHER_${teacher}/${teacher}/best_model"
config="ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml"
combined_label_list="./dataset/person/train_list_for_distill.txt"
cmd_train="python3.7 -m paddle.distributed.launch \
--gpus=${GPU_IDS} \
tools/train.py -c ${config} \
-o Global.output_dir=${OUTPUT}/kd_teacher \
-o Optimizer.lr.learning_rate=${_lr} \
-o Global.epochs=${EPOCHS} \
-o DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.prob=${_ra_prob} \
-o DataLoader.Train.dataset.transform_ops.5.RandomErasing.EPSILON=${_re_prob} \
-o DataLoader.Train.dataset.transform_ops.1.RandCropImage.size=${_resolution} \
-o DataLoader.Train.dataset.transform_ops.3.TimmAutoAugment.img_size=${_resolution} \
-o DataLoader.Train.dataset.cls_label_path=${combined_label_list} \
-o Arch.models.0.Teacher.name="${teacher}" \
-o Arch.models.0.Teacher.pretrained="${t_pretrained_model}" \
-o Arch.models.1.Student.lr_mult_list=${_lr_mult}"
eval ${cmd_train}
status_check $? "${cmd_train}" "${STATUS_LOG}"
cmd="find ${OUTPUT} -name epoch* | xargs rm -rf"
eval ${cmd}
}
######## Train PaddleClas ########
rm -rf ${OUTPUT}
# Train and get best lr
best_lr=0.01
if [[ ${MODE_MAP["search_lr"]} -eq 1 ]]; then
search_lr
best_info=$(get_best_info "LR_[0-9]")
best_metric=$(echo $best_info | awk -F " " '{print $1}')
best_lr=$(echo $best_info | awk -F " " '{print $2}')
echo "The best lr is ${best_lr}, and the best metric is ${best_metric}" >> ${RESULT}
fi
# Train and get best resolution
best_resolution=192
if [[ ${MODE_MAP["search_resolution"]} -eq 1 ]]; then
search_resolution "${best_lr}"
best_info=$(get_best_info "RESOLUTION")
best_metric=$(echo $best_info | awk -F " " '{print $1}')
best_resolution=$(echo $best_info | awk -F " " '{print $2}')
echo "The best resolution is ${best_resolution}, and the best metric is ${best_metric}" >> ${RESULT}
fi
# Train and get best ra_prob
best_ra_prob=0.0
if [[ ${MODE_MAP["search_ra_prob"]} -eq 1 ]]; then
search_ra_prob "${best_lr}" "${best_resolution}"
best_info=$(get_best_info "RA")
best_metric=$(echo $best_info | awk -F " " '{print $1}')
best_ra_prob=$(echo $best_info | awk -F " " '{print $2}')
echo "The best ra_prob is ${best_ra_prob}, and the best metric is ${best_metric}" >> ${RESULT}
fi
# Train and get best re_prob
best_re_prob=0.1
if [[ ${MODE_MAP["search_re_prob"]} -eq 1 ]]; then
search_re_prob "${best_lr}" "${best_resolution}" "${best_ra_prob}"
best_info=$(get_best_info "RE")
best_metric=$(echo $best_info | awk -F " " '{print $1}')
best_re_prob=$(echo $best_info | awk -F " " '{print $2}')
echo "The best re_prob is ${best_re_prob}, and the best metric is ${best_metric}" >> ${RESULT}
fi
# Train and get best lr_mult_list
best_lr_mult_list=[1.0,1.0,1.0,1.0,1.0,1.0]
if [[ ${MODE_MAP["search_lr_mult_list"]} -eq 1 ]]; then
search_lr_mult_list "${best_lr}" "${best_resolution}" "${best_ra_prob}" "${best_re_prob}"
best_info=$(get_best_info "LR_MULT")
best_metric=$(echo $best_info | awk -F " " '{print $1}')
best_lr_mult_list=$(echo $best_info | awk -F " " '{print $2}')
echo "The best lr_mult_list is ${best_lr_mult_list}, and the best metric is ${best_metric}" >> ${RESULT}
fi
# train and get best teacher
best_teacher="ResNet101_vd"
if [[ ${MODE_MAP["search_teacher"]} -eq 1 ]]; then
search_teacher "${best_lr}" "${best_resolution}" "${best_ra_prob}" "${best_re_prob}"
best_info=$(get_best_info "TEACHER")
best_metric=$(echo $best_info | awk -F " " '{print $1}')
best_teacher=$(echo $best_info | awk -F " " '{print $2}')
echo "The best teacher is ${best_teacher}, and the best metric is ${best_metric}" >> ${RESULT}
fi
# train the distillation model
if [[ ${MODE_MAP["train_distillation_model"]} -eq 1 ]]; then
train_distillation_model "${best_lr}" "${best_resolution}" "${best_ra_prob}" "${best_re_prob}" "${best_lr_mult_list}" ${best_teacher}
best_info=$(get_best_info "kd_teacher/DistillationModel")
best_metric=$(echo $best_info | awk -F " " '{print $1}')
echo "the distillation best metric is ${best_metric}, it is global best metric!" >> ${RESULT}
fi
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
import subprocess
import numpy as np
from ppcls.utils import config
def get_result(log_dir):
log_file = "{}/train.log".format(log_dir)
with open(log_file, "r") as f:
raw = f.read()
res = float(raw.split("best metric: ")[-1].split("]")[0])
return res
def search_train(search_list, base_program, base_output_dir, search_key,
config_replace_value, model_name, search_times=1):
best_res = 0.
best = search_list[0]
all_result = {}
for search_i in search_list:
program = base_program.copy()
for v in config_replace_value:
program += ["-o", "{}={}".format(v, search_i)]
if v == "Arch.name":
model_name = search_i
res_list = []
for j in range(search_times):
output_dir = "{}/{}_{}_{}".format(base_output_dir, search_key, search_i, j).replace(".", "_")
program += ["-o", "Global.output_dir={}".format(output_dir)]
process = subprocess.Popen(program)
process.communicate()
res = get_result("{}/{}".format(output_dir, model_name))
res_list.append(res)
all_result[str(search_i)] = res_list
if np.mean(res_list) > best_res:
best = search_i
best_res = np.mean(res_list)
all_result["best"] = best
return all_result
def search_strategy():
args = config.parse_args()
configs = config.get_config(args.config, overrides=args.override, show=False)
base_config_file = configs["base_config_file"]
distill_config_file = configs["distill_config_file"]
model_name = config.get_config(base_config_file)["Arch"]["name"]
gpus = configs["gpus"]
gpus = ",".join([str(i) for i in gpus])
base_program = ["python3.7", "-m", "paddle.distributed.launch", "--gpus={}".format(gpus),
"tools/train.py", "-c", base_config_file]
base_output_dir = configs["output_dir"]
search_times = configs["search_times"]
search_dict = configs.get("search_dict")
all_results = {}
for search_i in search_dict:
search_key = search_i["search_key"]
search_values = search_i["search_values"]
replace_config = search_i["replace_config"]
res = search_train(search_values, base_program, base_output_dir,
search_key, replace_config, model_name, search_times)
all_results[search_key] = res
best = res.get("best")
for v in replace_config:
base_program += ["-o", "{}={}".format(v, best)]
teacher_configs = configs.get("teacher", None)
if teacher_configs is not None:
teacher_program = base_program.copy()
# remove incompatible keys
teacher_rm_keys = teacher_configs["rm_keys"]
rm_indices = []
for rm_k in teacher_rm_keys:
for ind, ki in enumerate(base_program):
if rm_k in ki:
rm_indices.append(ind)
for rm_index in rm_indices[::-1]:
teacher_program.pop(rm_index)
teacher_program.pop(rm_index-1)
replace_config = ["Arch.name"]
teacher_list = teacher_configs["search_values"]
res = search_train(teacher_list, teacher_program, base_output_dir, "teacher", replace_config, model_name)
all_results["teacher"] = res
best = res.get("best")
t_pretrained = "{}/{}_{}_0/{}/best_model".format(base_output_dir, "teacher", best, best)
base_program += ["-o", "Arch.models.0.Teacher.name={}".format(best),
"-o", "Arch.models.0.Teacher.pretrained={}".format(t_pretrained)]
output_dir = "{}/search_res".format(base_output_dir)
base_program += ["-o", "Global.output_dir={}".format(output_dir)]
final_replace = configs.get('final_replace')
for i in range(len(base_program)):
base_program[i] = base_program[i].replace(base_config_file, distill_config_file)
for k in final_replace:
v = final_replace[k]
base_program[i] = base_program[i].replace(k, v)
process = subprocess.Popen(base_program)
process.communicate()
print(all_results, base_program)
if __name__ == '__main__':
search_strategy()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册