提交 66497ea6 编写于 作者: C ceci3 提交者: whs

update example (#53)

上级 73019e56
...@@ -57,11 +57,15 @@ paddleslim.nas.SANAS.next_archs() ...@@ -57,11 +57,15 @@ paddleslim.nas.SANAS.next_archs()
**示例代码:** **示例代码:**
```python ```python
import paddle.fluid as fluid import paddle.fluid as fluid
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
sanas = SANAS(configs=config)
input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32') input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = sanas.next_archs() archs = sanas.next_archs()
for arch in archs: for arch in archs:
output = arch(input) output = arch(input)
input = output input = output
print(output)
``` ```
paddleslim.nas.SANAS.reward(score) paddleslim.nas.SANAS.reward(score)
...@@ -74,13 +78,26 @@ paddleslim.nas.SANAS.reward(score) ...@@ -74,13 +78,26 @@ paddleslim.nas.SANAS.reward(score)
**返回:** **返回:**
模型结构更新成功或者失败,成功则返回`True`,失败则返回`False` 模型结构更新成功或者失败,成功则返回`True`,失败则返回`False`
**示例代码:**
```python
import paddle.fluid as fluid
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
sanas = SANAS(configs=config)
archs = sanas.next_archs()
### 假设网络计算出来的score是1,实际代码中使用时需要返回真实score。
score=float(1.0)
sanas.reward(float(score))
```
paddlesim.nas.SANAS.tokens2arch(tokens) paddlesim.nas.SANAS.tokens2arch(tokens)
: 通过一组token得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。tokens的形式是一个列表,tokens映射到搜索空间转换成相应的网络结构,一组token对应唯一的一个网络结构。 : 通过一组tokens得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。tokens的形式是一个列表,tokens映射到搜索空间转换成相应的网络结构,一组tokens对应唯一的一个网络结构。
**参数:** **参数:**
- **tokens(list):** - 一组token。 - **tokens(list):** - 一组tokens。tokens的长度和范取决于搜索空间
**返回:** **返回:**
根据传入的token得到一个模型结构实例。 根据传入的token得到一个模型结构实例。
...@@ -88,11 +105,13 @@ paddlesim.nas.SANAS.tokens2arch(tokens) ...@@ -88,11 +105,13 @@ paddlesim.nas.SANAS.tokens2arch(tokens)
**示例代码:** **示例代码:**
```python ```python
import paddle.fluid as fluid import paddle.fluid as fluid
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
sanas = SANAS(configs=config)
input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32') input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = sanas.token2arch(tokens) tokens = ([0] * 25)
for arch in archs: archs = sanas.tokens2arch(tokens)[0]
output = arch(input) print(archs(input))
input = output
``` ```
paddleslim.nas.SANAS.current_info() paddleslim.nas.SANAS.current_info()
...@@ -100,3 +119,12 @@ paddleslim.nas.SANAS.current_info() ...@@ -100,3 +119,12 @@ paddleslim.nas.SANAS.current_info()
**返回:** **返回:**
搜索过程中最好的token,reward和当前训练的token,形式为dict。 搜索过程中最好的token,reward和当前训练的token,形式为dict。
**示例代码:**
```python
import paddle.fluid as fluid
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
sanas = SANAS(configs=config)
print(sanas.current_info())
```
...@@ -135,7 +135,7 @@ class ControllerServer(object): ...@@ -135,7 +135,7 @@ class ControllerServer(object):
)) > 1: )) > 1:
self._client.pop(key_client) self._client.pop(key_client)
self._client_num -= 1 self._client_num -= 1
_logger.info( _logger.debug(
"client: {}, client_num: {}, compare_time: {}".format( "client: {}, client_num: {}, compare_time: {}".format(
self._client, self._client_num, self._client, self._client_num,
self._compare_time)) self._compare_time))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册