diff --git a/docs/en/api_en/index_en.rst b/docs/en/api_en/index_en.rst
index 6e15635a74d55337363743128c2018fd3f8474ad..9fd0325a9142160f4a53c792ed430f48ca5f145a 100644
--- a/docs/en/api_en/index_en.rst
+++ b/docs/en/api_en/index_en.rst
@@ -15,5 +15,5 @@ API Documents
paddleslim.quant.rst
paddleslim.nas.rst
paddleslim.nas.one_shot.rst
- paddleslim.nas.search_space.rst
paddleslim.pantheon.rst
+ search_space_en.rst
diff --git a/docs/en/api_en/paddleslim.nas.rst b/docs/en/api_en/paddleslim.nas.rst
index 30eddca0baf1355dbd5872f7fcb97e71698a98c2..f6b17bcf87600e96c4b3b31573d9db3b0291a08c 100644
--- a/docs/en/api_en/paddleslim.nas.rst
+++ b/docs/en/api_en/paddleslim.nas.rst
@@ -1,18 +1,12 @@
paddleslim\.nas package
=======================
-.. automodule:: paddleslim.nas
- :members:
- :undoc-members:
- :show-inheritance:
-
Subpackages
-----------
.. toctree::
paddleslim.nas.one_shot
- paddleslim.nas.search_space
Submodules
----------
diff --git a/docs/en/api_en/paddleslim.nas.search_space.rst b/docs/en/api_en/paddleslim.nas.search_space.rst
deleted file mode 100644
index f078d4af7cc0d9d77a7ea05e3cff43ce3c112788..0000000000000000000000000000000000000000
--- a/docs/en/api_en/paddleslim.nas.search_space.rst
+++ /dev/null
@@ -1,108 +0,0 @@
-paddleslim\.nas\.search\_space package
-======================================
-
-.. automodule:: paddleslim.nas.search_space
- :members:
- :undoc-members:
- :show-inheritance:
-
-Submodules
-----------
-
-paddleslim\.nas\.search\_space\.base\_layer module
---------------------------------------------------
-
-.. automodule:: paddleslim.nas.search_space.base_layer
- :members:
- :undoc-members:
- :show-inheritance:
-
-paddleslim\.nas\.search\_space\.combine\_search\_space module
--------------------------------------------------------------
-
-.. automodule:: paddleslim.nas.search_space.combine_search_space
- :members:
- :undoc-members:
- :show-inheritance:
-
-paddleslim\.nas\.search\_space\.inception\_block module
--------------------------------------------------------
-
-.. automodule:: paddleslim.nas.search_space.inception_block
- :members:
- :undoc-members:
- :show-inheritance:
-
-paddleslim\.nas\.search\_space\.mobilenet\_block module
--------------------------------------------------------
-
-.. automodule:: paddleslim.nas.search_space.mobilenet_block
- :members:
- :undoc-members:
- :show-inheritance:
-
-paddleslim\.nas\.search\_space\.mobilenetv1 module
---------------------------------------------------
-
-.. automodule:: paddleslim.nas.search_space.mobilenetv1
- :members:
- :undoc-members:
- :show-inheritance:
-
-paddleslim\.nas\.search\_space\.mobilenetv2 module
---------------------------------------------------
-
-.. automodule:: paddleslim.nas.search_space.mobilenetv2
- :members:
- :undoc-members:
- :show-inheritance:
-
-paddleslim\.nas\.search\_space\.resnet module
----------------------------------------------
-
-.. automodule:: paddleslim.nas.search_space.resnet
- :members:
- :undoc-members:
- :show-inheritance:
-
-paddleslim\.nas\.search\_space\.resnet\_block module
-----------------------------------------------------
-
-.. automodule:: paddleslim.nas.search_space.resnet_block
- :members:
- :undoc-members:
- :show-inheritance:
-
-paddleslim\.nas\.search\_space\.search\_space\_base module
-----------------------------------------------------------
-
-.. automodule:: paddleslim.nas.search_space.search_space_base
- :members:
- :undoc-members:
- :show-inheritance:
-
-paddleslim\.nas\.search\_space\.search\_space\_factory module
--------------------------------------------------------------
-
-.. automodule:: paddleslim.nas.search_space.search_space_factory
- :members:
- :undoc-members:
- :show-inheritance:
-
-paddleslim\.nas\.search\_space\.search\_space\_registry module
---------------------------------------------------------------
-
-.. automodule:: paddleslim.nas.search_space.search_space_registry
- :members:
- :undoc-members:
- :show-inheritance:
-
-paddleslim\.nas\.search\_space\.utils module
---------------------------------------------
-
-.. automodule:: paddleslim.nas.search_space.utils
- :members:
- :undoc-members:
- :show-inheritance:
-
-
diff --git a/docs/en/api_en/search_space_en.rst b/docs/en/api_en/search_space_en.rst
new file mode 100644
index 0000000000000000000000000000000000000000..020bb08e092c404c5c5c4729e0625f2a70a42c97
--- /dev/null
+++ b/docs/en/api_en/search_space_en.rst
@@ -0,0 +1,115 @@
+search space
+========
+Search Space used in neural architecture search. Search Space is a collection of model architecture, the purpose of SANAS is to get a model which FLOPs or latency is smaller or percision is higher.
+
+search space which paddleslim.nas provided
+-------
+
+Based on origin model architecture:
+1. MobileNetV2Space
+ MobileNetV2's architecture can reference: [code](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/mobilenet_v2.py#L29), [paper](https://arxiv.org/abs/1801.04381)
+
+2. MobileNetV1Space
+ MobilNetV1's architecture can reference: [code](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/mobilenet_v1.py#L29), [paper](https://arxiv.org/abs/1704.04861)
+
+3. ResNetSpace
+ ResNetSpace's architecture can reference: [code](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/resnet.py#L30), [paper](https://arxiv.org/pdf/1512.03385.pdf)
+
+
+Based on block from different model:
+1. MobileNetV1BlockSpace
+ MobileNetV1Block's architecture can reference: [code](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/mobilenet_v1.py#L173)
+
+2. MobileNetV2BlockSpace
+ MobileNetV2Block's architecture can reference: [code](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/mobilenet_v2.py#L174)
+
+3. ResNetBlockSpace
+ ResNetBlock's architecture can reference: [code](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/resnet.py#L148)
+
+4. InceptionABlockSpace
+ InceptionABlock's architecture can reference: [code](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/inception_v4.py#L140)
+
+5. InceptionCBlockSpace
+ InceptionCBlock's architecture can reference: [code](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/inception_v4.py#L291)
+
+
+How to use search space
+--------
+1. Only need to specify the name of search space if use the space based on origin model architecture, such as configs for class SANAS is [('MobileNetV2Space')] if you want to use origin MobileNetV2 as search space.
+2. Use search space paddleslim.nas provided based on block:
+ 2.1 Use `input_size`, `output_size` and `block_num` to construct search space, such as configs for class SANAS is ('MobileNetV2BlockSpace', {'input_size': 224, 'output_size': 32, 'block_num': 10})].
+ 2.2 Use `block_mask` to construct search space, such as configs for class SANAS is [('MobileNetV2BlockSpace', {'block_mask': [0, 1, 1, 1, 1, 0, 1, 0]})].
+
+How to write yourself search space
+--------
+If you want to write yourself search space, you need to inherit base class named SearchSpaceBase and overwrite following functions:
+ 1. Function to get initial tokens(function `init_tokens`), set the initial tokens which you want, every token in tokens means index of search list, such as if tokens=[0, 3, 5], it means the list of channel of current model architecture is [8, 40, 128].
+ 2. Function about the length of every token in tokens(function `range_table`), range of every token in tokens.
+ 3. Function to get model architecture according to tokens(function `token2arch`), get model architecture according to tokens in the search process.
+
+For example, how to add a search space with resnet block. New search space can NOT has the same name with existing search space.
+
+```python
+### import necessary head file
+from .search_space_base import SearchSpaceBase
+from .search_space_registry import SEARCHSPACE
+import numpy as np
+
+### use decorator SEARCHSPACE.register to register yourself search space to search space NameSpace
+@SEARCHSPACE.register
+### define a search space class inherit the base class SearchSpaceBase
+class ResNetBlockSpace2(SearchSpaceBase):
+ def __init__(self, input_size, output_size, block_num, block_mask):
+ ### define the iterm you want to search, such as the numeber of channel, the number of convolution repeat, the size of kernel.
+ ### self.filter_num represents the search list about the numeber of channel.
+ self.filter_num = np.array([8, 16, 32, 40, 64, 128, 256, 512])
+
+ ### define initial tokens, the length of initial tokens according to block_num or block_mask.
+ def init_tokens(self):
+ return [0] * 3 * len(self.block_mask)
+
+ ### define the range of index in tokens.
+ def range_table(self):
+ return [len(self.filter_num)] * 3 * len(self.block_mask)
+
+ ### transform tokens to model architecture.
+ def token2arch(self, tokens=None):
+ if tokens == None:
+ tokens = self.init_tokens()
+
+ self.bottleneck_params_list = []
+ for i in range(len(self.block_mask)):
+ self.bottleneck_params_list.append(self.filter_num[tokens[i * 3 + 0]],
+ self.filter_num[tokens[i * 3 + 1]],
+ self.filter_num[tokens[i * 3 + 2]],
+ 2 if self.block_mask[i] == 1 else 1)
+
+ def net_arch(input):
+ for i, layer_setting in enumerate(self.bottleneck_params_list):
+ channel_num, stride = layer_setting[:-1], layer_setting[-1]
+ input = self._resnet_block(input, channel_num, stride, name='resnet_layer{}'.format(i+1))
+
+ return input
+
+ return net_arch
+
+ ### code to get block.
+ def _resnet_block(self, input, channel_num, stride, name=None):
+ shortcut_conv = self._shortcut(input, channel_num[2], stride, name=name)
+ input = self._conv_bn_layer(input=input, num_filters=channel_num[0], filter_size=1, act='relu', name=name + '_conv0')
+ input = self._conv_bn_layer(input=input, num_filters=channel_num[1], filter_size=3, stride=stride, act='relu', name=name + '_conv1')
+ input = self._conv_bn_layer(input=input, num_filters=channel_num[2], filter_size=1, name=name + '_conv2')
+ return fluid.layers.elementwise_add(x=shortcut_conv, y=input, axis=0, name=name+'_elementwise_add')
+
+ def _shortcut(self, input, channel_num, stride, name=None):
+ channel_in = input.shape[1]
+ if channel_in != channel_num or stride != 1:
+ return self.conv_bn_layer(input, num_filters=channel_num, filter_size=1, stride=stride, name=name+'_shortcut')
+ else:
+ return input
+
+ def _conv_bn_layer(self, input, num_filters, filter_size, stride=1, padding='SAME', act=None, name=None):
+ conv = fluid.layers.conv2d(input, num_filters, filter_size, stride, name=name+'_conv')
+ bn = fluid.layers.batch_norm(conv, act=act, name=name+'_bn')
+ return bn
+```
diff --git a/docs/en/quick_start/index_en.rst b/docs/en/quick_start/index_en.rst
index 97dfc1a8ec02ed84fda7d3161e267e5deb49f42a..57a08cb58262de444f14e6121f96ab2435eac3f8 100644
--- a/docs/en/quick_start/index_en.rst
+++ b/docs/en/quick_start/index_en.rst
@@ -6,6 +6,7 @@ Quick Start
:maxdepth: 1
pruning_tutorial_en.md
+ nas_tutorial_en.md
quant_aware_tutorial_en.md
quant_post_tutorial_en.md
diff --git a/docs/en/quick_start/nas_tutorial_en.md b/docs/en/quick_start/nas_tutorial_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..63d695b0b611728ba932c236f7b07cd6f419b1a9
--- /dev/null
+++ b/docs/en/quick_start/nas_tutorial_en.md
@@ -0,0 +1,155 @@
+# Nerual Architecture Search for Image Classification
+
+This tutorial shows how to use [API](../api/nas_api.md) about SANAS in PaddleSlim. We start experiment based on MobileNetV2 as example. The tutorial contains follow section.
+
+1. necessary imports
+2. initial SANAS instance
+3. define function about building program
+4. define function about input data
+5. define function about training
+6. define funciton about evaluation
+7. start search
+ 7.1 fetch model architecture
+ 7.2 build program
+ 7.3 define input data
+ 7.4 train model
+ 7.5 evaluate model
+ 7.6 reture score
+8. full example
+
+
+The following chapter describes each steps in order.
+
+## 1. import dependency
+Please make sure that you haved installed Paddle correctly, then do the necessary imports.
+```python
+import paddle
+import paddle.fluid as fluid
+import paddleslim as slim
+import numpy as np
+```
+
+## 2. initial SANAS instance
+```python
+sanas = slim.nas.SANAS(configs=[('MobileNetV2Space')], server_addr=("", 8337), save_checkpoint=None)
+```
+
+## 3. define function about building program
+Build program about training and evaluation according to the model architecture.
+```python
+def build_program(archs):
+ train_program = fluid.Program()
+ startup_program = fluid.Program()
+ with fluid.program_guard(train_program, startup_program):
+ data = fluid.data(name='data', shape=[None, 3, 32, 32], dtype='float32')
+ label = fluid.data(name='label', shape=[None, 1], dtype='int64')
+ output = archs(data)
+ output = fluid.layers.fc(input=output, size=10)
+
+ softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)
+ cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
+ avg_cost = fluid.layers.mean(cost)
+ acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)
+ acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5)
+ test_program = fluid.default_main_program().clone(for_test=True)
+
+ optimizer = fluid.optimizer.Adam(learning_rate=0.1)
+ optimizer.minimize(avg_cost)
+
+ place = fluid.CPUPlace()
+ exe = fluid.Executor(place)
+ exe.run(startup_program)
+ return exe, train_program, test_program, (data, label), avg_cost, acc_top1, acc_top5
+```
+
+## 4. define function about input data
+The dataset we used is cifar10, and `paddle.dataset.cifar` in Paddle including the download and pre-read about cifar.
+```python
+def input_data(inputs):
+ train_reader = paddle.batch(paddle.reader.shuffle(paddle.dataset.cifar.train10(cycle=False), buf_size=1024),batch_size=256)
+ train_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())
+ eval_reader = paddle.batch(paddle.dataset.cifar.test10(cycle=False), batch_size=256)
+ eval_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())
+ return train_reader, train_feeder, eval_reader, eval_feeder
+```
+
+## 5. define function about training
+Start training.
+```python
+def start_train(program, data_reader, data_feeder):
+ outputs = [avg_cost.name, acc_top1.name, acc_top5.name]
+ for data in data_reader():
+ batch_reward = exe.run(program, feed=data_feeder.feed(data), fetch_list = outputs)
+ print("TRAIN: loss: {}, acc1: {}, acc5:{}".format(batch_reward[0], batch_reward[1], batch_reward[2]))
+```
+
+## 6. define funciton about evaluation
+Start evaluating.
+```python
+def start_eval(program, data_reader, data_feeder):
+ reward = []
+ outputs = [avg_cost.name, acc_top1.name, acc_top5.name]
+ for data in data_reader():
+ batch_reward = exe.run(program, feed=data_feeder.feed(data), fetch_list = outputs)
+ reward_avg = np.mean(np.array(batch_reward), axis=1)
+ reward.append(reward_avg)
+ print("TEST: loss: {}, acc1: {}, acc5:{}".format(batch_reward[0], batch_reward[1], batch_reward[2]))
+ finally_reward = np.mean(np.array(reward), axis=0)
+ print("FINAL TEST: avg_cost: {}, acc1: {}, acc5: {}".format(finally_reward[0], finally_reward[1], finally_reward[2]))
+ return finally_reward
+```
+
+## 7. start search
+The following steps describes how to get current model architecture and what need to do after get the model architecture. If you want to start a full example directly, please jump to Step 9.
+
+### 7.1 fetch model architecture
+Get Next model architecture by `next_archs()`.
+```python
+archs = sanas.next_archs()[0]
+```
+
+### 7.2 build program
+Get program according to the function in Step3 and model architecture from Step 7.1.
+```python
+exe, train_program, eval_program, inputs, avg_cost, acc_top1, acc_top5 = build_program(archs)
+```
+
+### 7.3 define input data
+```python
+train_reader, train_feeder, eval_reader, eval_feeder = input_data(inputs)
+```
+
+### 7.4 train model
+Start training according to train program and data.
+```python
+start_train(train_program, train_reader, train_feeder)
+```
+### 7.5 evaluate model
+Start evaluation according to evaluation program and data.
+```python
+finally_reward = start_eval(eval_program, eval_reader, eval_feeder)
+```
+### 7.6 reture score
+```
+sanas.reward(float(finally_reward[1]))
+```
+
+## 8. full example
+The following is a full example about neural architecture search, it uses FLOPs as constraint and includes 3 steps, it means train 3 model architectures which is satisfied constraint, and train 7 epoch for each model architecture.
+```python
+for step in range(3):
+ archs = sanas.next_archs()[0]
+ exe, train_program, eval_progarm, inputs, avg_cost, acc_top1, acc_top5 = build_program(archs)
+ train_reader, train_feeder, eval_reader, eval_feeder = input_data(inputs)
+
+ current_flops = slim.analysis.flops(train_program)
+ if current_flops > 321208544:
+ continue
+
+ for epoch in range(7):
+ start_train(train_program, train_reader, train_feeder)
+
+ finally_reward = start_eval(eval_program, eval_reader, eval_feeder)
+
+ sanas.reward(float(finally_reward[1]))
+```
diff --git a/docs/zh_cn/api_cn/index.rst b/docs/zh_cn/api_cn/index.rst
index 722949487fcda71f65a48ae6873b76a85f8eb319..8151587784cc54aa2681fd46cc606552aea735ab 100644
--- a/docs/zh_cn/api_cn/index.rst
+++ b/docs/zh_cn/api_cn/index.rst
@@ -16,5 +16,5 @@ API文档
prune_api.rst
quantization_api.rst
single_distiller_api.rst
- search_space.md
+ search_space.rst
table_latency.md
diff --git a/docs/zh_cn/api_cn/nas_api.rst b/docs/zh_cn/api_cn/nas_api.rst
index b4b0d38abfdfd35b40e1f6fa5ab919aabd4a1407..f1f0214d5c6e8d7e1a1390a88adfea4612355772 100644
--- a/docs/zh_cn/api_cn/nas_api.rst
+++ b/docs/zh_cn/api_cn/nas_api.rst
@@ -125,7 +125,7 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
**参数:**
- - **tokens(list):** - 一组tokens。tokens的长度和范取决于搜索空间。
+ - **tokens(list):** - 一组tokens。tokens的长度和范围取决于搜索空间。
**返回:**
根据传入的token得到一个模型结构实例。
diff --git a/docs/zh_cn/api_cn/search_space.md b/docs/zh_cn/api_cn/search_space.rst
similarity index 84%
rename from docs/zh_cn/api_cn/search_space.md
rename to docs/zh_cn/api_cn/search_space.rst
index 442607f23bc6220acdbcf7db1286f8a7028b7983..bbd0c3f52e360e499954ea30cae4d1edb985b77e 100644
--- a/docs/zh_cn/api_cn/search_space.md
+++ b/docs/zh_cn/api_cn/search_space.rst
@@ -1,11 +1,12 @@
-# 搜索空间
+搜索空间
+=========
+搜索空间是神经网络搜索中的一个概念。搜索空间是一系列模型结构的汇集, SANAS主要是利用模拟退火的思想在搜索空间中搜索到一个比较小的模型结构或者一个精度比较高的模型结构。
-## 搜索空间简介
-: 搜索空间是神经网络搜索中的一个概念。搜索空间是一系列模型结构的汇集, SANAS主要是利用模拟退火的思想在搜索空间中搜索到一个比较小的模型结构或者一个精度比较高的模型结构。
+paddleslim.nas 提供的搜索空间
+--------
-## paddleslim.nas 提供的搜索空间
+根据初始模型结构构造搜索空间:
-##### 根据初始模型结构构造搜索空间
1. MobileNetV2Space
MobileNetV2的网络结构可以参考:[代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/mobilenet_v2.py#L29),[论文](https://arxiv.org/abs/1801.04381)
@@ -16,7 +17,7 @@
ResNetSpace的网络结构可以参考:[代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/resnet.py#L30),[论文](https://arxiv.org/pdf/1512.03385.pdf)
-##### 根据相应模型的block构造搜索空间
+根据相应模型的block构造搜索空间:
1. MobileNetV1BlockSpace
MobileNetV1Block的结构可以参考:[代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/mobilenet_v1.py#L173)
@@ -33,20 +34,22 @@
InceptionCBlock结构可以参考:[代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/inception_v4.py#L291)
-## 搜索空间示例
+搜索空间使用示例
+--------
-1. 使用paddleslim中提供用初始的模型结构来构造搜索空间的话,仅需要指定搜索空间名字即可。例如:如果使用原本的MobileNetV2的搜索空间进行搜索的话,传入SANAS中的config直接指定为[('MobileNetV2Space')]。
+1. 使用paddleslim中提供用初始的模型结构来构造搜索空间的话,仅需要指定搜索空间名字即可。例如:如果使用原本的MobileNetV2的搜索空间进行搜索的话,传入SANAS中的configs直接指定为[('MobileNetV2Space')]。
2. 使用paddleslim中提供的block搜索空间构造搜索空间:
- 2.1 使用`input_size`, `output_size`和`block_num`来构造搜索空间。例如:传入SANAS的config可以指定为[('MobileNetV2BlockSpace', {'input_size': 224, 'output_size': 32, 'block_num': 10})]。
- 2.2 使用`block_mask`构造搜索空间。例如:传入SANAS的config可以指定为[('MobileNetV2BlockSpace', {'block_mask': [0, 1, 1, 1, 1, 0, 1, 0]})]。
+ 2.1 使用`input_size`, `output_size`和`block_num`来构造搜索空间。例如:传入SANAS的configs可以指定为[('MobileNetV2BlockSpace', {'input_size': 224, 'output_size': 32, 'block_num': 10})]。
+ 2.2 使用`block_mask`构造搜索空间。例如:传入SANAS的configs可以指定为[('MobileNetV2BlockSpace', {'block_mask': [0, 1, 1, 1, 1, 0, 1, 0]})]。
-## 自定义搜索空间(search space)
+自定义搜索空间(search space)
+--------
自定义搜索空间类需要继承搜索空间基类并重写以下几部分:
1. 初始化的tokens(`init_tokens`函数),可以设置为自己想要的tokens列表, tokens列表中的每个数字指的是当前数字在相应的搜索列表中的索引。例如本示例中若tokens=[0, 3, 5],则代表当前模型结构搜索到的通道数为[8, 40, 128]。
- 2. token中每个数字的搜索列表长度(`range_table`函数),tokens中每个token的索引范围。
- 3. 根据token产生模型结构(`token2arch`函数),根据搜索到的tokens列表产生模型结构。
+ 2. tokens中每个数字的搜索列表长度(`range_table`函数),tokens中每个token的索引范围。
+ 3. 根据tokens产生模型结构(`token2arch`函数),根据搜索到的tokens列表产生模型结构。
以新增reset block为例说明如何构造自己的search space。自定义的search space不能和已有的search space同名。
diff --git a/image_classification_nas_quick_start.ipynb b/image_classification_nas_quick_start.ipynb
deleted file mode 100644
index 15f8b6a9b06fa08ba587c33120715407947f5fdd..0000000000000000000000000000000000000000
--- a/image_classification_nas_quick_start.ipynb
+++ /dev/null
@@ -1,286 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 图像分类网络结构搜索-快速开始\n",
- "\n",
- "该教程以图像分类模型MobileNetV2为例,说明如何在cifar10数据集上快速使用[网络结构搜索接口](../api/nas_api.md)。\n",
- "该示例包含以下步骤:\n",
- "\n",
- "1. 导入依赖\n",
- "2. 初始化SANAS搜索实例\n",
- "3. 构建网络\n",
- "4. 启动搜索实验\n",
- "5. 定义输入数据\n",
- "6. 训练模型\n",
- "7. 评估模型\n",
- "8. 回传当前模型的得分\n",
- "9. 完整示例\n",
- "\n",
- "\n",
- "以下章节依次介绍每个步骤的内容。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 1. 导入依赖\n",
- "请确认已正确安装Paddle,导入需要的依赖包。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import paddle\n",
- "import paddle.fluid as fluid\n",
- "import paddleslim as slim\n",
- "import numpy as np"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 2. 初始化SANAS搜索实例"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "sanas = slim.nas.SANAS(configs=[('MobileNetV2Space')], server_addr=(\"\", 8339))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 3. 构建网络\n",
- "根据传入的网络结构构造训练program和测试program。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def build_program(archs):\n",
- " train_program = fluid.Program()\n",
- " startup_program = fluid.Program()\n",
- " with fluid.program_guard(train_program, startup_program):\n",
- " data = fluid.data(name='data', shape=[None, 3, 32, 32], dtype='float32')\n",
- " label = fluid.data(name='label', shape=[None, 1], dtype='int64')\n",
- " output = archs(data)\n",
- " output = fluid.layers.fc(input=output, size=10)\n",
- "\n",
- " softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)\n",
- " cost = fluid.layers.cross_entropy(input=softmax_out, label=label)\n",
- " avg_cost = fluid.layers.mean(cost)\n",
- " acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)\n",
- " acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5)\n",
- " test_program = fluid.default_main_program().clone(for_test=True)\n",
- " \n",
- " optimizer = fluid.optimizer.Adam(learning_rate=0.1)\n",
- " optimizer.minimize(avg_cost)\n",
- "\n",
- " place = fluid.CPUPlace()\n",
- " exe = fluid.Executor(place)\n",
- " exe.run(startup_program)\n",
- " return exe, train_program, test_program, (data, label), avg_cost, acc_top1, acc_top5"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 4. 启动搜索实验\n",
- "以下步骤拆解说明了如何获得当前模型结构以及获得当前模型结构之后应该有的步骤,如果想要看如何启动搜索实验的完整示例可以看步骤9。\n",
- "\n",
- "### 4.1 获取模型结构\n",
- "调用`next_archs()`函数获取到下一个模型结构。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "archs = sanas.next_archs()[0]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 4.2 构造program\n",
- "调用步骤3中的函数,根据5.1中的模型结构构造相应的program。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "exe, train_program, test_program, inputs, avg_cost, acc_top1, acc_top5 = build_program(archs)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 5. 定义输入数据\n",
- "使用的数据集为cifar10,paddle框架中`paddle.dataset.cifar`包括了cifar数据集的下载和读取,代码如下:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "train_reader = paddle.batch(paddle.reader.shuffle(paddle.dataset.cifar.train10(cycle=False), buf_size=1024),batch_size=256)\n",
- "train_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())\n",
- "test_reader = paddle.batch(paddle.dataset.cifar.test10(cycle=False), batch_size=256)\n",
- "test_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 6. 训练模型\n",
- "根据上面得到的训练program启动训练。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "outputs = [avg_cost.name, acc_top1.name, acc_top5.name]\n",
- "for data in train_reader():\n",
- " batch_reward = exe.run(train_program, feed=train_feeder.feed(data), fetch_list = outputs)\n",
- " print(\"TRAIN: loss: {}, acc1: {}, acc5:{}\".format(batch_reward[0], batch_reward[1], batch_reward[2]))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 7. 评估模型\n",
- "根据上面得到的评估program启动评估。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "reward = []\n",
- "for data in test_reader():\n",
- " batch_reward = exe.run(test_program, feed=test_feeder.feed(data), fetch_list = outputs)\n",
- " reward_avg = np.mean(np.array(batch_reward), axis=1)\n",
- " reward.append(reward_avg)\n",
- " print(\"TEST: loss: {}, acc1: {}, acc5:{}\".format(batch_reward[0], batch_reward[1], batch_reward[2]))\n",
- "finally_reward = np.mean(np.array(reward), axis=0)\n",
- "print(\"FINAL TEST: avg_cost: {}, acc1: {}, acc5: {}\".format(finally_reward[0], finally_reward[1], finally_reward[2]))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 8. 回传当前模型的得分"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "sanas.reward(float(finally_reward[1]))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 9. 完整示例\n",
- "以下是一个完整的搜索实验示例,示例中使用FLOPs作为约束条件,搜索实验一共搜索3个step,表示搜索到3个满足条件的模型结构进行训练,每搜>索到一个网络结构训练7个epoch。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "for step in range(3):\n",
- " archs = sanas.next_archs()[0]\n",
- " exe, train_program, test_progarm, inputs, avg_cost, acc_top1, acc_top5 = build_program(archs)\n",
- "\n",
- " current_flops = slim.analysis.flops(train_program)\n",
- " if current_flops > 321208544:\n",
- " continue\n",
- " \n",
- " train_reader = paddle.batch(paddle.reader.shuffle(paddle.dataset.cifar.train10(cycle=False), buf_size=1024),batch_size=256)\n",
- " train_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())\n",
- " test_reader = paddle.batch(paddle.dataset.cifar.test10(cycle=False),\n",
- " batch_size=256)\n",
- " test_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())\n",
- "\n",
- " outputs = [avg_cost.name, acc_top1.name, acc_top5.name]\n",
- " for epoch in range(7):\n",
- " for data in train_reader():\n",
- " loss, acc1, acc5 = exe.run(train_program, feed=train_feeder.feed(data), fetch_list = outputs)\n",
- " print(\"TRAIN: loss: {}, acc1: {}, acc5:{}\".format(loss, acc1, acc5))\n",
- "\n",
- " reward = []\n",
- " for data in test_reader():\n",
- " batch_reward = exe.run(test_program, feed=test_feeder.feed(data), fetch_list = outputs)\n",
- " reward_avg = np.mean(np.array(batch_reward), axis=1)\n",
- " reward.append(reward_avg)\n",
- " print(\"TEST: loss: {}, acc1: {}, acc5:{}\".format(batch_reward[0], batch_reward[1], batch_reward[2]))\n",
- " finally_reward = np.mean(np.array(reward), axis=0)\n",
- " print(\"FINAL TEST: avg_cost: {}, acc1: {}, acc5: {}\".format(finally_reward[0], finally_reward[1], finally_reward[2]))\n",
- "\n",
- " sanas.reward(float(finally_reward[1]))"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "language": "python",
- "name": "python2"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 2
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.12"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py
index 245f90480fa438246f4ee7302e240fea01efd3b6..679ecca46d3974f8eefa99403659f4da4ae72457 100644
--- a/paddleslim/nas/sa_nas.py
+++ b/paddleslim/nas/sa_nas.py
@@ -34,6 +34,66 @@ _logger = get_logger(__name__, level=logging.INFO)
class SANAS(object):
+ """
+ SANAS(Simulated Annealing Neural Architecture Search) is a neural architecture search algorithm
+ based on simulated annealing, used in discrete search task generally.
+
+ Args:
+ configs(list): A list of search space configuration with format [(key, {input_size,
+ output_size, block_num, block_mask})]. `key` is the name of search space
+ with data type str. `input_size` and `output_size` are input size and
+ output size of searched sub-network. `block_num` is the number of blocks
+ in searched network, `block_mask` is a list consists by 0 and 1, 0 means
+ normal block, 1 means reduction block.
+ server_addr(tuple): Server address, including ip and port of server. If ip is None or "", will
+ use host ip if is_server = True. Default: ("", 8881).
+ init_temperature(float): Initial temperature in SANAS. If init_temperature and init_tokens are None,
+ default initial temperature is 10.0, if init_temperature is None and
+ init_tokens is not None, default initial temperature is 1.0. The detail
+ configuration about the init_temperature please reference Note. Default: None.
+ reduce_rate(float): Reduce rate in SANAS. The detail configuration about the reduce_rate please
+ reference Note. Default: 0.85.
+ search_steps(int): The steps of searching. Default: 300.
+ init_tokens(list|None): Initial token. If init_tokens is None, SANAS will random generate initial
+ tokens. Default: None.
+ save_checkpoint(string|None): The directory of checkpoint to save, if set to None, not save checkpoint.
+ Default: 'nas_checkpoint'.
+ load_checkpoint(string|None): The directory of checkpoint to load, if set to None, not load checkpoint.
+ Default: None.
+ is_server(bool): Whether current host is controller server. Default: True.
+
+ .. note::
+ - Why need to set initial temperature and reduce rate:
+
+ - SA algorithm preserve a base token(initial token is the first base token, can be set by
+ yourself or random generate) and base score(initial score is -1), next token will be
+ generated based on base token. During the search, if the score which is obtained by the
+ model corresponding to the token is greater than the score which is saved in SA corresponding to
+ base token, current token saved as base token certainly; if score which is obtained by the model
+ corresponding to the token is less than the score which is saved in SA correspinding to base token,
+ current token saved as base token with a certain probability.
+ - For initial temperature, higher is more unstable, it means that SA has a strong possibility to save
+ current token as base token if current score is smaller than base score saved in SA.
+ - For initial temperature, lower is more stable, it means that SA has a small possibility to save
+ current token as base token if current score is smaller than base score saved in SA.
+ - For reduce rate, higher means SA algorithm has slower convergence.
+ - For reduce rate, lower means SA algorithm has faster convergence.
+
+ - How to set initial temperature and reduce rate:
+
+ - If there is a better initial token, and want to search based on this token, we suggest start search
+ experiment in the steady state of the SA algorithm, initial temperature can be set to a small value,
+ such as 1.0, and reduce rate can be set to a large value, such as 0.85. If you want to start search
+ experiment based on the better token with greedy algorithm, which only saved current token as base
+ token if current score higher than base score saved in SA algorithm, reduce rate can be set to a
+ extremely small value, such as 0.85 ** 10.
+
+ - If initial token is generated randomly, it means initial token is a worse token, we suggest start
+ search experiment in the unstable state of the SA algorithm, explore all random tokens as much as
+ possible, and get a better token. Initial temperature can be set a higher value, such as 1000.0,
+ and reduce rate can be set to a small value.
+ """
+
def __init__(self,
configs,
server_addr=("", 8881),
@@ -44,21 +104,6 @@ class SANAS(object):
save_checkpoint='nas_checkpoint',
load_checkpoint=None,
is_server=True):
- """
- Search a group of ratios used to prune program.
- Args:
- configs(list): A list of search space configuration with format [(key, {input_size, output_size, block_num, block_mask})].
- `key` is the name of search space with data type str. `input_size` and `output_size` are
- input size and output size of searched sub-network. `block_num` is the number of blocks in searched network, `block_mask` is a list consists by 0 and 1, 0 means normal block, 1 means reduction block.
- server_addr(tuple): A tuple of server ip and server port for controller server.
- init_temperature(float|None): The init temperature used in simulated annealing search strategy. Default: None.
- reduce_rate(float): The decay rate used in simulated annealing search strategy. Default: None.
- search_steps(int): The steps of searching. Default: 300.
- init_token(list): Init tokens user can set by yourself. Default: None.
- save_checkpoint(string|None): The directory of checkpoint to save, if set to None, not save checkpoint. Default: 'nas_checkpoint'.
- load_checkpoint(string|None): The directory of checkpoint to load, if set to None, not load checkpoint. Default: None.
- is_server(bool): Whether current host is controller server. Default: True.
- """
if not is_server:
assert server_addr[
0] != "", "You should set the IP and port of server when is_server is False."
@@ -149,9 +194,11 @@ class SANAS(object):
def tokens2arch(self, tokens):
"""
- Convert tokens to network architectures.
+ Convert tokens to model architectures.
+ Args
+ tokens: A list of token. The length and range based on search space.:
Returns:
- list: A list of functions that define networks.
+ list: A model architecture instance according to tokens.
"""
return self._search_space.token2arch(tokens)
@@ -166,9 +213,9 @@ class SANAS(object):
def next_archs(self):
"""
- Get next network architectures.
+ Get next model architectures.
Returns:
- list: A list of functions that define networks.
+ list: A list of instance of model architecture.
"""
self._current_tokens = self._controller_client.next_tokens()
_logger.info("current tokens: {}".format(self._current_tokens))
@@ -179,7 +226,7 @@ class SANAS(object):
"""
Return reward of current searched network.
Args:
- score(float): The score of current searched network.
+ score(float): The score of current searched network, bigger is better.
Returns:
bool: True means updating successfully while false means failure.
"""