nas_tutorial.md 6.2 KB
Newer Older
W
whs 已提交
1
# 网络结构搜索
C
ceci3 已提交
2

W
whs 已提交
3
该教程以图像分类模型MobileNetV2为例,说明如何在cifar10数据集上快速使用[网络结构搜索接口](https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/static/nas/nas_api.html)
C
ceci3 已提交
4 5 6 7 8 9 10 11
该示例包含以下步骤:

1. 导入依赖
2. 初始化SANAS搜索实例
3. 构建网络
4. 定义输入数据函数
5. 定义训练函数
6. 定义评估函数
C
ceci3 已提交
12 13 14 15 16 17
7. 启动搜索实验  
  7.1 获取模型结构  
  7.2 构造program  
  7.3 定义输入数据  
  7.4 训练模型  
  7.5 评估模型  
C
ceci3 已提交
18 19 20 21 22 23 24 25 26 27
  7.6 回传当前模型的得分
8. 完整示例


以下章节依次介绍每个步骤的内容。

## 1. 导入依赖
请确认已正确安装Paddle,导入需要的依赖包。
```python
import paddle
C
ceci3 已提交
28 29 30
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.static as static
C
ceci3 已提交
31 32 33 34 35 36
import paddleslim as slim
import numpy as np
```

## 2. 初始化SANAS搜索实例
```python
C
Chang Xu 已提交
37 38
port = np.random.randint(8337, 8773)
sanas = slim.nas.SANAS(configs=[('MobileNetV2Space')], server_addr=("", port), save_checkpoint=None)
C
ceci3 已提交
39 40 41 42 43
```

## 3. 构建网络
根据传入的网络结构构造训练program和测试program。
```python
C
ceci3 已提交
44
paddle.enable_static()
C
ceci3 已提交
45
def build_program(archs):
C
ceci3 已提交
46 47 48 49 50 51
    train_program = static.Program()
    startup_program = static.Program()
    with static.program_guard(train_program, startup_program):
        data = static.data(name='data', shape=[None, 3, 32, 32], dtype='float32')
        label = static.data(name='label', shape=[None, 1], dtype='int64')
        gt = paddle.reshape(label, [-1, 1])
C
ceci3 已提交
52
        output = archs(data)
C
ceci3 已提交
53
        output = static.nn.fc(output, size=10)
C
ceci3 已提交
54

C
ceci3 已提交
55
        softmax_out = F.softmax(output)
C
ceci3 已提交
56
        cost = F.cross_entropy(softmax_out, label=gt)
C
ceci3 已提交
57 58 59 60
        avg_cost = paddle.mean(cost)
        acc_top1 = paddle.metric.accuracy(input=softmax_out, label=gt, k=1)
        acc_top5 = paddle.metric.accuracy(input=softmax_out, label=gt, k=5)
        test_program = static.default_main_program().clone(for_test=True)
C
ceci3 已提交
61

C
ceci3 已提交
62
        optimizer = paddle.optimizer.Adam(learning_rate=0.1)
C
ceci3 已提交
63
        optimizer.minimize(avg_cost)
C
ceci3 已提交
64 65
        place = paddle.CPUPlace()
        exe = static.Executor(place)
C
ceci3 已提交
66 67 68 69 70
        exe.run(startup_program)
    return exe, train_program, test_program, (data, label), avg_cost, acc_top1, acc_top5
```

## 4. 定义输入数据函数
C
ceci3 已提交
71 72
为了快速执行该示例,我们使用的数据集为CIFAR10,Paddle框架的`paddle.vision.datasets.Cifar10`包定义了CIFAR10数据的下载和读取。 代码如下:

C
ceci3 已提交
73
```python
C
ceci3 已提交
74 75 76
import paddle.vision.transforms as T

def input_data(image, label):
C
ceci3 已提交
77 78
    transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
    train_dataset = paddle.vision.datasets.Cifar10(mode="train", transform=transform, backend='cv2')
C
ceci3 已提交
79 80 81 82 83 84 85
    train_loader = paddle.io.DataLoader(train_dataset,
                    places=paddle.CPUPlace(),
                    feed_list=[image, label],
                    drop_last=True,
                    batch_size=64,
                    return_list=False,
                    shuffle=True)
C
ceci3 已提交
86
    eval_dataset = paddle.vision.datasets.Cifar10(mode="test", transform=transform, backend='cv2')
C
ceci3 已提交
87 88 89 90 91 92 93 94
    eval_loader = paddle.io.DataLoader(eval_dataset,
                    places=paddle.CPUPlace(),
                    feed_list=[image, label],
                    drop_last=False,
                    batch_size=64,
                    return_list=False,
                    shuffle=False)
    return train_loader, eval_loader
C
ceci3 已提交
95 96 97 98 99
```

## 5. 定义训练函数
根据训练program和训练数据进行训练。
```python
C
ceci3 已提交
100
def start_train(program, data_loader):
C
ceci3 已提交
101
    outputs = [avg_cost.name, acc_top1.name, acc_top5.name]
C
ceci3 已提交
102 103
    for data in data_loader():
        batch_reward = exe.run(program, feed=data, fetch_list = outputs)
C
ceci3 已提交
104 105 106 107 108 109
        print("TRAIN: loss: {}, acc1: {}, acc5:{}".format(batch_reward[0], batch_reward[1], batch_reward[2]))
```

## 6. 定义评估函数
根据评估program和评估数据进行评估。
```python
C
ceci3 已提交
110
def start_eval(program, data_loader):
C
ceci3 已提交
111 112
    reward = []
    outputs = [avg_cost.name, acc_top1.name, acc_top5.name]
C
ceci3 已提交
113 114
    for data in data_loader():
        batch_reward = exe.run(program, feed=data, fetch_list = outputs)
C
ceci3 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
        reward_avg = np.mean(np.array(batch_reward), axis=1)
        reward.append(reward_avg)
        print("TEST: loss: {}, acc1: {}, acc5:{}".format(batch_reward[0], batch_reward[1], batch_reward[2]))
    finally_reward = np.mean(np.array(reward), axis=0)
    print("FINAL TEST: avg_cost: {}, acc1: {}, acc5: {}".format(finally_reward[0], finally_reward[1], finally_reward[2]))
    return finally_reward
```

## 7. 启动搜索实验
以下步骤拆解说明了如何获得当前模型结构以及获得当前模型结构之后应该有的步骤,如果想要看如何启动搜索实验的完整示例可以看步骤9。

### 7.1 获取模型结构
调用`next_archs()`函数获取到下一个模型结构。
```python
archs = sanas.next_archs()[0]
```

### 7.2 构造program
调用步骤3中的函数,根据4.1中的模型结构构造相应的program。
```python
C
ceci3 已提交
135
exe, train_program, eval_program, (image, label), avg_cost, acc_top1, acc_top5 = build_program(archs)
C
ceci3 已提交
136 137 138 139
```

### 7.3 定义输入数据
```python
C
ceci3 已提交
140
train_loader, eval_loader = input_data(image, label)
C
ceci3 已提交
141 142 143 144 145
```

### 7.4 训练模型
根据上面得到的训练program和评估数据启动训练。
```python
C
ceci3 已提交
146
start_train(train_program, train_loader)
C
ceci3 已提交
147 148 149 150
```
### 7.5 评估模型
根据上面得到的评估program和评估数据启动评估。
```python
C
ceci3 已提交
151
finally_reward = start_eval(eval_program, eval_loader)
C
ceci3 已提交
152 153 154 155 156 157 158 159 160 161 162
```
### 7.6 回传当前模型的得分
```
sanas.reward(float(finally_reward[1]))
```

## 8. 完整示例
以下是一个完整的搜索实验示例,示例中使用FLOPs作为约束条件,搜索实验一共搜索3个step,表示搜索到3个满足条件的模型结构进行训练,每搜索到一个网络结构训练7个epoch。
```python
for step in range(3):
    archs = sanas.next_archs()[0]
C
Chang Xu 已提交
163 164
    exe, train_program, eval_program, (images,label), avg_cost, acc_top1, acc_top5 = build_program(archs)
    train_loader, eval_loader = input_data(images, label)
C
ceci3 已提交
165 166 167 168

    current_flops = slim.analysis.flops(train_program)
    if current_flops > 321208544:
        continue
C
ceci3 已提交
169

C
ceci3 已提交
170
    for epoch in range(7):
C
ceci3 已提交
171
        start_train(train_program, train_loader)
C
ceci3 已提交
172

C
ceci3 已提交
173
    finally_reward = start_eval(eval_program, eval_loader)
C
ceci3 已提交
174 175 176

    sanas.reward(float(finally_reward[1]))
```