nas_api.md 3.2 KB
Newer Older
C
ceci3 已提交
1 2 3 4
# paddleslim.nas API文档

## SANAS API文档

C
ceci3 已提交
5 6 7 8 9 10
## class SANAS

---

>paddleslim.nas.SANAS(configs, server_addr, init_temperature, reduce_rate, search_steps, save_checkpoint, load_checkpoint, is_server)
SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火算法进行模型结构搜索的算法,一般用于离散搜索任务。
C
ceci3 已提交
11 12

**参数:**
C
ceci3 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
- **configs(list<tuple>):** 搜索空间配置列表,格式是`[(key, {input_size, output_size, block_num, block_mask})]`或者`[(key)]`(MobileNetV2、MobilenetV1和ResNet的搜索空间使用和原本网络结构相同的搜索空间,所以仅需指定`key`即可), `input_size``output_size`表示输入和输出的特征图的大小,`block_num`是指搜索网络中的block数量,`block_mask`是一组由0和1组成的列表,0代表不进行下采样的block,1代表下采样的block。 更多paddleslim提供的搜索空间配置可以参考。
- **server_addr(tuple):** SANAS的地址,包括server的ip地址和端口号,如果ip地址为None或者为""的话则默认使用本机ip。默认:("", 8881)。
- **init_temperature(float):** 基于模拟退火进行搜索的初始温度。默认:100。
- **reduce_rate(float):** 基于模拟退火进行搜索的衰减率。默认:0.85。
- **search_steps(int):** 搜索过程迭代的次数。默认:300。
- **save_checkpoint(str|None):** 保存checkpoint的文件目录,如果设置为None的话则不保存checkpoint。默认:`./nas_checkpoint`
- **load_checkpoint(str|None):** 加载checkpoint的文件目录,如果设置为None的话则不加载checkpoint。默认:None。
- **is_server(bool):** 当前实例是否要启动一个server。默认:True。

**返回:** 
一个SANAS类的实例

**示例代码:**
```
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
sanas = SANAS(config=config)
```

---

>tokens2arch(tokens)
C
ceci3 已提交
35 36 37
通过一组token得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。

**参数:**
C
ceci3 已提交
38
- **tokens(list):** 一组token。
C
ceci3 已提交
39 40

**返回**
C
ceci3 已提交
41
返回一个模型结构实例。
C
ceci3 已提交
42

C
ceci3 已提交
43
---
C
ceci3 已提交
44

C
ceci3 已提交
45
>next_archs():
C
ceci3 已提交
46 47 48
获取下一组模型结构。

**返回**
C
ceci3 已提交
49
返回模型结构实例的列表,形式为list。
C
ceci3 已提交
50

C
ceci3 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64
**示例代码:**
```
import paddle.fluid as fluid
input = fluid.data(name='input', shape=[None, 1, 32, 32], dtype='float32')
archs = sanas.next_archs()
for arch in archs:
    output = arch(input)
    input = output
```

---

>reward(score):
把当前模型结构的得分情况回传。
C
ceci3 已提交
65 66

**参数:**
C
ceci3 已提交
67
score<float>:** 当前模型的得分,分数越大越好。
C
ceci3 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82

**返回**
模型结构更新成功或者失败,成功则返回`True`,失败则返回`False`


**代码示例**
```python
import paddleslim.nas.SANAS as SANAS

# 搜索空间配置
config=[('MobileNetV2Space')] 

# 实例化SANAS
sa_nas = SANAS(config, server_addr=("", 8887), init_temperature=10.24, reduce_rate=0.85, search_steps=100, is_server=True)

C
ceci3 已提交
83
# 构造输入数据
C
ceci3 已提交
84 85 86 87 88 89 90 91 92 93 94
input = fluid.data(name='input', shape=[None, 1, 32, 32], dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
for step in range(100):
    archs = sa_nas.next_archs()
    for arch in archs:
        input = arch(input)

    score = fluid.layer.
    sa_nas.reward(score)

```