网络结构搜索示例#

本示例介绍如何使用网络结构搜索接口,搜索到一个更小或者精度更高的模型,该文档仅介绍paddleslim中SANAS的使用及如何利用SANAS得到模型结构,完整示例代码请参考sa_nas_mobilenetv2.py或者block_sa_nas_mobilenetv2.py。

接口介绍#

请参考神经网络搜索API介绍

1. 配置搜索空间#

详细的搜索空间配置可以参考搜索空间介绍

config = [('MobileNetV2Space')]

2. 利用搜索空间初始化SANAS实例#

from paddleslim.nas import SANAS

sa_nas = SANAS(
    config,
    server_addr=("", 8881),
    init_temperature=10.24,
    reduce_rate=0.85,
    search_steps=300,
    is_server=True)

3. 根据实例化的NAS得到当前的网络结构#

archs = sa_nas.next_archs()

4. 根据得到的网络结构和输入构造训练和测试program#

import paddle.fluid as fluid

train_program = fluid.Program()
test_program = fluid.Program()
startup_program = fluid.Program()

with fluid.program_guard(train_program, startup_program):
    data = fluid.data(name='data', shape=[None, 3, 32, 32], dtype='float32')
    label = fluid.data(name='label', shape=[None, 1], dtype='int64')
    for arch in archs:
        data = arch(data)
    output = fluid.layers.fc(data, 10)
    softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)
    cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
    avg_cost = fluid.layers.mean(cost)
    acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)

    test_program = train_program.clone(for_test=True)
    sgd = fluid.optimizer.SGD(learning_rate=1e-3)
    sgd.minimize(avg_cost)

5. 根据构造的训练program添加限制条件#

from paddleslim.analysis import flops

if flops(train_program) > 321208544:
    continue

6. 回传score#

sa_nas.reward(score)