网络结构搜索示例#
本示例介绍如何使用网络结构搜索接口,搜索到一个更小或者精度更高的模型,该文档仅介绍paddleslim中SANAS的使用及如何利用SANAS得到模型结构,完整示例代码请参考sa_nas_mobilenetv2.py或者block_sa_nas_mobilenetv2.py。
接口介绍#
请参考。
1. 配置搜索空间#
详细的搜索空间配置可以参考神经网络搜索API文档。
config = [('MobileNetV2Space')]
2. 利用搜索空间初始化SANAS实例#
from paddleslim.nas import SANAS sa_nas = SANAS( config, server_addr=("", 8881), init_temperature=10.24, reduce_rate=0.85, search_steps=300, is_server=True)
3. 根据实例化的NAS得到当前的网络结构#
archs = sa_nas.next_archs()
4. 根据得到的网络结构和输入构造训练和测试program#
import paddle.fluid as fluid train_program = fluid.Program() test_program = fluid.Program() startup_program = fluid.Program() with fluid.program_guard(train_program, startup_program): data = fluid.data(name='data', shape=[None, 3, 32, 32], dtype='float32') label = fluid.data(name='label', shape=[None, 1], dtype='int64') for arch in archs: data = arch(data) output = fluid.layers.fc(data, 10) softmax_out = fluid.layers.softmax(input=output, use_cudnn=False) cost = fluid.layers.cross_entropy(input=softmax_out, label=label) avg_cost = fluid.layers.mean(cost) acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1) test_program = train_program.clone(for_test=True) sgd = fluid.optimizer.SGD(learning_rate=1e-3) sgd.minimize(avg_cost)
5. 根据构造的训练program添加限制条件#
from paddleslim.analysis import flops if flops(train_program) > 321208544: continue
6. 回传score#
sa_nas.reward(score)