未验证 提交 73019e56 编写于 作者: C ceci3 提交者: GitHub

fix for python3.7 (#45)

* update fix

* fix init token

* fix bug

* update

* update doc
上级 501ab9d4
......@@ -2,69 +2,25 @@
本示例介绍如何使用网络结构搜索接口,搜索到一个更小或者精度更高的模型,该文档仅介绍paddleslim中SANAS的使用及如何利用SANAS得到模型结构,完整示例代码请参考sa_nas_mobilenetv2.py或者block_sa_nas_mobilenetv2.py。
## 接口介绍
请参考。
### 1. 配置搜索空间
详细的搜索空间配置可以参考<a href='../../../paddleslim/nas/nas_api.md'>神经网络搜索API文档</a>
```
config = [('MobileNetV2Space')]
```
### 2. 利用搜索空间初始化SANAS实例
```
from paddleslim.nas import SANAS
sa_nas = SANAS(
config,
server_addr=("", 8881),
init_temperature=10.24,
reduce_rate=0.85,
search_steps=300,
is_server=True)
```
## 数据准备
本示例默认使用cifar10数据,cifar10数据会根据调用的paddle接口自动下载,无需额外准备。
### 3. 根据实例化的NAS得到当前的网络结构
```
archs = sa_nas.next_archs()
```
### 4. 根据得到的网络结构和输入构造训练和测试program
```
import paddle.fluid as fluid
train_program = fluid.Program()
test_program = fluid.Program()
startup_program = fluid.Program()
## 接口介绍
请参考<a href='../../docs/docs/api/nas_api.md'>神经网络搜索API文档</a>
with fluid.program_guard(train_program, startup_program):
data = fluid.data(name='data', shape=[None, 3, 32, 32], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
for arch in archs:
data = arch(data)
output = fluid.layers.fc(data, 10)
softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)
cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_cost = fluid.layers.mean(cost)
acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)
本示例为在MobileNetV2的搜索空间上搜索FLOPs更小的模型。
## 1 搜索空间配置
默认搜索空间为`MobileNetV2`,详细的搜索空间配置请参考<a href='../../docs/docs/search_space.md'>搜索空间配置文档</a>
test_program = train_program.clone(for_test=True)
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(avg_cost)
```
## 2 启动训练
### 5. 根据构造的训练program添加限制条件
### 2.1 启动基于MobileNetV2初始模型结构构造搜索空间的实验
```shell
CUDA_VISIBLE_DEVICES=0 python sa_nas_mobilenetv2.py
```
from paddleslim.analysis import flops
if flops(train_program) > 321208544:
continue
```
### 6. 回传score
```
sa_nas.reward(score)
### 2.2 启动基于MobileNetV2的block构造搜索空间的实验
```shell
CUDA_VISIBLE_DEVICES=0 python block_sa_nas_mobilenetv2.py
```
......@@ -32,7 +32,7 @@ paddleslim.nas.SANAS(configs, server_addr=("", 8881), init_temperature=None, red
```python
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
sanas = SANAS(config=config)
sanas = SANAS(configs=config)
```
!!! note "Note"
......@@ -48,26 +48,6 @@ sanas = SANAS(config=config)
- 初始化token如果是随机生成的话,代表初始化token是一个比较差的token,SA算法可以处于一种不稳定的阶段进行搜索,尽可能的随机探索所有可能得token,从而找到一个较好的token。初始温度可以设置的高一些,例如设置为1000,退火率相对设置的小一些。
paddlesim.nas.SANAS.tokens2arch(tokens)
: 通过一组token得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。tokens的形式是一个列表,tokens映射到搜索空间转换成相应的网络结构,一组token对应唯一的一个网络结构。
**参数:**
- **tokens(list):** - 一组token。
**返回:**
根据传入的token得到一个模型结构实例。
**示例代码:**
```python
import paddle.fluid as fluid
input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = sanas.token2arch(tokens)
for arch in archs:
output = arch(input)
input = output
```
paddleslim.nas.SANAS.next_archs()
: 获取下一组模型结构。
......@@ -84,7 +64,6 @@ for arch in archs:
input = output
```
paddleslim.nas.SANAS.reward(score)
: 把当前模型结构的得分情况回传。
......@@ -95,6 +74,27 @@ paddleslim.nas.SANAS.reward(score)
**返回:**
模型结构更新成功或者失败,成功则返回`True`,失败则返回`False`
paddlesim.nas.SANAS.tokens2arch(tokens)
: 通过一组token得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。tokens的形式是一个列表,tokens映射到搜索空间转换成相应的网络结构,一组token对应唯一的一个网络结构。
**参数:**
- **tokens(list):** - 一组token。
**返回:**
根据传入的token得到一个模型结构实例。
**示例代码:**
```python
import paddle.fluid as fluid
input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = sanas.token2arch(tokens)
for arch in archs:
output = arch(input)
input = output
```
paddleslim.nas.SANAS.current_info()
: 返回当前token和搜索过程中最好的token和reward。
......
......@@ -71,3 +71,13 @@ class ControllerClient(object):
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
def request_current_info(self):
"""
Request for current information.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
socket_client.send("current_info".encode())
current_info = socket_client.recv(1024).decode()
return eval(current_info)
......@@ -90,10 +90,18 @@ class ControllerServer(object):
(self._search_steps))) and not self._closed:
conn, addr = self._socket_server.accept()
message = conn.recv(1024).decode()
_logger.debug(message)
if message.strip("\n") == "next_tokens":
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
elif message.strip("\n") == "current_info":
current_info = dict()
current_info['best_tokens'] = self._controller.best_tokens
current_info['best_reward'] = self._controller.max_reward
current_info[
'current_tokens'] = self._controller.current_tokens
conn.send(str(current_info).encode())
else:
_logger.debug("recv message from {}: [{}]".format(addr,
message))
......
......@@ -81,7 +81,7 @@ class SAController(EvolutionaryController):
self._iter = iters
self._checkpoints = checkpoints
self._searched = searched if searched != None else dict()
self._current_token = init_tokens
self._current_tokens = init_tokens
def __getstate__(self):
d = {}
......
......@@ -162,10 +162,7 @@ class SANAS(object):
Returns:
dict<name, value>: a dictionary include best tokens, best reward and current reward.
"""
current_dict = dict()
current_dict['best_tokens'] = self._controller.best_tokens
current_dict['best_reward'] = self._controller.max_reward
current_dict['current_tokens'] = self._controller.current_tokens
current_dict = self._controller_client.request_current_info()
return current_dict
def next_archs(self):
......
......@@ -19,7 +19,7 @@ from paddle.fluid.param_attr import ParamAttr
def conv_bn_layer(input,
filter_size,
num_filters,
stride,
stride=1,
padding='SAME',
num_groups=1,
act=None,
......@@ -52,9 +52,9 @@ def conv_bn_layer(input,
bias_attr=False)
bn_name = name + '_bn'
return fluid.layers.batch_norm(
input=conv,
act = act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(name=bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(name=bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
......@@ -58,7 +58,7 @@ class InceptionABlockSpace(SearchSpaceBase):
"""
The initial token.
"""
return get_random_tokens(self.range_table)
return get_random_tokens(self.range_table())
def range_table(self):
"""
......@@ -175,7 +175,7 @@ class InceptionABlockSpace(SearchSpaceBase):
input = self._inceptionA(
input,
A_tokens=filter_nums,
filter_size=filter_size,
filter_size=int(filter_size),
stride=stride,
pool_type=pool_type,
name='inceptionA_{}'.format(i + 1))
......@@ -287,7 +287,7 @@ class InceptionCBlockSpace(SearchSpaceBase):
"""
The initial token.
"""
return get_random_tokens(self.range_table)
return get_random_tokens(self.range_table())
def range_table(self):
"""
......@@ -408,13 +408,13 @@ class InceptionCBlockSpace(SearchSpaceBase):
pool_type = 'avg' if layer_setting[11] == 0 else 'max'
if stride == 2:
layer_count += 1
if check_points((layer_count - 1) in return_block):
if check_points((layer_count - 1), return_block):
mid_layer[layer_count - 1] = input
input = self._inceptionC(
input,
C_tokens=filter_nums,
filter_size=filter_size,
filter_size=int(filter_size),
stride=stride,
pool_type=pool_type,
name='inceptionC_{}'.format(i + 1))
......
......@@ -60,7 +60,7 @@ class MobileNetV2BlockSpace(SearchSpaceBase):
self.scale = scale
def init_tokens(self):
return get_random_tokens(self.range_table)
return get_random_tokens(self.range_table())
def range_table(self):
range_table_base = []
......@@ -153,7 +153,7 @@ class MobileNetV2BlockSpace(SearchSpaceBase):
c=int(c * self.scale),
n=n,
s=s,
k=k,
k=int(k),
name='mobilenetv2_' + str(i + 1))
in_c = int(c * self.scale)
......@@ -289,9 +289,11 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
scale=1.0):
super(MobileNetV1BlockSpace, self).__init__(input_size, output_size,
block_num, block_mask)
# use input_size and output_size to compute self.downsample_num
self.downsample_num = compute_downsample_num(self.input_size,
self.output_size)
if self.block_mask == None:
# use input_size and output_size to compute self.downsample_num
self.downsample_num = compute_downsample_num(self.input_size,
self.output_size)
if self.block_num != None:
assert self.downsample_num <= self.block_num, 'downsample numeber must be LESS THAN OR EQUAL TO block_num, but NOW: downsample numeber is {}, block_num is {}'.format(
self.downsample_num, self.block_num)
......@@ -305,7 +307,7 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
self.scale = scale
def init_tokens(self):
return get_random_tokens(self.range_table)
return get_random_tokens(self.range_table())
def range_table(self):
range_table_base = []
......@@ -383,7 +385,7 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
num_filters2=filter_num2,
stride=stride,
scale=self.scale,
kernel_size=kernel_size,
kernel_size=int(kernel_size),
name='mobilenetv1_{}'.format(str(i + 1)))
if return_mid_layer:
......
......@@ -191,7 +191,7 @@ class MobileNetV1Space(SearchSpaceBase):
num_groups=filter_num1,
stride=stride,
scale=self.scale,
kernel_size=kernel_size,
kernel_size=int(kernel_size),
name='mobilenetv1_{}'.format(str(i + 1)))
### return_block and end_points means block num
......
......@@ -182,7 +182,7 @@ class MobileNetV2Space(SearchSpaceBase):
c=int(c * self.scale),
n=n,
s=s,
k=k,
k=int(k),
name='mobilenetv2_conv' + str(i))
in_c = int(c * self.scale)
......
......@@ -32,9 +32,10 @@ class ResNetBlockSpace(SearchSpaceBase):
def __init__(self, input_size, output_size, block_num, block_mask=None):
super(ResNetBlockSpace, self).__init__(input_size, output_size,
block_num, block_mask)
# use input_size and output_size to compute self.downsample_num
self.downsample_num = compute_downsample_num(self.input_size,
self.output_size)
if self.block_mask == None:
# use input_size and output_size to compute self.downsample_num
self.downsample_num = compute_downsample_num(self.input_size,
self.output_size)
if self.block_num != None:
assert self.downsample_num <= self.block_num, 'downsample numeber must be LESS THAN OR EQUAL TO block_num, but NOW: downsample numeber is {}, block_num is {}'.format(
self.downsample_num, self.block_num)
......@@ -44,7 +45,7 @@ class ResNetBlockSpace(SearchSpaceBase):
self.k_size = np.array([3, 5])
def init_tokens(self):
return get_random_tokens(self.range_table)
return get_random_tokens(self.range_table())
def range_table(self):
range_table_base = []
......@@ -133,7 +134,7 @@ class ResNetBlockSpace(SearchSpaceBase):
num_filters1=filter_num1,
num_filters2=filter_num3,
num_filters3=filter_num3,
kernel_size=k_size,
kernel_size=int(k_size),
repeat1=repeat1,
repeat2=repeat2,
stride=stride,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册