From 66497ea6f3193c4e7524360d24a402fe2fe1bfbd Mon Sep 17 00:00:00 2001 From: ceci3 Date: Wed, 22 Jan 2020 15:23:54 +0800 Subject: [PATCH] update example (#53) --- docs/docs/api/nas_api.md | 40 ++++++++++++++++++++++---- paddleslim/common/controller_server.py | 2 +- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/docs/docs/api/nas_api.md b/docs/docs/api/nas_api.md index 23c5c4f1..a6501ffa 100644 --- a/docs/docs/api/nas_api.md +++ b/docs/docs/api/nas_api.md @@ -57,11 +57,15 @@ paddleslim.nas.SANAS.next_archs() **示例代码:** ```python import paddle.fluid as fluid +from paddleslim.nas import SANAS +config = [('MobileNetV2Space')] +sanas = SANAS(configs=config) input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32') archs = sanas.next_archs() for arch in archs: output = arch(input) input = output +print(output) ``` paddleslim.nas.SANAS.reward(score) @@ -74,13 +78,26 @@ paddleslim.nas.SANAS.reward(score) **返回:** 模型结构更新成功或者失败,成功则返回`True`,失败则返回`False`。 +**示例代码:** +```python +import paddle.fluid as fluid +from paddleslim.nas import SANAS +config = [('MobileNetV2Space')] +sanas = SANAS(configs=config) +archs = sanas.next_archs() + +### 假设网络计算出来的score是1,实际代码中使用时需要返回真实score。 +score=float(1.0) +sanas.reward(float(score)) +``` + paddlesim.nas.SANAS.tokens2arch(tokens) -: 通过一组token得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。tokens的形式是一个列表,tokens映射到搜索空间转换成相应的网络结构,一组token对应唯一的一个网络结构。 +: 通过一组tokens得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。tokens的形式是一个列表,tokens映射到搜索空间转换成相应的网络结构,一组tokens对应唯一的一个网络结构。 **参数:** -- **tokens(list):** - 一组token。 +- **tokens(list):** - 一组tokens。tokens的长度和范取决于搜索空间。 **返回:** 根据传入的token得到一个模型结构实例。 @@ -88,11 +105,13 @@ paddlesim.nas.SANAS.tokens2arch(tokens) **示例代码:** ```python import paddle.fluid as fluid +from paddleslim.nas import SANAS +config = [('MobileNetV2Space')] +sanas = SANAS(configs=config) input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32') -archs = sanas.token2arch(tokens) -for arch in archs: - output = arch(input) - input = output +tokens = ([0] * 25) +archs = sanas.tokens2arch(tokens)[0] +print(archs(input)) ``` paddleslim.nas.SANAS.current_info() @@ -100,3 +119,12 @@ paddleslim.nas.SANAS.current_info() **返回:** 搜索过程中最好的token,reward和当前训练的token,形式为dict。 + +**示例代码:** +```python +import paddle.fluid as fluid +from paddleslim.nas import SANAS +config = [('MobileNetV2Space')] +sanas = SANAS(configs=config) +print(sanas.current_info()) +``` diff --git a/paddleslim/common/controller_server.py b/paddleslim/common/controller_server.py index 40943556..fb4d76bc 100644 --- a/paddleslim/common/controller_server.py +++ b/paddleslim/common/controller_server.py @@ -135,7 +135,7 @@ class ControllerServer(object): )) > 1: self._client.pop(key_client) self._client_num -= 1 - _logger.info( + _logger.debug( "client: {}, client_num: {}, compare_time: {}".format( self._client, self._client_num, self._compare_time)) -- GitLab