未验证 提交 df7c9f2e 编写于 作者: C Chang Xu 提交者: GitHub

fix_docs_nas (#822)

Co-authored-by: Nceci3 <ceci3@users.noreply.github.com>
上级 cdcfa69a
...@@ -33,7 +33,7 @@ import numpy as np ...@@ -33,7 +33,7 @@ import numpy as np
Please set a unused port when build instance of SANAS. Please set a unused port when build instance of SANAS.
```python ```python
sanas = slim.nas.SANAS(configs=[('MobileNetV2Space')], server_addr=("", 8337), save_checkpoint=None) sanas = slim.nas.SANAS(configs=[('MobileNetV2Space')], server_addr=("", 8911), save_checkpoint=None)
``` ```
## 3. define function about building program ## 3. define function about building program
......
...@@ -88,7 +88,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火 ...@@ -88,7 +88,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
from paddleslim.nas import SANAS from paddleslim.nas import SANAS
config = [('MobileNetV2Space')] config = [('MobileNetV2Space')]
paddle.enable_static() paddle.enable_static()
sanas = SANAS(configs=config, , server_addr=("",8822)) sanas = SANAS(configs=config, server_addr=("",8822))
input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32') input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = sanas.next_archs() archs = sanas.next_archs()
for arch in archs: for arch in archs:
...@@ -115,7 +115,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火 ...@@ -115,7 +115,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
from paddleslim.nas import SANAS from paddleslim.nas import SANAS
config = [('MobileNetV2Space')] config = [('MobileNetV2Space')]
paddle.enable_static() paddle.enable_static()
sanas = SANAS(configs=config, server_addr=("", 8883)) sanas = SANAS(configs=config, server_addr=("", 8823))
archs = sanas.next_archs() archs = sanas.next_archs()
### 假设网络计算出来的score是1,实际代码中使用时需要返回真实score。 ### 假设网络计算出来的score是1,实际代码中使用时需要返回真实score。
...@@ -142,7 +142,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火 ...@@ -142,7 +142,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
from paddleslim.nas import SANAS from paddleslim.nas import SANAS
config = [('MobileNetV2Space')] config = [('MobileNetV2Space')]
paddle.enable_static() paddle.enable_static()
sanas = SANAS(configs=config, server_addr=("",8823)) sanas = SANAS(configs=config, server_addr=("", 8824))
input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32') input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
tokens = ([0] * 25) tokens = ([0] * 25)
archs = sanas.tokens2arch(tokens)[0] archs = sanas.tokens2arch(tokens)[0]
...@@ -163,7 +163,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火 ...@@ -163,7 +163,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
from paddleslim.nas import SANAS from paddleslim.nas import SANAS
config = [('MobileNetV2Space')] config = [('MobileNetV2Space')]
paddle.enable_static() paddle.enable_static()
sanas = SANAS(configs=config, server_addr=("", 8885)) sanas = SANAS(configs=config, server_addr=("", 8825))
print(sanas.current_info()) print(sanas.current_info())
...@@ -233,7 +233,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习 ...@@ -233,7 +233,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习
config = [('MobileNetV2Space')] config = [('MobileNetV2Space')]
paddle.enable_static() paddle.enable_static()
rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8824)) rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8826))
.. py:method:: next_archs(obs=None) .. py:method:: next_archs(obs=None)
...@@ -255,7 +255,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习 ...@@ -255,7 +255,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习
from paddleslim.nas import RLNAS from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')] config = [('MobileNetV2Space')]
paddle.enable_static() paddle.enable_static()
rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8825)) rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8827))
input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32') input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = rlnas.next_archs(1)[0] archs = rlnas.next_archs(1)[0]
for arch in archs: for arch in archs:
...@@ -280,7 +280,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习 ...@@ -280,7 +280,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习
from paddleslim.nas import RLNAS from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')] config = [('MobileNetV2Space')]
paddle.enable_static() paddle.enable_static()
rlnas = RLNAS(key='lstm', configs=config, server_addr=("", 8888)) rlnas = RLNAS(key='lstm', configs=config, server_addr=("", 8828))
rlnas.next_archs(1) rlnas.next_archs(1)
rlnas.reward(1.0) rlnas.reward(1.0)
...@@ -307,7 +307,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习 ...@@ -307,7 +307,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习
from paddleslim.nas import RLNAS from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')] config = [('MobileNetV2Space')]
paddle.enable_static() paddle.enable_static()
rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8826)) rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8829))
archs = rlnas.final_archs(1) archs = rlnas.final_archs(1)
print(archs) print(archs)
...@@ -330,7 +330,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习 ...@@ -330,7 +330,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习
from paddleslim.nas import RLNAS from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')] config = [('MobileNetV2Space')]
paddle.enable_static() paddle.enable_static()
rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8827)) rlnas = RLNAS(key='lstm', configs=config, server_addr=("",8830))
input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32') input = paddle.static.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
tokens = ([0] * 25) tokens = ([0] * 25)
archs = rlnas.tokens2arch(tokens)[0] archs = rlnas.tokens2arch(tokens)[0]
......
...@@ -160,8 +160,8 @@ sanas.reward(float(finally_reward[1])) ...@@ -160,8 +160,8 @@ sanas.reward(float(finally_reward[1]))
```python ```python
for step in range(3): for step in range(3):
archs = sanas.next_archs()[0] archs = sanas.next_archs()[0]
exe, train_program, eval_program, inputs, avg_cost, acc_top1, acc_top5 = build_program(archs) exe, train_program, eval_program, (images,label), avg_cost, acc_top1, acc_top5 = build_program(archs)
train_loader, eval_loader = input_data(inputs) train_loader, eval_loader = input_data(images, label)
current_flops = slim.analysis.flops(train_program) current_flops = slim.analysis.flops(train_program)
if current_flops > 321208544: if current_flops > 321208544:
......
...@@ -16,13 +16,13 @@ OFA的基本流程分为以下步骤: ...@@ -16,13 +16,13 @@ OFA的基本流程分为以下步骤:
PaddleSlim提供了三种获得超网络的方式,具体可以参考[超网络转换](https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/dygraph/ofa/convert_supernet_api.html) PaddleSlim提供了三种获得超网络的方式,具体可以参考[超网络转换](https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/dygraph/ofa/convert_supernet_api.html)
```python ```python
import paddle import paddle
from paddle.vision.models import mobilenet_v1 from paddle.vision.models import mobilenet_v1
from paddleslim.nas.ofa.convert_super import Convert, supernet from paddleslim.nas.ofa.convert_super import Convert, supernet
model = mobilenet_v1() model = mobilenet_v1()
sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4]) sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
sp_model = Convert(sp_net_config).convert(model) sp_model = Convert(sp_net_config).convert(model)
``` ```
### 2. 训练配置 ### 2. 训练配置
......
...@@ -262,14 +262,14 @@ sa_nas.reward(float(valid_top1_list[-1] + valid_top1_list[-2]) / 2) ...@@ -262,14 +262,14 @@ sa_nas.reward(float(valid_top1_list[-1] + valid_top1_list[-2]) / 2)
### 10. 利用demo下的脚本启动搜索 ### 10. 利用demo下的脚本启动搜索
搜索文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/demo/nas/sanas_darts_space.py),搜索过程中限制模型参数量为不大于3.77M。 搜索文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/demo/nas/sanas_darts_space.py),搜索过程中限制模型参数量为不大于3.77M。
```python ```shell
cd demo/nas/ cd demo/nas/
python darts_nas.py python darts_nas.py
``` ```
### 11. 利用demo下的脚本启动最终实验 ### 11. 利用demo下的脚本启动最终实验
最终实验文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/demo/nas/sanas_darts_space.py),最终实验需要训练600epoch。以下示例输入token为`[5, 5, 0, 5, 5, 10, 7, 7, 5, 7, 7, 11, 10, 12, 10, 0, 5, 3, 10, 8]` 最终实验文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/demo/nas/sanas_darts_space.py),最终实验需要训练600epoch。以下示例输入token为`[5, 5, 0, 5, 5, 10, 7, 7, 5, 7, 7, 11, 10, 12, 10, 0, 5, 3, 10, 8]`
```python ```shell
cd demo/nas/ cd demo/nas/
python darts_nas.py --token 5 5 0 5 5 10 7 7 5 7 7 11 10 12 10 0 5 3 10 8 --retain_epoch 600 python darts_nas.py --token 5 5 0 5 5 10 7 7 5 7 7 11 10 12 10 0 5 3 10 8 --retain_epoch 600
``` ```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册