提交 8f685fbe 编写于 作者: C ceci3

update doc

上级 7afcdc58
...@@ -2,49 +2,73 @@ ...@@ -2,49 +2,73 @@
## SANAS API文档 ## SANAS API文档
### paddleslim.nas.SANAS(configs, server_addr, init_temperature, reduce_rate, search_steps, save_checkpoint, load_checkpoint, is_server) ## class SANAS
初始化一个sanas实例。
---
>paddleslim.nas.SANAS(configs, server_addr, init_temperature, reduce_rate, search_steps, save_checkpoint, load_checkpoint, is_server)
SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火算法进行模型结构搜索的算法,一般用于离散搜索任务。
**参数:** **参数:**
- **configs(list<tuple>): 搜索空间配置列表,格式是[(key, {input_size, output_size, block_num, block_mask})], `input_size``output_size`表示输入和输出的特征图的大小,`block_num`是指搜索网络中的block数量,`block_mask`是一组由0和1组成的列表,0代表不进行下采样的block,1代表下采样的block。 更多paddleslim提供的搜索空间配置可以参考。 - **configs(list<tuple>):** 搜索空间配置列表,格式是`[(key, {input_size, output_size, block_num, block_mask})]`或者`[(key)]`(MobileNetV2、MobilenetV1和ResNet的搜索空间使用和原本网络结构相同的搜索空间,所以仅需指定`key`即可), `input_size``output_size`表示输入和输出的特征图的大小,`block_num`是指搜索网络中的block数量,`block_mask`是一组由0和1组成的列表,0代表不进行下采样的block,1代表下采样的block。 更多paddleslim提供的搜索空间配置可以参考。
- **server_addr(tuple): SANAS的地址,包括server的ip地址和端口号,如果ip地址为None或者为""的话则默认使用本机ip。默认:("", 8881)。 - **server_addr(tuple):** SANAS的地址,包括server的ip地址和端口号,如果ip地址为None或者为""的话则默认使用本机ip。默认:("", 8881)。
- **init_temperature(float): 基于模拟退火进行搜索的初始温度。默认:100。 - **init_temperature(float):** 基于模拟退火进行搜索的初始温度。默认:100。
- **reduce_rate(float): 基于模拟退火进行搜索的衰减率。默认:0.85。 - **reduce_rate(float):** 基于模拟退火进行搜索的衰减率。默认:0.85。
- **search_steps(int): 搜索过程迭代的次数。默认:300。 - **search_steps(int):** 搜索过程迭代的次数。默认:300。
- **save_checkpoint(str|None): 保存checkpoint的文件目录,如果设置为None的话则不保存checkpoint。默认:nas_checkpoint。 - **save_checkpoint(str|None):** 保存checkpoint的文件目录,如果设置为None的话则不保存checkpoint。默认:`./nas_checkpoint`
- **load_checkpoint(str|None): 加载checkpoint的文件目录,如果设置为None的话则不加载checkpoint。默认:None。 - **load_checkpoint(str|None):** 加载checkpoint的文件目录,如果设置为None的话则不加载checkpoint。默认:None。
- **is_server(bool): 当前实例是否要启动一个server。默认:True。 - **is_server(bool):** 当前实例是否要启动一个server。默认:True。
### paddleslim.nas.SANAS.tokens2arch(tokens) **返回:**
一个SANAS类的实例
**示例代码:**
```
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
sanas = SANAS(config=config)
```
---
>tokens2arch(tokens)
通过一组token得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。 通过一组token得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。
**参数:** **参数:**
- **tokens(list): 搜索出来的token。 - **tokens(list):** 一组token。
**返回** **返回**
返回一个模型模型结构实例。 返回一个模型结构实例。
**返回类型** ---
function
### paddleslim.nas.SANAS.next_archs(): >next_archs():
获取下一组模型结构。 获取下一组模型结构。
**返回** **返回**
返回模型结构实例的列表,形式为list<model_arch> 返回模型结构实例的列表,形式为list。
### paddleslim.nas.SANAS.reward(score): **示例代码:**
把当前模型结构的得分情况回传给server,server根据得分判断是否是最优得分。 ```
import paddle.fluid as fluid
input = fluid.data(name='input', shape=[None, 1, 32, 32], dtype='float32')
archs = sanas.next_archs()
for arch in archs:
output = arch(input)
input = output
```
---
>reward(score):
把当前模型结构的得分情况回传。
**参数:** **参数:**
score<float>: 当前模型的得分,分数越大越好。 score<float>:** 当前模型的得分,分数越大越好。
**返回** **返回**
模型结构更新成功或者失败,成功则返回`True`,失败则返回`False` 模型结构更新成功或者失败,成功则返回`True`,失败则返回`False`
**返回类型**
bool类型
**代码示例** **代码示例**
```python ```python
...@@ -56,6 +80,7 @@ config=[('MobileNetV2Space')] ...@@ -56,6 +80,7 @@ config=[('MobileNetV2Space')]
# 实例化SANAS # 实例化SANAS
sa_nas = SANAS(config, server_addr=("", 8887), init_temperature=10.24, reduce_rate=0.85, search_steps=100, is_server=True) sa_nas = SANAS(config, server_addr=("", 8887), init_temperature=10.24, reduce_rate=0.85, search_steps=100, is_server=True)
# 构造输入数据
input = fluid.data(name='input', shape=[None, 1, 32, 32], dtype='float32') input = fluid.data(name='input', shape=[None, 1, 32, 32], dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='int64') label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
for step in range(100): for step in range(100):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册