提交 66497ea6 编写于 作者: C ceci3 提交者: whs

update example (#53)

上级 73019e56
......@@ -57,11 +57,15 @@ paddleslim.nas.SANAS.next_archs()
**示例代码:**
```python
import paddle.fluid as fluid
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
sanas = SANAS(configs=config)
input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = sanas.next_archs()
for arch in archs:
output = arch(input)
input = output
print(output)
```
paddleslim.nas.SANAS.reward(score)
......@@ -74,13 +78,26 @@ paddleslim.nas.SANAS.reward(score)
**返回:**
模型结构更新成功或者失败,成功则返回`True`,失败则返回`False`
**示例代码:**
```python
import paddle.fluid as fluid
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
sanas = SANAS(configs=config)
archs = sanas.next_archs()
### 假设网络计算出来的score是1,实际代码中使用时需要返回真实score。
score=float(1.0)
sanas.reward(float(score))
```
paddlesim.nas.SANAS.tokens2arch(tokens)
: 通过一组token得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。tokens的形式是一个列表,tokens映射到搜索空间转换成相应的网络结构,一组token对应唯一的一个网络结构。
: 通过一组tokens得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。tokens的形式是一个列表,tokens映射到搜索空间转换成相应的网络结构,一组tokens对应唯一的一个网络结构。
**参数:**
- **tokens(list):** - 一组token。
- **tokens(list):** - 一组tokens。tokens的长度和范取决于搜索空间
**返回:**
根据传入的token得到一个模型结构实例。
......@@ -88,11 +105,13 @@ paddlesim.nas.SANAS.tokens2arch(tokens)
**示例代码:**
```python
import paddle.fluid as fluid
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
sanas = SANAS(configs=config)
input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = sanas.token2arch(tokens)
for arch in archs:
output = arch(input)
input = output
tokens = ([0] * 25)
archs = sanas.tokens2arch(tokens)[0]
print(archs(input))
```
paddleslim.nas.SANAS.current_info()
......@@ -100,3 +119,12 @@ paddleslim.nas.SANAS.current_info()
**返回:**
搜索过程中最好的token,reward和当前训练的token,形式为dict。
**示例代码:**
```python
import paddle.fluid as fluid
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
sanas = SANAS(configs=config)
print(sanas.current_info())
```
......@@ -135,7 +135,7 @@ class ControllerServer(object):
)) > 1:
self._client.pop(key_client)
self._client_num -= 1
_logger.info(
_logger.debug(
"client: {}, client_num: {}, compare_time: {}".format(
self._client, self._client_num,
self._compare_time))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册