未验证 提交 df7c9f2e 编写于 作者: C Chang Xu 提交者: GitHub

fix_docs_nas (#822)

Co-authored-by: Nceci3 <ceci3@users.noreply.github.com>
上级 cdcfa69a
......@@ -33,7 +33,7 @@ import numpy as np
Please set a unused port when build instance of SANAS.
```python
sanas = slim.nas.SANAS(configs=[('MobileNetV2Space')], server_addr=("", 8337), save_checkpoint=None)
sanas = slim.nas.SANAS(configs=[('MobileNetV2Space')], server_addr=("", 8911), save_checkpoint=None)
```
## 3. define function about building program
......
......@@ -85,8 +85,8 @@ PaddleSlim提供了三种方式构造超网络,下面分别介绍这三种方
models += [nn.Conv2D(4, 4, 3, groups=4)]
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs):
return self.models(inputs)
def forward(self, inputs):
return self.models(inputs)
方式三
------------------
......
......@@ -88,7 +88,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
paddle.enable_static()
sanas = SANAS(configs=config, , server_addr=("",8822))
sanas = SANAS(configs=config, server_addr=("",8822))
input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = sanas.next_archs()
for arch in archs:
......@@ -115,7 +115,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
paddle.enable_static()
sanas = SANAS(configs=config, server_addr=("", 8883))
sanas = SANAS(configs=config, server_addr=("", 8823))
archs = sanas.next_archs()
### 假设网络计算出来的score是1,实际代码中使用时需要返回真实score。
......@@ -142,7 +142,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
paddle.enable_static()
sanas = SANAS(configs=config, server_addr=("",8823))
sanas = SANAS(configs=config, server_addr=("", 8824))
input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
tokens = ([0] * 25)
archs = sanas.tokens2arch(tokens)[0]
......@@ -163,7 +163,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
paddle.enable_static()
sanas = SANAS(configs=config, server_addr=("", 8885))
sanas = SANAS(configs=config, server_addr=("", 8825))
print(sanas.current_info())
......@@ -233,7 +233,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习
config = [('MobileNetV2Space')]
paddle.enable_static()
rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8824))
rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8826))
.. py:method:: next_archs(obs=None)
......@@ -255,7 +255,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习
from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')]
paddle.enable_static()
rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8825))
rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8827))
input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = rlnas.next_archs(1)[0]
for arch in archs:
......@@ -280,7 +280,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习
from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')]
paddle.enable_static()
rlnas = RLNAS(key='lstm', configs=config, server_addr=("", 8888))
rlnas = RLNAS(key='lstm', configs=config, server_addr=("", 8828))
rlnas.next_archs(1)
rlnas.reward(1.0)
......@@ -307,7 +307,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习
from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')]
paddle.enable_static()
rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8826))
rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8829))
archs = rlnas.final_archs(1)
print(archs)
......@@ -330,7 +330,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习
from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')]
paddle.enable_static()
rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8827))
rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8830))
input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
tokens = ([0] * 25)
archs = rlnas.tokens2arch(tokens)[0]
......
......@@ -160,8 +160,8 @@ sanas.reward(float(finally_reward[1]))
```python
for step in range(3):
archs = sanas.next_archs()[0]
exe, train_program, eval_program, inputs, avg_cost, acc_top1, acc_top5 = build_program(archs)
train_loader, eval_loader = input_data(inputs)
exe, train_program, eval_program, (images,label), avg_cost, acc_top1, acc_top5 = build_program(archs)
train_loader, eval_loader = input_data(images, label)
current_flops = slim.analysis.flops(train_program)
if current_flops > 321208544:
......
......@@ -16,13 +16,13 @@ OFA的基本流程分为以下步骤:
PaddleSlim提供了三种获得超网络的方式,具体可以参考[超网络转换](https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/dygraph/ofa/convert_supernet_api.html)
```python
import paddle
from paddle.vision.models import mobilenet_v1
from paddleslim.nas.ofa.convert_super import Convert, supernet
import paddle
from paddle.vision.models import mobilenet_v1
from paddleslim.nas.ofa.convert_super import Convert, supernet
model = mobilenet_v1()
sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
sp_model = Convert(sp_net_config).convert(model)
model = mobilenet_v1()
sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
sp_model = Convert(sp_net_config).convert(model)
```
### 2. 训练配置
......
......@@ -262,14 +262,14 @@ sa_nas.reward(float(valid_top1_list[-1] + valid_top1_list[-2]) / 2)
### 10. 利用demo下的脚本启动搜索
搜索文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/demo/nas/sanas_darts_space.py),搜索过程中限制模型参数量为不大于3.77M。
```python
```shell
cd demo/nas/
python darts_nas.py
```
### 11. 利用demo下的脚本启动最终实验
最终实验文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/demo/nas/sanas_darts_space.py),最终实验需要训练600epoch。以下示例输入token为`[5, 5, 0, 5, 5, 10, 7, 7, 5, 7, 7, 11, 10, 12, 10, 0, 5, 3, 10, 8]`
```python
```shell
cd demo/nas/
python darts_nas.py --token 5 5 0 5 5 10 7 7 5 7 7 11 10 12 10 0 5 3 10 8 --retain_epoch 600
```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册