未验证 提交 02144bca 编写于 作者: W whs 提交者: GitHub

Add one-shot NAS API and mnasnet based search space. (#17)

上级 1664a758
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import ast
import numpy as np
from PIL import Image
import os
import paddle
import paddle.fluid as fluid
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph.base import to_variable
from paddleslim.nas.one_shot import SuperMnasnet
from paddleslim.nas.one_shot import OneShotSearch
def parse_args():
parser = argparse.ArgumentParser("Training for Mnist.")
parser.add_argument(
"--use_data_parallel",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use data parallel mode to train the model."
)
parser.add_argument("-e", "--epoch", default=5, type=int, help="set epoch")
parser.add_argument("--ce", action="store_true", help="run ce")
args = parser.parse_args()
return args
class SimpleImgConv(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
conv_stride=1,
conv_padding=0,
conv_dilation=1,
conv_groups=1,
act=None,
use_cudnn=False,
param_attr=None,
bias_attr=None):
super(SimpleImgConv, self).__init__()
self._conv2d = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
padding=conv_padding,
dilation=conv_dilation,
groups=conv_groups,
param_attr=None,
bias_attr=None,
act=act,
use_cudnn=use_cudnn)
def forward(self, inputs):
x = self._conv2d(inputs)
return x
class MNIST(fluid.dygraph.Layer):
def __init__(self):
super(MNIST, self).__init__()
self._simple_img_conv_pool_1 = SimpleImgConv(1, 20, 2, act="relu")
self.arch = SuperMnasnet(
name_scope="super_net", input_channels=20, out_channels=20)
self._simple_img_conv_pool_2 = SimpleImgConv(20, 50, 2, act="relu")
self.pool_2_shape = 50 * 13 * 13
SIZE = 10
scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5
self._fc = Linear(
self.pool_2_shape,
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs, label=None, tokens=None):
x = self._simple_img_conv_pool_1(inputs)
x = self.arch(x, tokens=tokens) # addddddd
x = self._simple_img_conv_pool_2(x)
x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
x = self._fc(x)
if label is not None:
acc = fluid.layers.accuracy(input=x, label=label)
return x, acc
else:
return x
def test_mnist(model, tokens=None):
acc_set = []
avg_loss_set = []
batch_size = 64
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size, drop_last=True)
for batch_id, data in enumerate(test_reader()):
dy_x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(batch_size, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
label.stop_gradient = True
prediction, acc = model.forward(img, label, tokens=tokens)
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
acc_set.append(float(acc.numpy()))
avg_loss_set.append(float(avg_loss.numpy()))
if batch_id % 100 == 0:
print("Test - batch_id: {}".format(batch_id))
# get test acc and loss
acc_val_mean = np.array(acc_set).mean()
avg_loss_val_mean = np.array(avg_loss_set).mean()
return acc_val_mean
def train_mnist(args, model, tokens=None):
epoch_num = args.epoch
BATCH_SIZE = 64
adam = AdamOptimizer(
learning_rate=0.001, parameter_list=model.parameters())
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=BATCH_SIZE, drop_last=True)
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
label.stop_gradient = True
cost, acc = model.forward(img, label, tokens=tokens)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
if args.use_data_parallel:
avg_loss = model.scale_loss(avg_loss)
avg_loss.backward()
model.apply_collective_grads()
else:
avg_loss.backward()
adam.minimize(avg_loss)
# save checkpoint
model.clear_gradients()
if batch_id % 1 == 0:
print("Loss at epoch {} step {}: {:}".format(epoch, batch_id,
avg_loss.numpy()))
model.eval()
test_acc = test_mnist(model, tokens=tokens)
model.train()
print("Loss at epoch {} , acc is: {}".format(epoch, test_acc))
save_parameters = (not args.use_data_parallel) or (
args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters:
fluid.save_dygraph(model.state_dict(), "save_temp")
print("checkpoint saved")
if __name__ == '__main__':
args = parse_args()
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
model = MNIST()
# step 1: training super net
#train_mnist(args, model)
# step 2: search
best_tokens = OneShotSearch(model, test_mnist)
# step 3: final training
# train_mnist(args, model, best_tokens)
## OneShotSearch
paddleslim.nas.one_shot.OneShotSearch(model, eval_func, strategy='sa', search_steps=100)[代码]()
: 从超级网络中搜索出一个最佳的子网络。
**参数:**
- **model(fluid.dygraph.layer):** 通过在`OneShotSuperNet`前后添加若该模块构建的动态图模块。因为`OneShotSuperNet`是一个超网络,所以`model`也是一个超网络。换句话说,在`model`模块的子模块中,至少有一个是`OneShotSuperNet`的实例。该方法从`model`超网络中搜索得到一个最佳的子网络。超网络`model`需要先被训练,具体细节请参考[OneShotSuperNet]()。
- **eval_func:** 用于评估子网络性能的回调函数。该回调函数需要接受`model`为参数,并调用`model``forward`方法进行性能评估。
- **strategy(str):** 搜索策略的名称。默认为'sa', 当前仅支持'sa'.
- **search_steps(int):** 搜索轮次数。默认为100。
**返回:**
- **best_tokens:** 表示最佳子网络的编码信息(tokens)。
**示例代码:**
请参考[one-shot NAS示例]()
## OneShotSuperNet
用于`OneShot`搜索策略的超级网络的基类,所有超级网络的实现要继承该类。
paddleslim.nas.one_shot.OneShotSuperNet(name_scope)
: 构造方法。
**参数:**
- **name_scope:(str) **超级网络的命名空间。
**返回:**
- **super_net:** 一个`OneShotSuperNet`实例。
init_tokens()
: 获得当前超级网络的初始化子网络的编码,主要用于搜索。
**返回:**
- **tokens(list<int>):** 一个子网络的编码。
range_table()
: 超级网络中各个子网络由一组整型数字编码表示,该方法返回编码每个位置的取值范围。
**返回:**
- **range_table(tuple):** 子网络编码每一位的取值范围。`range_table`格式为`(min_values, max_values)`,其中,`min_values`为一个整型数组,表示每个编码位置可选取的最小值;`max_values`表示每个编码位置可选取的最大值。
_forward_impl(input, tokens)
: 前向计算函数。`OneShotSuperNet`的子类需要实现该函数。
**参数:**
- **input(Variable):** 超级网络的输入。
- **tokens(list<int>):** 执行前向计算所用的子网络的编码。默认为`None`,即随机选取一个子网络执行前向。
**返回:**
- **output(Variable):** 前向计算的输出
forward(self, input, tokens=None)
: 执行前向计算。
**参数:**
- **input(Variable):** 超级网络的输入。
- **tokens(list<int>):** 执行前向计算所用的子网络的编码。默认为`None`,即随机选取一个子网络执行前向。
**返回:**
- **output(Variable):** 前向计算的输出
_random_tokens()
: 随机选取一个子网络,并返回其编码。
**返回:**
- **tokens(list<int>):** 一个子网络的编码。
## SuperMnasnet
[Mnasnet](https://arxiv.org/abs/1807.11626)基础上修改得到的超级网络, 该类继承自`OneShotSuperNet`.
paddleslim.nas.one_shot.SuperMnasnet(name_scope, input_channels=3, out_channels=1280, repeat_times=[6, 6, 6, 6, 6, 6], stride=[1, 1, 1, 1, 2, 1], channels=[16, 24, 40, 80, 96, 192, 320], use_auxhead=False)
: 构造函数。
**参数:**
- **name_scope(str):** 命名空间。
- **input_channels(str):** 当前超级网络的输入的特征图的通道数量。
- **out_channels(str):** 当前超级网络的输出的特征图的通道数量。
- **repeat_times(list):** 每种`block`重复的次数。
- **stride(list):** 一种`block`重复堆叠成`repeat_block``stride`表示每个`repeat_block`的下采样比例。
- **channels(list):** channels[i]和channels[i+1]分别表示第i个`repeat_block`的输入特征图的通道数和输出特征图的通道数。
- **use_auxhead(bool):** 是否使用辅助特征图。如果设置为`True`,则`SuperMnasnet`除了返回输出特征图,还还返回辅助特征图。默认为False.
**返回:**
- **instance(SuperMnasnet):** 一个`SuperMnasnet`实例
**示例:**
```
import paddle
import paddle.fluid as fluid
class MNIST(fluid.dygraph.Layer):
def __init__(self):
super(MNIST, self).__init__()
self.arch = SuperMnasnet(
name_scope="super_net", input_channels=20, out_channels=20)
self.pool_2_shape = 50 * 13 * 13
SIZE = 10
scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5
self._fc = Linear(
self.pool_2_shape,
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs, label=None, tokens=None):
x = self.arch(inputs, tokens=tokens)
x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
x = self._fc(x)
if label is not None:
acc = fluid.layers.accuracy(input=x, label=label)
return x, acc
else:
return x
```
# One Shot NAS 示例
>该示例依赖Paddle1.7.0或Paddle develop版本。
该示例使用MNIST数据,介绍了如何使用PaddleSlim的OneShotNAS接口搜索出一个分类网络。OneShotNAS仅支持动态图,所以该示例完全使用Paddle动态图模式。
## 关键代码介绍
One-shot网络结构搜索策略包含以下步骤:
1. 定义超网络
2. 训练超网络
3. 基于超网络搜索子网络
4. 训练最佳子网络
以下按序介绍各个步骤的关键代码。
### 定义超级网络
按照动态图教程,定义一个分类网络模块,该模块包含4个子模块:`_simple_img_conv_pool_1`,`_simple_img_conv_pool_2`,`super_net``fc`,其中`super_net``SuperMnasnet`的一个实例。
在前向计算过程中,输入图像先后经过子模块`_simple_img_conv_pool_1``super_net``_simple_img_conv_pool_2``fc`的前向计算。
代码如下所示:
```
class MNIST(fluid.dygraph.Layer):
def __init__(self):
super(MNIST, self).__init__()
self._simple_img_conv_pool_1 = SimpleImgConv(1, 20, 2, act="relu")
self.arch = SuperMnasnet(
name_scope="super_net", input_channels=20, out_channels=20)
self._simple_img_conv_pool_2 = SimpleImgConv(20, 50, 2, act="relu")
self.pool_2_shape = 50 * 13 * 13
SIZE = 10
scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5
self._fc = Linear(
self.pool_2_shape,
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs, label=None, tokens=None):
x = self._simple_img_conv_pool_1(inputs)
x = self.arch(x, tokens=tokens) # addddddd
x = self._simple_img_conv_pool_2(x)
x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
x = self._fc(x)
if label is not None:
acc = fluid.layers.accuracy(input=x, label=label)
return x, acc
else:
return x
```
动态图模块MNIST的forward函数接受一个参数`tokens`,用于指定在前向计算中使用的子网络,如果`tokens`为None,则随机选取一个子网络进行前向计算。
### 训练超级网络
网络训练的逻辑定义在`train_mnist`函数中,将`tokens`参数设置为None,进行超网络训练,即在每个batch选取一个超网络进行训练。
代码如下所示:
```
with fluid.dygraph.guard(place):
model = MNIST()
train_mnist(args, model)
```
### 搜索最佳子网络
使用PaddleSlim提供的`OneShotSearch`接口搜索最佳子网络。传入已定义且训练好的超网络实例`model`和一个用于评估子网络的回调函数`test_mnist`.
代码如下:
```
best_tokens = OneShotSearch(model, test_mnist)
```
### 训练最佳子网络
获得最佳的子网络的编码`best_tokens`后,调用之前定义的`train_mnist`方法进行子网络的训练。代码如下:
```
train_mnist(args, model, best_tokens)
```
## 启动示例
执行以下代码运行示例:
```
python train.py
```
执行`python train.py --help`查看更多可配置选项。
## FAQ
...@@ -9,6 +9,7 @@ nav: ...@@ -9,6 +9,7 @@ nav:
- 量化训练: tutorials/quant_aware_demo.md - 量化训练: tutorials/quant_aware_demo.md
- Embedding量化: tutorials/quant_embedding_demo.md - Embedding量化: tutorials/quant_embedding_demo.md
- SA搜索: tutorials/nas_demo.md - SA搜索: tutorials/nas_demo.md
- One-shot搜索: tutorials/one_shot_nas_demo.md
- 搜索空间: search_space.md - 搜索空间: search_space.md
- 知识蒸馏: tutorials/distillation_demo.md - 知识蒸馏: tutorials/distillation_demo.md
- API: - API:
...@@ -17,6 +18,8 @@ nav: ...@@ -17,6 +18,8 @@ nav:
- 模型分析: api/analysis_api.md - 模型分析: api/analysis_api.md
- 知识蒸馏: api/single_distiller_api.md - 知识蒸馏: api/single_distiller_api.md
- SA搜索: api/nas_api.md - SA搜索: api/nas_api.md
- One-shot搜索: api/one_shot_api.md
- 搜索空间: search_space.md
- 硬件延时评估表: table_latency.md - 硬件延时评估表: table_latency.md
- 算法原理: algo/algo.md - 算法原理: algo/algo.md
......
...@@ -11,8 +11,12 @@ ...@@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from ..nas import search_space
from .search_space import * from .search_space import *
from .sa_nas import SANAS from ..nas import sa_nas
from .sa_nas import *
__all__ = ['SANAS'] __all__ = []
__all__ += sa_nas.__all__
__all__ += search_space.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from ..one_shot import one_shot_nas
from .one_shot_nas import *
from ..one_shot import super_mnasnet
from .super_mnasnet import *
__all__ = []
__all__ += one_shot_nas.__all__
__all__ += super_mnasnet.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
from ...common import SAController
__all__ = ['OneShotSuperNet', 'OneShotSearch']
def OneShotSearch(model, eval_func, strategy='sa', search_steps=100):
"""
Search a best tokens which represents a sub-network.
Archs:
model(fluid.dygraph.Layer): A dynamic graph module whose sub-modules should contain
one instance of `OneShotSuperNet` at least.
eval_func(function): A callback function which accept model and tokens as arguments.
strategy(str): The name of strategy used to search. Default: 'sa'.
search_steps(int): The total steps for searching.
Returns:
tokens(list): The best tokens searched.
"""
super_net = None
for layer in model.sublayers(include_sublayers=False):
print("layer: {}".format(layer))
if isinstance(layer, OneShotSuperNet):
super_net = layer
break
assert super_net is not None
controller = None
if strategy == "sa":
contoller = SAController(
range_table=super_net.range_table(),
init_tokens=super_net.init_tokens())
assert (controller is not None, "Unsupported searching strategy.")
for i in range(search_steps):
tokens = contoller.next_tokens()
reward = eval_func(model, tokens)
contoller.update(tokens, reward, i)
return contoller.best_tokens()
class OneShotSuperNet(fluid.dygraph.Layer):
"""
The base class of super net used in one-shot searching strategy.
A super net is a dygraph layer.
Args:
name_scope(str): The name scope of super net.
"""
def __init__(self, name_scope):
super(OneShotSuperNet, self).__init__(name_scope)
def init_tokens(self):
"""Get init tokens in search space.
Return:
tokens(list): The init tokens which is a list of integer.
"""
raise NotImplementedError('Abstract method.')
def range_table(self):
"""Get range table of current search space.
Return:
range_table(tuple): The maximum value and minimum value in each position of tokens
with format `(min_values, max_values)`. The `min_values` is
a list of integers indicating the minimum values while `max_values`
indicating the maximum values.
"""
raise NotImplementedError('Abstract method.')
def _forward_impl(self, *inputs, **kwargs):
"""
Defines the computation performed at every call.
Should be overridden by all subclasses.
Args:
inputs(tuple): unpacked tuple arguments
kwargs(dict): unpacked dict arguments
"""
raise NotImplementedError('Abstract method.')
def forward(self, input, tokens=None):
"""
Defines the computation performed at every call.
Args:
input(variable): The input of super net.
tokens(list): The tokens used to generate a sub-network.
None means computing in super net training mode.
Otherwise, it will execute the sub-network generated by tokens.
The `tokens` should be set in searching stage and final training stage.
Default: None.
Returns:
output(varaible): The output of super net.
"""
if tokens == None:
tokens = self._random_tokens()
return self._forward_impl(input, tokens=tokens)
def _random_tokens(self):
tokens = []
for min_v, max_v in zip(self.range_table()[0], self.range_table()[1]):
tokens.append(np.random.randint(min_v, max_v))
return tokens
import paddle
from paddle import fluid
from paddle.fluid.layer_helper import LayerHelper
import numpy as np
from one_shot_nas import OneShotSuperNet
__all__ = ['SuperMnasnet']
class DConvBlock(fluid.dygraph.Layer):
def __init__(self,
name_scope,
in_channels,
channels,
expansion,
stride,
kernel_size=3,
padding=1):
super(DConvBlock, self).__init__(name_scope)
self.expansion = expansion
self.in_channels = in_channels
self.channels = channels
self.stride = stride
self.flops = 0
self.flops_calculated = False
self.expand = fluid.dygraph.Conv2D(
in_channels,
num_filters=in_channels * expansion,
filter_size=1,
stride=1,
padding=0,
act=None,
bias_attr=False)
self.expand_bn = fluid.dygraph.BatchNorm(
num_channels=in_channels * expansion, act='relu6')
self.dconv = fluid.dygraph.Conv2D(
in_channels * expansion,
num_filters=in_channels * expansion,
filter_size=kernel_size,
stride=stride,
padding=padding,
act=None,
bias_attr=False,
groups=in_channels * expansion,
use_cudnn=False)
self.dconv_bn = fluid.dygraph.BatchNorm(
num_channels=in_channels * expansion, act='relu6')
self.project = fluid.dygraph.Conv2D(
in_channels * expansion,
num_filters=channels,
filter_size=1,
stride=1,
padding=0,
act=None,
bias_attr=False)
self.project_bn = fluid.dygraph.BatchNorm(
num_channels=channels, act=None)
self.shortcut = fluid.dygraph.Conv2D(
in_channels,
num_filters=channels,
filter_size=1,
stride=1,
padding=0,
act=None,
bias_attr=False)
self.shortcut_bn = fluid.dygraph.BatchNorm(
num_channels=channels, act=None)
def get_flops(self, input, output, op):
if not self.flops_calculated:
flops = input.shape[1] * output.shape[1] * (
op._filter_size**2) * output.shape[2] * output.shape[3]
if op._groups:
flops /= op._groups
self.flops += flops
def forward(self, inputs):
expand_x = self.expand_bn(self.expand(inputs))
self.get_flops(inputs, expand_x, self.expand)
dconv_x = self.dconv_bn(self.dconv(expand_x))
self.get_flops(expand_x, dconv_x, self.dconv)
proj_x = self.project_bn(self.project(dconv_x))
self.get_flops(dconv_x, proj_x, self.project)
if self.in_channels != self.channels and self.stride == 1:
shortcut = self.shortcut_bn(self.shortcut(inputs))
self.get_flops(inputs, shortcut, self.shortcut)
elif self.stride == 1:
shortcut = inputs
self.flops_calculated = True
if self.stride == 1:
out = fluid.layers.elementwise_add(x=proj_x, y=shortcut)
return out
return proj_x
class SearchBlock(fluid.dygraph.Layer):
def __init__(self,
name_scope,
in_channels,
channels,
stride,
kernel_size=3,
padding=1):
super(SearchBlock, self).__init__(name_scope)
self._stride = stride
self.block_list = []
self.flops = [0 for i in range(10)]
self.flops_calculated = [False if i < 6 else True for i in range(10)]
kernels = [3, 5, 7]
expansions = [3, 6]
for k in kernels:
for e in expansions:
self.block_list.append(
DConvBlock(self.full_name(), in_channels, channels, e,
stride, k, (k - 1) // 2))
self.add_sublayer("expansion_{}_kernel_{}".format(e, k),
self.block_list[-1])
def forward(self, inputs, arch):
if arch >= 6:
return inputs
out = self.block_list[arch](inputs)
if not self.flops_calculated[arch]:
self.flops[arch] = self.block_list[arch].flops
self.flops_calculated[arch] = True
return out
class AuxiliaryHead(fluid.dygraph.Layer):
def __init__(self, name_scope, num_classes):
super(AuxiliaryHead, self).__init__(name_scope)
self.pool1 = fluid.dygraph.Pool2D(
5, 'avg', pool_stride=3, pool_padding=0)
self.conv1 = fluid.dygraph.Conv2D(128, 1, bias_attr=False)
self.bn1 = fluid.dygraph.BatchNorm(128, act='relu6')
self.conv2 = fluid.dygraph.Conv2D(768, 2, bias_attr=False)
self.bn2 = fluid.dygraph.BatchNorm(768, act='relu6')
self.classifier = fluid.dygraph.FC(num_classes, act='softmax')
self.layer_helper = LayerHelper(self.full_name(), act='relu6')
def forward(self, inputs): #pylint: disable=arguments-differ
inputs = self.layer_helper.append_activation(inputs)
inputs = self.pool1(inputs)
inputs = self.conv1(inputs)
inputs = self.bn1(inputs)
inputs = self.conv2(inputs)
inputs = self.bn2(inputs)
inputs = self.classifier(inputs)
return inputs
class SuperMnasnet(OneShotSuperNet):
def __init__(self,
name_scope,
input_channels=3,
out_channels=1280,
repeat_times=[6, 6, 6, 6, 6, 6],
stride=[1, 1, 1, 1, 2, 1],
channels=[16, 24, 40, 80, 96, 192, 320],
use_auxhead=False):
super(SuperMnasnet, self).__init__(name_scope)
self.flops = 0
self.repeat_times = repeat_times
self.flops_calculated = False
self.last_tokens = None
self._conv = fluid.dygraph.Conv2D(
input_channels, 32, 3, 1, 1, act=None, bias_attr=False)
self._bn = fluid.dygraph.BatchNorm(32, act='relu6')
self._sep_conv = fluid.dygraph.Conv2D(
32,
32,
3,
1,
1,
groups=32,
act=None,
use_cudnn=False,
bias_attr=False)
self._sep_conv_bn = fluid.dygraph.BatchNorm(32, act='relu6')
self._sep_project = fluid.dygraph.Conv2D(
32, 16, 1, 1, 0, act=None, bias_attr=False)
self._sep_project_bn = fluid.dygraph.BatchNorm(16, act='relu6')
self._final_conv = fluid.dygraph.Conv2D(
320, out_channels, 1, 1, 0, act=None, bias_attr=False)
self._final_bn = fluid.dygraph.BatchNorm(out_channels, act='relu6')
self.stride = stride
self.block_list = []
self.use_auxhead = use_auxhead
for _iter, _stride in enumerate(self.stride):
repeat_block = []
for _ind in range(self.repeat_times[_iter]):
if _ind == 0:
block = SearchBlock(self.full_name(), channels[_iter],
channels[_iter + 1], _stride)
else:
block = SearchBlock(self.full_name(), channels[_iter + 1],
channels[_iter + 1], 1)
self.add_sublayer("block_{}_{}".format(_iter, _ind), block)
repeat_block.append(block)
self.block_list.append(repeat_block)
if self.use_auxhead:
self.auxhead = AuxiliaryHead(self.full_name(), 10)
def init_tokens(self):
return [
3, 3, 6, 6, 6, 6, 3, 3, 3, 6, 6, 6, 3, 3, 3, 3, 6, 6, 3, 3, 3, 6,
6, 6, 3, 3, 3, 6, 6, 6, 3, 6, 6, 6, 6, 6
]
def range_table(self):
max_v = [
6, 6, 10, 10, 10, 10, 6, 6, 6, 10, 10, 10, 6, 6, 6, 6, 10, 10, 6,
6, 6, 10, 10, 10, 6, 6, 6, 10, 10, 10, 6, 10, 10, 10, 10, 10
]
return (len(max_v) * [0], max_v)
def get_flops(self, input, output, op):
if not self.flops_calculated:
flops = input.shape[1] * output.shape[1] * (
op._filter_size**2) * output.shape[2] * output.shape[3]
if op._groups:
flops /= op._groups
self.flops += flops
def _forward_impl(self, inputs, tokens=None):
if isinstance(tokens, np.ndarray) and not (tokens == self.last_tokens).all()\
or not isinstance(tokens, np.ndarray) and not tokens == self.last_tokens:
self.flops_calculated = False
self.flops = 0
self.last_tokens = tokens
x = self._bn(self._conv(inputs))
self.get_flops(inputs, x, self._conv)
sep_x = self._sep_conv_bn(self._sep_conv(x))
self.get_flops(x, sep_x, self._sep_conv)
proj_x = self._sep_project_bn(self._sep_project(sep_x))
self.get_flops(sep_x, proj_x, self._sep_project)
x = proj_x
for ind in range(len(self.block_list)):
for b_ind, block in enumerate(self.block_list[ind]):
x = fluid.layers.dropout(block(x, tokens[ind * 6 + b_ind]), 0.)
if not self.flops_calculated:
self.flops += block.flops[tokens[ind * 6 + b_ind]]
if ind == len(self.block_list) * 2 // 3 - 1 and self.use_auxhead:
fc_aux = self.auxhead(x)
final_x = self._final_bn(self._final_conv(x))
self.get_flops(x, final_x, self._final_conv)
# x = self.global_pooling(final_x)
self.flops_calculated = True
if self.use_auxhead:
return final_x, fc_aux
return final_x
...@@ -21,7 +21,6 @@ from .inception_block import InceptionABlockSpace, InceptionCBlockSpace ...@@ -21,7 +21,6 @@ from .inception_block import InceptionABlockSpace, InceptionCBlockSpace
from .search_space_registry import SEARCHSPACE from .search_space_registry import SEARCHSPACE
from .search_space_factory import SearchSpaceFactory from .search_space_factory import SearchSpaceFactory
from .search_space_base import SearchSpaceBase from .search_space_base import SearchSpaceBase
__all__ = [ __all__ = [
'MobileNetV1Space', 'MobileNetV2Space', 'ResNetSpace', 'MobileNetV1Space', 'MobileNetV2Space', 'ResNetSpace',
'MobileNetV1BlockSpace', 'MobileNetV2BlockSpace', 'ResNetBlockSpace', 'MobileNetV1BlockSpace', 'MobileNetV2BlockSpace', 'ResNetBlockSpace',
......
...@@ -19,6 +19,7 @@ __all__ = ['SearchSpaceBase'] ...@@ -19,6 +19,7 @@ __all__ = ['SearchSpaceBase']
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
class SearchSpaceBase(object): class SearchSpaceBase(object):
"""Controller for Neural Architecture Search. """Controller for Neural Architecture Search.
""" """
...@@ -56,3 +57,7 @@ class SearchSpaceBase(object): ...@@ -56,3 +57,7 @@ class SearchSpaceBase(object):
model arch model arch
""" """
raise NotImplementedError('Abstract method.') raise NotImplementedError('Abstract method.')
def super_net(self):
"""This function is just used in one shot NAS strategy. Return a super graph."""
raise NotImplementedError('Abstract method.')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册