提交 e8f8acd4 编写于 作者: J Jiancheng Li 提交者: whs

增加基于硬件耗时搜索功能及模型硬件耗时评估工具 (#3039)

* minor fix

* add getting ops from program

* add obtaining op params from paddle program

* add code reusage of light_nas_space

* add comments to get_ops_from_program

* add comments to get_ops_from_program

* fixed bugs

* add get_all_ops

* add get_op_latency

* add platform parameter for get_op_latency

* add get_all_ops

* add get ops from shortcut and se

* bug fixed

* fixed bugs

* add refined code

* bugs fixed

* useless test codes and op files removed

* modifying 7 to length of searched bottlenecks

* bugs fixed

* add softmax op in get_ops_from_program

* remove resize ops in get_ops_from_program

* add get_op_latency from paddle-mobile and modify getting fc op params

* add fc op params

* add batch norm in light_nas_space

* bugs fixed

* update light_nas_space: add get_model_latency

* bug fixed

* bug fixed

* bug fixed

* update get_latency_lookup_table

* add bias of fc layer in get_all_ops

* add docs for light-nas with latency

* update doc for light-nas

* add first part of method descriptions for light-nas based on android and ios

* add descriptions for lightnas basd on android and ios in demo and usage and modify get_latency_lookup_table

* fix errors in usage.md

* fix errors among the contents of android system in usage.md

* fix errors among the contents of ios system in usage.md

* add descriptions in demo.md about lightnas based on android and ios

* add auxillary methods for generating bin files and .a files for op latency test

* add links for latency estimation tools

* minor fix in usage.md

* minor errors fixed in usage.md

* update LightNAS usage

* update docs

* update get_latency_lookup_table.

* add information for get latency lookup table

* add current searched models and their corresponding accuracies for android and ios systems

* bugs in fc_op_prams are fixed in get_ops_from_program.py test=develop

* update docs

* update docs

* update docs

* update PaddleSlim results

* add more lr_strategy

* update docs

* update PaddleSlim results

* update docs

* minor update

* minor update

* modify get_latency_lookup_table.py to make it compatible for python2 and python3; modify usage.md to update the information of PaddleMobile to Paddle-Lite

* minor update to be compatible for py3

* modify get_latency_lookup_table.py to make it compatible for python2 and python3

* convert float number in get_all_ops to int number in order to make it compatible for both python2 and python3

* remove print in get_latency_lookup_table.py

* update docs

* fixed bugs in get_all_ops

* fix merge conflicts

* fix merge conflicts
上级 10cf8081
Subproject commit 546ce5e967027f37dac10a472d9800025201ffe7
Subproject commit 31e0c630508aea49432c597809028436ed35e573
Subproject commit a884635519c529c69c34e1134ca6c9d99f2c0007
Subproject commit 89c3366bb03af74760796802adb8a1efb464f0c7
......@@ -106,6 +106,9 @@ Paddle-Slim工具库有以下特点:
### 轻量神经网络结构自动搜索(Light-NAS)
- 支持基于进化算法的轻量神经网络结构自动搜索(Light-NAS)
- 支持分布式搜索
- 支持 FLOPS / 硬件延时约束
- 支持多平台模型延时评估
### 其它功能
......@@ -186,11 +189,21 @@ Paddle-Slim工具库有以下特点:
数据:ImageNet 1000类
| |Light-NAS-model0 |Light-NAS-model1 |MobileNetV2 |
|---|---|---|---|
|FLOPS|-3%|-17%|-0%|
|top1/top5 accuracy|72.45%/90.70%|71.84%/90.45%|71.90%/90.55% |
|GPU cost|1.2K GPU hours|1.2K GPU hours|-|
| - | FLOPS | Top1/Top5 accuracy | GPU cost |
|------------------|-------|--------------------|----------------------|
| MobileNetV2 | 0% | 71.90% / 90.55% | - |
| Light-NAS-model0 | -3% | 72.45% / 90.70% | 1.2K GPU hours(V100) |
| Light-NAS-model1 | -17% | 71.84% / 90.45% | 1.2K GPU hours(V100) |
基于硬件耗时的模型结构搜索实验:
| - | Latency | Top1/Top5 accuracy | GPU cost |
|---------------|---------|--------------------|---------------------|
| MobileNetV2 | 0% | 71.90% / 90.55% | - |
| RK3288 开发板 | -23% | 71.97% / 90.35% | 1.2K GPU hours(V100) |
| Android 手机 | -20% | 72.06% / 90.36% | 1.2K GPU hours(V100) |
| iPhone 手机 | -17% | 72.22% / 90.47% | 1.2K GPU hours(V100) |
## 模型导出格式
......
......@@ -347,29 +347,45 @@ python compress.py \
step1: 进入路径`PaddlePaddle/models/PaddleSlim/light_nas/`
step2: 在当前路径下,新建软链接指向上级目录的data: `ln -s ../data data`
step2: (可选)按照[使用手册](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#245-延时评估器生成方式)中说明的方法生成好延时评估器表格 `latency_lookup_table.txt`,放置到当前路径
step3: 修改`compress.xml`文件, 将参数server_ip设置为当前机器的ip
step3: 在当前路径下,新建软链接指向上级目录的data: `ln -s ../data data`
step4: 执行`sh run.sh`, 可根据实际情况修改`run.sh`中的`CUDA_VISIBLE_DEVICES`
step4: 修改 `compress.yaml` 文件, 将参数 `server_ip` 设置为当前机器的 IP
step5: 修改`light_nas_space.py`文件中的`LightNASSpace::init_tokens`, 使其返回step4中搜到的最优tokens
step5: (可选)修改 `compress.yaml` 文件,将参数 `target_latency` 设置为用户的目标延时
step6: 修改`compress.xml`文件,将compressor下的`strategies`去掉
step6: 执行 `sh run.sh`, 可根据实际情况修改 `run.sh` 中的 `CUDA_VISIBLE_DEVICES`
step7: 执行`sh run.sh`进行训练任务
step7: 修改 `light_nas_space.py` 文件中的 `LightNASSpace::init_tokens`, 使其返回step6中搜到的最优tokens
该示例两组结果如下:
step8: 修改 `compress.yaml` 文件,将 `compressor` 下的 `strategies` 去掉。
| - | Light-NAS-model0| Light-NAS-model1 | MobileNetV2 |
|---|---|---|---|
| FLOPS|-3% | -17% | -0% |
| top1 accuracy| 72.45%| 71.84%| 71.90% |
|GPU cost|1.2K GPU hours(V100)|1.2K GPU hours(V100)|-|
|tokens|tokens1|tokens2||
step9: 执行 `sh run.sh` 进行训练任务。
该示例基于 Flops 约束的两组结果如下:
| token name | tokens|
|---|---|
|tokens1|[3, 1, 1, 0, 1, 0, 3, 2, 1, 0, 1, 0, 3, 1, 1, 0, 1, 0, 2, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0]|
|tokens2|[3, 1, 1, 0, 1, 0, 3, 2, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 2, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1]|
| - | FLOPS | Top1/Top5 accuracy | GPU cost | token |
|------------------|-------|--------------------|----------------------|--------|
| MobileNetV2 | 0% | 71.90% / 90.55% | - | - |
| Light-NAS-model0 | -3% | 72.45% / 90.70% | 1.2K GPU hours(V100) | token0 |
| Light-NAS-model1 | -17% | 71.84% / 90.45% | 1.2K GPU hours(V100) | token1 |
基于硬件耗时的模型结构搜索实验:
| - | Latency | Top1/Top5 accuracy | GPU cost | token |
|---------------|---------|--------------------|---------------------|--------|
| MobileNetV2 | 0% | 71.90% / 90.55% | - | - |
| RK3288 开发板 | -23% | 71.97% / 90.35% | 1.2K GPU hours(V100) | token2 |
| Android 手机 | -20% | 72.06% / 90.36% | 1.2K GPU hours(V100) | token3 |
| iPhone 手机 | -17% | 72.22% / 90.47% | 1.2K GPU hours(V100) | token4 |
| token name | tokens |
|------------|--------|
| tokens0 | [3, 1, 1, 0, 1, 0, 3, 2, 1, 0, 1, 0, 3, 1, 1, 0, 1, 0, 2, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0] |
| tokens1 | [3, 1, 1, 0, 1, 0, 3, 2, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 2, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1] |
| tokens2 | [0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 2, 2, 1, 0, 1, 1, 2, 1, 0, 0, 0, 0, 3, 2, 1, 0, 1, 0] |
| tokens3 | [3, 0, 0, 0, 1, 0, 1, 2, 0, 0, 1, 0, 0, 2, 0, 1, 1, 0, 3, 1, 0, 1, 1, 0, 0, 2, 1, 1, 1, 0] |
| tokens4 | [3, 1, 1, 0, 0, 0, 3, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 3, 0, 1, 0, 1, 1, 2, 1, 1, 0, 1, 0] |
......@@ -271,6 +271,23 @@ e^{\frac{(r_k-r)}{T_k}} & r_k < r\\
</p>
### 4.3 模型延时评估
搜索过程支持 FLOPS 约束和模型延时约束。而基于 Android/iOS 移动端、开发板等硬件平台,迭代搜索过程中不断测试模型的延时不仅消耗时间而且非常不方便,因此我们开发了模型延时评估器来评估搜索得到模型的延时。通过延时评估器评估得到的延时与模型实际测试的延时波动偏差小于 10%。
延时评估器分为配置硬件延时评估器和评估模型延时两个阶段,配置硬件延时评估器只需要执行一次,而评估模型延时则在搜索过程中不断评估搜索得到的模型延时。
- 配置硬件延时评估器
1. 获取搜索空间中所有不重复的 op 及其参数
2. 获取每组 op 及其参数的延时
- 评估模型延时
1. 获取给定模型的所有 op 及其参数
2. 根据给定模型的所有 op 及参数,利用延时评估器去估计模型的延时
## 5. 参考文献
1. [High-Performance Hardware for Machine Learning](https://media.nips.cc/Conferences/2015/tutorialslides/Dally-NIPS-Tutorial-2015.pdf)
......
......@@ -29,7 +29,7 @@
- [量化使用说明](#21-量化训练)
- [剪裁使用说明](#22-模型通道剪裁)
- [蒸馏使用说明](#23-蒸馏)
- [轻量级模型结构搜索使用说明](#24-轻量级模型结构搜索)
- [轻量级模型结构搜索使用说明](#24-基于硬件的轻量级模型结构搜索)
## 1. PaddleSlim通用功能使用介绍
......@@ -518,28 +518,29 @@ distillers:
- **distillation_loss_weight:** 当前定义的loss对应的权重。默认为1.0
### 2.4 轻量级模型结构搜索
### 2.4 基于硬件的轻量级模型结构搜索
该功能基于模拟退火算法,实现了轻量级模型结构的快速搜索,简称为LightNAS(Light Network Architecture Search).
该功能基于模拟退火算法,实现了基于不同硬件的轻量级模型结构的快速搜索,简称为LightNAS (Light Network Architecture Search).
使用改功能,需要用户做两个工作:
使用该功能,需要用户做三个工作:
- 定义搜索空间
- 配置LightNASStrategy,并启动压缩任务
- (可选)基于不同的硬件,例如Android/iOS移动端、Android开发板等,配置延时评估器
- 配置LightNASStrategy,并启动搜索任务
#### 2.4.1 定义搜索空间
用户需要通过继承[SearchSpace类](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/contrib/slim/nas/search_space.py#L19)并重写其方法来定义搜索空间。用户需要重写实现的方法有:
- init_tokens: tokens以数组的形式格式表示网络结构,一个tokens对应一个网络结构。init_tokens指搜索的初始化tokens.
- `init_tokens`: `tokens` 以数组的形式格式表示网络结构,一个 `tokens` 对应一个网络结构。`init_tokens` 指搜索的初始化 `tokens`
- range_table: 以数组的形式指定tokens数组中每个位置的取值范围,其长度与tokens长度相同。tokens[i]的取值范围为[0, range_table[i]).
- `range_table`: 以数组的形式指定 `tokens` 数组中每个位置的取值范围,其长度与 `tokens` 长度相同。`tokens[i]` 的取值范围为 `[0, range_table[i])`
- create_net: 根据指定的tokens构造初始化Program、训练Program和测试Program.
- `create_net`: 根据指定的 `tokens` 构造初始化 `Program`、训练 `Program` 和测试 `Program`
[PaddlePaddle/models/light_nas](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/light_nas/light_nas_space.py)下,定义了经过验证的一个搜索空间,建议一般用户直接用该搜索空间。
在构造Compressor对象时,按以下方式将SearchSpace实例传入:
在构造 `Compressor` 对象时,按以下方式将 `SearchSpace` 实例传入:
```
...
......@@ -560,8 +561,17 @@ com_pass = Compressor(
search_space=space)
```
#### 2.4.2 (可选) 基于不同硬件,配置延时评估器
#### 2.4.2 配置LightNASStrategy
用户需要根据自己定义的搜索空间,类似[LightNASSpace类](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/light_nas/light_nas_space.py)中的 `get_all_ops` 函数,重写获取搜索空间所有可能 op 的方法。
用户需要根据其搜索空间所有可能的 op,生成延时评估器表格。延时评估器表格一般存放在类似[LightNASSpace类](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/light_nas/light_nas_space.py)里面的 `LATENCY_LOOKUP_TABLE_PATH=latency_lookup_table.txt` 路径下。后面会详细介绍延时评估器表格的生成方式。
用户需要通过继承[SearchSpace类](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/contrib/slim/nas/search_space.py#L19)并重写下面方法:
- `get_model_latency`: 参数 `program` 对应搜索到的某一个网络结构。使用该功能,用户能够根据不同硬件提前生成延时评估器表格,然后查询获取各个搜索到的网络的延时。
#### 2.4.3 配置 LightNASStrategy
在配置文件中,配置搜索策略方式如下:
```
......@@ -570,6 +580,7 @@ strategies:
class: 'LightNASStrategy'
controller: 'sa_controller'
target_flops: 592948064
target_latency: 0
end_epoch: 500
retrain_epoch: 5
metric_name: 'acc_top1'
......@@ -578,11 +589,12 @@ strategies:
is_server: True
search_steps: 100
```
其中, 需要在关键字`strategies`下注册策略实例,可配置参数有:
其中, 需要在关键字 `strategies` 下注册策略实例,可配置参数有:
- **class:** 策略类的名称,轻量级模型结构搜索策略请设置为LightNASStrategy。
- **controller:** 用于搜索的controller, 需要在当前配置文件提前注册,下文会详细介绍其注册方法。
- **target_flops:** FLOPS限制,搜索出的网络结构的FLOPS不超过该数值。
- **target_latency** 评估延时限制,搜索出的网络结构评估的延时不超过该数值。0 表示不限制,不会启动基于硬件的网络搜索。
- **end_epoch:** 当前client结束搜索策略的epoch。
- **retrain_epoch:** 在评估模型结构性能之前,需要训练的epoch数量。(end_epoch-0)/retrain_epoch为当前client搜索出的网络结构的数量。
- **metric_name:** 评估模型性能的指标。
......@@ -606,20 +618,104 @@ controllers:
- **init_temperature:** float类型;初始化温度。
- **max_iter_number:** int类型;在得到一个满足FLOPS限制的tokens之前,最多尝试的次数。
#### 2.4.3 分布式搜索
#### 2.4.4 分布式搜索
单机多任务:
单机多任务是指在一个机器上启动一个controller server和多个client, client从controller获取tokens, 根据tokens组建网络并训练评估,最后返回reward给controller server.
单机多任务是指在一个机器上启动一个 controller server 和多个 client,client 从 controller 获取 tokens,根据 tokens 组建网络并训练评估,最后返回 reward 给 controller server。
在Compressor::run()执行时,会首先判断配置文件中的`is_server`是否为`True`, 然后做如下操作:
- True: 判断当前路径下是否存在`slim_LightNASStrategy_controller_server.socket`文件,如果存在,则仅启动一个client,如果不存在,则启动一个controller server和一个client.
- True: 判断当前路径下是否存在 `slim_LightNASStrategy_controller_server.socket` 文件,如果存在,则仅启动一个 client,如果不存在,则启动一个 controller server 和一个 client。
- False: 仅启动一个client
- False: 仅启动一个 client。
多机搜索:
多机搜索是指在一个机器上启动一个controller server,在多台机器上启动若干client。在启动controller server的机器上的配置文件里的is_server要设置为True。其它机器上的配置文件中的`is_server`要手动设置为False, 同时`server_ip``server_port`要设置为controller server对应的`ip``port`.
>注意: 在重启controller server时,lim_LightNASStrategy_controller_server.socke文件可能不会被及时清除,所以需要用户手动删除该文件。在后续版本中,会修复完善该问题。
>注意: 在重启controller server时,`slim_LightNASStrategy_controller_server.socket` 文件可能不会被及时清除,所以需要用户手动删除该文件。在后续版本中,会修复完善该问题。
#### 2.4.5 延时评估器生成方式
1. 延时评估器表格的标准形式
延时评估器表格一般存放在一个 .txt 文件中。对于不同的硬件平台,我们都会根据搜索空间中的所有可能 op 生成延时评估器表格。延时评估器表格中的每一行都对应一个 op,其内容形式如下:
- `conv flag_bias flag_relu n_in c_in h_in w_in c_out groups kernel padding stride dilation latency`
- `activation active_type n_in c_in h_in w_in latency`
- `batch_norm active_type n_in c_in h_in w_in latency`
- `eltwise eltwise_type n_in c_in h_in w_in latency`
- `pooling flag_global_pooling n_in c_in h_in w_in kernel padding stride ceil_mode pool_type latency`
- `softmax axis n_in c_in h_in w_in latency`
其中 `conv`、`activation`、`batch_norm`、`eltwise`、`pooling`、`softmax` 分别代表卷积运算、激活函数、batch normalization、elementwise 运算、池化以及 softmax 运算。目前主要支持了这些 op。参数含义如下:
- active_type (string) - 激活函数类型,包含:relu, prelu, sigmoid, relu6, tanh。
- eltwise_type (int) - 按元素操作算子类型,其中 1 表示 elementwise_mul,2 表示elementwise_add,3 表示 elementwise_max。
- pool_type (int) - 池化类型,其中 1 表示 pooling_max,2 表示 pooling_average_include_padding,3 表示 pooling_average_exclude_padding。
- flag_bias (int) - 是否有 bias(0:无,1:有)。
- flag_global_pooling (int) - 是否为全局池化(0:不是,1:是)。
- flag_relu (int) - 是否有 relu(0:无,1:有)。
- n_in (int) - 输入 Tensor 的批尺寸 (batch size)。
- c_in (int) - 输入 Tensor 的通道 (channel) 数。
- h_in (int) - 输入 Tensor 的特征高度。
- w_in (int) - 输入 Tensor 的特征宽度。
- c_out (int) - 输出 Tensor 的通道 (channel) 数。
- groups (int) - 卷积二维层(Conv2D Layer)的组数。
- kernel (int) - 卷积核大小。
- padding (int) - 填充 (padding) 大小。
- stride (int) - 步长 (stride) 大小。
- dilation (int) - 膨胀 (dilation) 大小。
- axis (int) - 执行 softmax 计算的维度索引,应该在 [−1,rank − 1] 范围内,其中 rank 是输入变量的秩。
- ceil_mode (int) - 是否用 ceil 函数计算输出高度和宽度。0 表示使用 floor 函数,1 表示使用 ceil 函数。
- latency (float) - 当前op的延时时间
2. 不同硬件平台延时评估器的生成方法
Android 系统:
- 用户从[这里](https://paddle-slim-models.bj.bcebos.com/Android_demo.zip)下载 Android 系统的延时评估器生成工具。
- 安装ADB。比如 macOS 系统, 可以利用brew一键安装,`brew cask install android-platform-tools`。
- 连接硬件平台。利用 adb devices 查看当前连接的设备,判断是否正确连接。
- 进入工具目录 Android_demo,命令行输入 `sh push2android.sh`, 把必要的文件放置到硬件平台。
- 在 `models/PaddleSlim/light_nas/` 目录下运行 `python get_latency_lookup_table.py` 就可以获取当前搜索空间的延时评估器表格 `latency_lookup_table.txt`。
- 另外一种方式:用户还可以将`models/PaddleSlim/light_nas/light_nas_space.py` 中的 `get_all_ops` 函数获取的所有 op 写入文件中,比如 `lightnas_ops.txt`,然后调用延时评估器生成工具包 `Android_demo` 目录下的 `get_latency_lookup_table.py` 函数产生评估器表格。
备注1:我们基于[Paddle Mobile](https://github.com/PaddlePaddle/paddle-mobile)预测库编写,编译并获取重要 op 单测延时、网络模型延时的二进制文件。重要 op 单测延时的二进制文件都被命名为 `get_{op}_latency`,其中对于不同 op 的单测程序,替换 `get_{op}_latency` 中的 `{op}` 为该 op 名称。所有单测均输出一个表示平均延时的浮点数。这些单测文件的调用方法如下:
- `./get_activation_latency "threads test_iter active_type n_in c_in h_in w_in"`
- `./get_batch_norm_latency "threads test_iter active_type n_in c_in h_in w_in"`
- `./get_conv_latency "threads test_iter flag_bias flag_relu n_in c_in h_in w_in c_out group kernel padding stride dilation"`
- `./get_eltwise_latency "threads test_iter eltwise_type n_in c_in h_in w_in"`
- `./get_pooling_latency "threads test_iter flag_global_pooling n_in c_in h_in w_in kernel padding stride ceil_mode pool_type"`
- `./get_softmax_latency "threads test_iter axis n_in c_in h_in w_in"`
可以看出,他们传入了一个字符串参数,这些字符串参数除开最初始的 `threads` 和 `test_iter` 以外,都与延时评估器表格中各个 op 的参数一样,其中
- threads (int) - 线程数(最大为手机支持的线程数)。
- test_iter (int) - 执行单测次数。
我们同样提供了测试整个模型延时的二进制文件,命名为 `get_net_latency`,它返回的是整个模型的延时。调用方法如下:
- `./get_net_latency model_path threads test_iter`
其中 `model_path` 是保存 PaddlePaddle 模型的路径,用户需要利用 [paddle.fluid.io.save_inference_model](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/api_cn/io_cn.html#save-inference-model)将参数保存为单独的文件。如何单独使用这些二进制文件可以参看[这里](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/mobile/doc/development_android.md)或者`get_latency_lookup_table.py`里面的类似方法。
备注2:用户如果有其他 op 的开发需求,可以根据 Paddle Mobile 的[op单测](https://github.com/PaddlePaddle/Paddle-Lite/tree/develop/mobile/test/operators)进行开发,基于android端的编译方法可以参见[这里](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/mobile/doc/development_android.md),欢迎大家贡献代码。
我们提供的示例 op 的单测代码可以在[这里](https://paddle-slim-models.bj.bcebos.com/android_op_test.zip)下载。用户通过 `git clone https://github.com/PaddlePaddle/Paddle-Lite.git` 命令, 将 android_op_test 里面的单测代码放置到 `Paddle-Lite/mobile/test/operators` 目录,然后修改 `Paddle-Lite/mobile/test/CMakeList.txt` 进行编译即可生成所需的二进制文件。
iOS系统:
- 用户从[这里](https://paddle-slim-models.bj.bcebos.com/OpLatency.zip)下载iOS系统的延时评估器生成工具 OpLatency。
- 与Android系统不同的是,在使用延时评估器生成工具之前,用户需要把从 `models/PaddleSlim/light_nas/light_nas_space.py` 中的 `get_all_ops` 函数里面得到的搜索空间所有 op 参数写入到一个 .txt 文件中。该文件与延时评估器表格类似,每行内容对应一个 op,仅仅缺少该 op 的延时数据。在 LightNAS 中,我们将它命名为`lightnas_ops.txt`。
- 用户需要安装 Xcode,连接 iOS 硬件平台,目前不支持虚拟设备。注意选中项目名称 OpLatency,在 General-->Signing 中修改 developer 信息。
- 将上述准备好的 `lightnas_ops.txt` 文件拖入工程。注意根据提示勾选 `Add to targets`。
- 在 ViewController 我们调用了 OCWrapper 类里面的 `get_latency_lookup_table` 方法,修改其输入输出参数为 `lightnas_ops.txt` 与 `latency_lookup_table.txt`。
- 运行 OpLatency,生成手机 APP 的同时,程序会在 APP 沙盒中生成一个当前搜索空间的延时评估器表格 `latency_lookup_table.txt`。
- 点击 Windows-->Devices and Simulators-->OpLatency->Download Container 将沙盒下载到 PC,右键点击显示包内容,在 AppData-->Documents 即能找到延时评估器表格。
备注1:我们同样提供了测试整个模型延时的方法。可以在 ViewController 调用 OCWrapper 类里面的 `get_net_latency` 方法。其中 `get_net_latency` 的参数为 model 和 params 路径,用户需要利用 [paddle.fluid.io.save_inference_model](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/api_cn/io_cn.html#save-inference-model)将所有参数打包存储。
备注2:用户如果有其他 op 的开发需求,可以根据我们[这里](https://paddle-slim-models.bj.bcebos.com/ios_op_test.zip)提供的代码示例进行开发。使用方法:解压并在命令行运行 `sh run.sh` 即可生成 OpLatency 里面所需的打包文件 `libpaddle-mobile.a` 和头文件 `ios_op_test.h`。
......@@ -10,6 +10,7 @@ strategies:
class: 'LightNASStrategy'
controller: 'sa_controller'
target_flops: 592948064
target_latency: 0
end_epoch: 500
retrain_epoch: 5
metric_name: 'acc_top1'
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Get latency lookup table."""
from __future__ import print_function
import re
import argparse
import subprocess
from light_nas_space import get_all_ops
def get_args():
"""Get arguments.
Returns:
Namespace, arguments.
"""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--latency_lookup_table_path',
default='latency_lookup_table.txt',
help='Output latency lookup table path.')
parser.add_argument(
'--platform', default='android', help='Platform: android/ios/custom.')
parser.add_argument('--threads', type=int, default=1, help='Threads.')
parser.add_argument(
'--test_iter',
type=int,
default=100,
help='Running times of op when estimating latency.')
args = parser.parse_args()
return args
def get_op_latency(op, platform):
"""Get op latency.
Args:
op: list, a list of str represents the op and its parameters.
platform: str, platform name.
Returns:
float, op latency.
"""
if platform == 'android':
commands = 'adb shell "cd /data/local/tmp/bin && LD_LIBRARY_PATH=. ./get_{}_latency \'{}\'"'.format(
op[0], ' '.join(op[1:]))
proc = subprocess.Popen(
commands,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True)
out = proc.communicate()[0]
out = [_ for _ in out.decode().split('\n') if 'Latency' in _][-1]
out = re.findall(r'\d+\.?\d*', out)[0]
out = float(out)
elif platform == 'ios':
print('Please refer the usage doc to get iOS latency lookup table')
out = 0
else:
print('Please define `get_op_latency` for {} platform'.format(platform))
out = 0
return out
def main():
"""main."""
args = get_args()
ops = get_all_ops()
fid = open(args.latency_lookup_table_path, 'w')
for op in ops:
op = [str(item) for item in op]
latency = get_op_latency(
op[:1] + [str(args.threads), str(args.test_iter)] + op[1:],
args.platform)
fid.write('{} {}\n'.format(' '.join(op), latency))
fid.close()
if __name__ == '__main__':
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Get ops from program."""
def conv_op_params(blocks, current_op):
"""Getting params of conv op
Args:
blocks: BlockDesc, current block
current_op: OpDesc, current op
Returns:
(list): op name and hyperparamters
"""
tmp, res = [], []
# op_name
tmp.append('conv')
# flag_bias
if not current_op.input('Bias'):
tmp.append(0)
else:
tmp.append(1)
# flag_relu
tmp.append(int(current_op.attr('fuse_relu')))
# batch size
tmp.append(1)
# channels, height, width
in_shapes = blocks.vars[current_op.input('Input')[0]].shape
tmp = tmp + [int(in_shapes[1]), int(in_shapes[2]), int(in_shapes[3])]
# output channels
w_shapes = blocks.vars[current_op.input('Filter')[0]].shape
tmp.append(int(w_shapes[0]))
# group
tmp.append(int(current_op.attr('groups')))
# kernel size
tmp.append(int(w_shapes[2]))
if w_shapes[2] != w_shapes[3]:
res.append(int(w_shapes[3]))
# padding
paddings = current_op.attr('paddings')
tmp.append(int(paddings[0]))
if paddings[0] != paddings[1]:
res.append(int(paddings[0]))
# strides
strides = current_op.attr('strides')
tmp.append(int(strides[0]))
if strides[0] != strides[1]:
res.append(int(strides[1]))
# dilations
dilations = current_op.attr('dilations')
tmp.append(int(dilations[0]))
if dilations[0] != dilations[1]:
res.append(int(dilations[1]))
tmp = tmp + res
return tmp
def batch_norm_op_params(blocks, current_op):
"""Getting params of batch_norm op
Args:
blocks: BlockDesc, current block
current_op: OpDesc, current op
Returns:
(list): op name and hyperparamters
"""
tmp = []
# op name
tmp.append('batch_norm')
# activation type
if not current_op.attr('fuse_with_relu'):
tmp.append('None')
else:
tmp.append('relu')
# batch size
tmp.append(1)
# input channels, height, width
in_shapes = blocks.vars[current_op.input("X")[0]].shape
tmp = tmp + [int(in_shapes[1]), int(in_shapes[2]), int(in_shapes[3])]
return tmp
def eltwise_op_params(blocks, current_op):
"""Getting params of eltwise op
Args:
blocks: BlockDesc, current block
current_op: OpDesc, current op
Returns:
(list): op name and hyperparamters
"""
# op name
tmp = ['eltwise']
# elementwise type, TODO: add more ops
if current_op.type == 'elementwise_mul':
tmp.append(1)
elif current_op.type == 'elementwise_add':
tmp.append(2)
else:
tmp.append(3)
# batch size
tmp.append(1)
# input channels, height, width
in_shapes = blocks.vars[current_op.input('X')[0]].shape
while len(in_shapes) < 4:
in_shapes = in_shapes + (1, )
for i in range(1, len(in_shapes)):
tmp.append(int(in_shapes[i]))
return tmp
def activation_op_params(blocks, current_op):
"""Getting params of activation op
Args:
blocks: BlockDesc, current block
current_op: OpDesc, current op
Returns:
(list): op name and hyperparamters
"""
tmp = []
# op name
tmp.append('activation')
# activation type
tmp.append(current_op.type)
# batch size
tmp.append(1)
# input channels, height, width
in_shapes = blocks.vars[current_op.input('X')[0]].shape
while len(in_shapes) < 4:
in_shapes = in_shapes + (1, )
for i in range(1, len(in_shapes)):
tmp.append(int(in_shapes[i]))
return tmp
def pooling_op_params(blocks, current_op):
"""Getting params of pooling op
Args:
blocks: BlockDesc, current block
current_op: OpDesc, current op
Returns:
(list): op name and hyperparamters
"""
tmp, res = [], []
# op name
tmp.append('pooling')
# global pooling
tmp.append(int(current_op.attr('global_pooling')))
# batch size
tmp.append(1)
# channels, height, width
in_shapes = blocks.vars[current_op.input('X')[0]].shape
tmp = tmp + [int(in_shapes[1]), int(in_shapes[2]), int(in_shapes[3])]
# kernel size
ksize = current_op.attr('ksize')
tmp.append(int(ksize[0]))
if ksize[0] != ksize[1]:
res.append(int(ksize[1]))
# padding
paddings = current_op.attr('paddings')
tmp.append(int(paddings[0]))
if paddings[0] != paddings[1]:
res.append(int(paddings[1]))
# stride
strides = current_op.attr('strides')
tmp.append(int(strides[0]))
if strides[0] != strides[1]:
res.append(int(strides[1]))
# ceil mode
tmp.append(int(current_op.attr('ceil_mode')))
# pool type
pool_type = current_op.attr('pooling_type')
exclusive = current_op.attr('exclusive')
if pool_type == 'max' and (not exclusive):
tmp.append(1)
elif pool_type == 'avg' and (not exclusive):
tmp.append(2)
else:
tmp.append(3)
tmp = tmp + res
return tmp
def softmax_op_params(blocks, current_op):
"""Getting params of softmax op
Args:
blocks: BlockDesc, current block
current_op: OpDesc, current op
Returns:
(list): op name and hyperparamters
"""
# op name
tmp = ['softmax']
# axis
tmp.append(current_op.attr('axis'))
# batch size
tmp.append(1)
# input channels, height, width
in_shapes = blocks.vars[current_op.input('X')[0]].shape
while len(in_shapes) < 4:
in_shapes = in_shapes + (1, )
for i in range(1, len(in_shapes)):
tmp.append(int(in_shapes[i]))
return tmp
def fc_op_params(blocks, current_op):
"""Getting params of fc op
Note:
fc op is converted to conv op with 1x1 kernels
Args:
blocks: BlockDesc, current block
current_op: OpDesc, current op
Returns:
(list): op name and hyperparamters
"""
# op name
tmp = ['conv']
# flag bias
tmp.append(0)
# flag relu
tmp.append(0)
# batch size
tmp.append(1)
# input channels, height, width
channels = 1
in_shape = blocks.vars[current_op.input('X')[0]].shape
for i in range(1, len(in_shape)):
channels *= in_shape[i]
tmp = tmp + [int(channels), 1, 1]
# output channels
tmp.append(int(blocks.vars[current_op.output('Out')[0]].shape[1]))
# groups, kernel size, padding, stride, dilation
tmp = tmp + [1, 1, 0, 1, 1]
return tmp
def get_ops_from_program(program):
"""Getting ops params from a paddle program
Args:
program(Program): The program to get ops.
Returns:
(list): ops.
"""
blocks = program.global_block()
ops = []
i = 0
while i < len(blocks.ops):
current_op = blocks.ops[i]
if current_op.type in ['conv2d', 'depthwise_conv2d']:
tmp = conv_op_params(blocks, current_op)
elif current_op.type in [
'elementwise_add', 'elementwise_mul', 'elementwise_max'
]:
tmp = eltwise_op_params(blocks, current_op)
elif current_op.type in [
'relu', 'prelu', 'sigmoid', 'relu6', 'elu', 'brelu',
'leaky_relu'
]:
tmp = activation_op_params(blocks, current_op)
elif current_op.type == 'batch_norm':
tmp = batch_norm_op_params(blocks, current_op)
elif current_op.type == 'pool2d':
tmp = pooling_op_params(blocks, current_op)
elif current_op.type == 'batch_norm':
tmp = batch_norm_op_params(blocks, current_op)
elif current_op.type == 'softmax':
tmp = softmax_op_params(blocks, current_op)
elif current_op.type == 'mul':
tmp = fc_op_params(blocks, current_op)
else:
tmp = None
if tmp:
ops.append(tmp)
i += 1
return ops
......@@ -11,21 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Light-NAS space."""
import sys
import math
from paddle.fluid.contrib.slim.nas import SearchSpace
import paddle.fluid as fluid
import paddle
import sys
sys.path.append('..')
from models import LightNASNet
import reader
from get_ops_from_program import get_ops_from_program
total_images = 1281167
lr = 0.1
num_epochs = 240
lr = 0.016
num_epochs = 350
batch_size = 512
lr_strategy = "cosine_decay"
l2_decay = 4e-5
lr_strategy = "exponential_decay_with_RMSProp"
l2_decay = 1e-5
momentum_rate = 0.9
image_shape = [3, 224, 224]
class_dim = 1000
......@@ -39,6 +41,7 @@ NAS_KERNEL_SIZE = [3, 5]
NAS_FILTERS_MULTIPLIER = [3, 4, 5, 6]
NAS_SHORTCUT = [0, 1]
NAS_SE = [0, 1]
LATENCY_LOOKUP_TABLE_PATH = 'latency_lookup_table.txt'
def get_bottleneck_params_list(var):
......@@ -67,9 +70,149 @@ def get_bottleneck_params_list(var):
return params_list
def ops_of_inverted_residual_unit(in_c,
in_shape,
expansion,
kernels,
num_filters,
s,
ifshortcut=True,
ifse=True):
"""Get ops of possible repeated inverted residual unit
Args:
in_c: list, a list of numbers of input channels
in_shape: int, size of input feature map
expansion: int, expansion factor
kernels: list, a list of possible kernel size
s: int, stride of depthwise conv
ifshortcut: bool
ifse: bool
Returns:
op_params: list, a list of op params
"""
op_params = []
for c in in_c:
for t in expansion:
# expansion
op_params.append(('conv', 0, 0, 1, c, in_shape, in_shape, c * t, 1,
1, 0, 1, 1))
op_params.append(('batch_norm', 'None', 1, c * t, in_shape,
in_shape))
op_params.append(('activation', 'relu6', 1, c * t, in_shape,
in_shape))
# depthwise
for k in kernels:
op_params.append(('conv', 0, 0, 1, c * t, in_shape, in_shape,
c * t, c * t, k, int(int(k - 1) / 2), s, 1))
op_params.append(('batch_norm', 'None', 1, c * t, int(in_shape / s),
int(in_shape / s)))
op_params.append(('activation', 'relu6', 1, c * t,
int(in_shape / s), int(in_shape / s)))
# shrink
for out_c in num_filters:
op_params.append(('conv', 0, 0, 1, c * t, int(in_shape / s),
int(in_shape / s), out_c, 1, 1, 0, 1, 1))
op_params.append(('batch_norm', 'None', 1, out_c,
int(in_shape / s), int(in_shape / s)))
# shortcut
if ifshortcut:
op_params.append(('eltwise', 2, 1, out_c, int(in_shape / s),
int(in_shape / s)))
if ifse:
op_params.append(('pooling', 1, 1, out_c, int(in_shape / s),
int(in_shape / s), 0, 0, 1, 0, 3))
op_params.append(('conv', 0, 0, 1, out_c, 1, 1,
int(out_c / 4), 1, 1, 0, 1, 1))
op_params.append(('eltwise', 2, 1, int(out_c / 4), 1, 1))
op_params.append(
('activation', 'relu', 1, int(out_c / 4), 1, 1))
op_params.append(('conv', 0, 0, 1, int(out_c / 4), 1, 1,
out_c, 1, 1, 0, 1, 1))
op_params.append(('eltwise', 2, 1, out_c, 1, 1))
op_params.append(('activation', 'sigmoid', 1, out_c, 1, 1))
op_params.append(('eltwise', 1, 1, out_c, int(in_shape / s),
int(in_shape / s)))
op_params.append(('activation', 'relu', 1, out_c,
int(in_shape / s), int(in_shape / s)))
return op_params
def get_all_ops(ifshortcut=True, ifse=True, strides=[1, 2, 2, 2, 1, 2, 1]):
"""Get all possible ops of current search space
Args:
ifshortcut: bool, shortcut or not
ifse: bool, se or not
strides: list, list of strides for bottlenecks
Returns:
op_params: list, a list of all possible params
"""
op_params = []
# conv1_1
op_params.append(('conv', 0, 0, 1, image_shape[0], image_shape[1],
image_shape[2], 32, 1, 3, 1, 2, 1))
op_params.append(('batch_norm', 'None', 1, 32, int(image_shape[1] / 2),
int(image_shape[2] / 2)))
op_params.append(('activation', 'relu6', 1, 32, int(image_shape[1] / 2),
int(image_shape[2] / 2)))
# bottlenecks, TODO: different h and w for images
in_c, in_shape = [32], int(image_shape[1] / 2)
for i in range(len(NAS_FILTER_SIZE) + 2):
if i == 0:
expansion, kernels, num_filters, s = [1], [3], [16], strides[i]
elif i == len(NAS_FILTER_SIZE) + 1:
expansion, kernels, num_filters, s = [6], [3], [320], strides[i]
else:
expansion, kernels, num_filters, s = NAS_FILTERS_MULTIPLIER, \
NAS_KERNEL_SIZE, \
NAS_FILTER_SIZE[i-1], \
strides[i]
# first block
tmp_ops = ops_of_inverted_residual_unit(
in_c, in_shape, expansion, kernels, num_filters, s, False, ifse)
op_params = op_params + tmp_ops
in_c, in_shape = num_filters, int(in_shape / s)
# repeated block: possibly more ops, but it is ok
tmp_ops = ops_of_inverted_residual_unit(in_c, in_shape, expansion,
kernels, num_filters, 1,
ifshortcut, ifse)
op_params = op_params + tmp_ops
# last conv
op_params.append(('conv', 0, 0, 1, 320, in_shape, in_shape, 1280, 1, 1, 0,
1, 1))
op_params.append(('batch_norm', 'None', 1, 1280, in_shape, in_shape))
op_params.append(('activation', 'relu6', 1, 1280, in_shape, in_shape))
op_params.append(('pooling', 1, 1, 1280, in_shape, in_shape, in_shape, 0, 1,
0, 3))
# fc, converted to 1x1 conv
op_params.append(('conv', 0, 0, 1, 1280, 1, 1, class_dim, 1, 1, 0, 1, 1))
op_params.append(('eltwise', 2, 1, 1000, 1, 1))
return list(set(op_params))
class LightNASSpace(SearchSpace):
def __init__(self):
super(LightNASSpace, self).__init__()
if LATENCY_LOOKUP_TABLE_PATH:
self.init_latency_lookup_table(LATENCY_LOOKUP_TABLE_PATH)
def init_latency_lookup_table(self, latency_lookup_table_path):
"""Init lookup table.
Args:
latency_lookup_table_path: str, latency lookup table path.
"""
self._latency_lookup_table = dict()
for line in open(latency_lookup_table_path):
line = line.split()
self._latency_lookup_table[tuple(line[:-1])] = float(line[-1])
def init_tokens(self):
"""Get init tokens in search space.
......@@ -88,6 +231,18 @@ class LightNASSpace(SearchSpace):
2, 4, 3, 3, 2, 2, 2
]
def get_model_latency(self, program):
"""Get model latency according to program.
Args:
program(Program): The program to get latency.
Return:
(float): model latency.
"""
ops = get_ops_from_program(program)
latency = sum(
[self._latency_lookup_table[tuple(map(str, op))] for op in ops])
return latency
def create_net(self, tokens=None):
"""Create a network for training by tokens.
"""
......@@ -249,6 +404,43 @@ def optimizer_setting(params):
learning_rate=lr, step_each_epoch=step, epochs=num_epochs),
momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay))
elif ls["name"] == "exponential_decay":
if "total_images" not in params:
total_images = IMAGENET1000
else:
total_images = params["total_images"]
batch_size = ls["batch_size"]
num_epochs = params["num_epochs"]
start_lr = params["lr"]
total_step = int((total_images / batch_size) * num_epochs)
decay_step = int((total_images / batch_size) * 2.4)
lr = start_lr
lr = fluid.layers.exponential_decay(
learning_rate=start_lr,
decay_steps=decay_step,
decay_rate=0.97,
staircase=True)
optimizer = fluid.optimizer.SGDOptimizer(learning_rate=lr)
elif ls["name"] == "exponential_decay_with_RMSProp":
if "total_images" not in params:
total_images = IMAGENET1000
else:
total_images = params["total_images"]
batch_size = ls["batch_size"]
step = int(math.ceil(float(total_images) / batch_size))
decay_step = int(2.4 * step)
lr = params["lr"]
num_epochs = params["num_epochs"]
optimizer = fluid.optimizer.RMSProp(
learning_rate=fluid.layers.exponential_decay(
learning_rate=lr,
decay_steps=decay_step,
decay_rate=0.97,
staircase=False),
regularization=fluid.regularizer.L2Decay(l2_decay),
momentum=0.9,
rho=0.9,
epsilon=0.001)
elif ls["name"] == "linear_decay":
if "total_images" not in params:
total_images = IMAGENET1000
......
......@@ -12,7 +12,6 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import paddle
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.core import Compressor
from light_nas_space import LightNASSpace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册