未验证 提交 abb30ef2 编写于 作者: B Bai Yifan 提交者: GitHub

Add DARTS series methods (#213)

上级 108c1587
# 可微分架构搜索DARTS(Differentiable Architecture Search)方法使用示例
本示例介绍如何使用PaddlePaddle进行可微分架构搜索,可以直接使用[DARTS](https://arxiv.org/abs/1806.09055)[PC-DARTS](https://arxiv.org/abs/1907.05737)两种方法,也支持自定义修改后使用其他可微分架构搜索算法。
## 依赖项
> PaddlePaddle >= 1.7.0, graphviz >= 0.11.1
## 数据集
本示例使用`CIFAR10`数据集进行架构搜索,可选择在`CIFAR10``ImageNet`数据集上做架构评估。
`CIFAR10`数据集可以在进行架构搜索或评估的过程中自动下载,`ImageNet`数据集需要自行下载,可参照此[教程](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87)
## 网络结构搜索
搜索方法支持DARTS的一阶、二阶近似搜索方法和PC-DARTS的搜索方法:
``` bash
python search.py # DARTS一阶近似搜索方法
python search.py --unrolled=True # DARTS的二阶近似搜索方法
python search.py --method='PC-DARTS' # PC-DARTS搜索方法
```
模型结构随搜索轮数的变化如图1所示。需要注意的是,图中准确率Acc并不代表该结构最终准确率,为了获得当前结构的最佳准确率,请对得到的genotype做网络结构评估训练。
![networks](images/networks.gif)
<p align="center">
图1: 在CIFAR10数据集上进行搜索的模型结构变化,上半部分为reduction cell,下半部分为normal cell
</p>
使用三种搜索方法得到的结构Genotype已添加到了genotypes.py文件中,`DARTS_V1``DARTS_V2``PC-DARTS`分别代表使用DARTS一阶、二阶近似方法和PC-DARTS搜索方法得到的网络结构。
## 网络结构评估训练
在得到搜索结构Genotype之后,可以对其进行评估训练,从而获得它在特定数据集上的真实性能
```bash
python train.py --arch='PC-DARTS' # 在CIFAR10数据集上对搜索到的结构评估训练
python train_imagenet.py --arch='PC-DARTS' # 在ImageNet数据集上对搜索得到的结构评估训练
```
对搜索到的`DARTS_V1``DARTS_V2``PC-DARTS`做评估训练的结果如下:
| 模型结构 | 数据集 | 准确率 |
| --------------------------- | -------- | --------------- |
| DARTS_V1 | CIFAR10 | 97.01% |
| DARTS(一阶搜索,论文数据) | CIFAR10 | 97.00$\pm$0.14% |
| DARTS_V2 | CIFAR10 | 97.26% |
| DARTS (二阶搜索,论文数据) | CIFAR10 | 97.24$\pm$0.09% |
| DARTS_V2 | ImageNet | 74.12% |
| DARTS (二阶搜索,论文数据) | ImageNet | 73.30% |
| PC-DARTS | CIFAR10 | 97.41% |
| PC-DARTS (论文数据) | CIFAR10 | 97.43$\pm$0.07% |
## 自定义数据集与搜索空间
### 修改数据集
本示例默认使用CIFAR10数据集进行搜索,如果需要替换为其他自定义数据集只需要对reader.py进行少量代码修改:
```python
def train_search(batch_size, train_portion, is_shuffle, args):
datasets = cifar10_reader( #对此进行替换
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'data_batch', is_shuffle, args)
```
将默认使用的`cifar10_reader`替换为特定数据集的reader即可
### 修改搜索空间
本示例提供了DARTS和PC-DARTS两种方法,定义在model_search.py中
可以直接修改model_search.py中定义的`class Network`对搜索空间进行自定义,使用paddleslim.nas.DARTSearch对该结构进行搜索
搜索结束后对model.py做相应的修改进行评估训练。
## 搜索结构可视化
使用以下命令对搜索得到的Genotype结构进行可视化观察
```python
python visualize.py PC-DARTS
```
`PC-DARTS`代表某个Genotype结构,需要预先添加到genotype.py中
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
PRIMITIVES = [
'none', 'max_pool_3x3', 'avg_pool_3x3', 'skip_connect', 'sep_conv_3x3',
'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5'
]
NASNet = Genotype(
normal=[
('sep_conv_5x5', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 0),
('sep_conv_3x3', 0),
('avg_pool_3x3', 1),
('skip_connect', 0),
('avg_pool_3x3', 0),
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('skip_connect', 1),
],
normal_concat=[2, 3, 4, 5, 6],
reduce=[
('sep_conv_5x5', 1),
('sep_conv_7x7', 0),
('max_pool_3x3', 1),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('sep_conv_5x5', 0),
('skip_connect', 3),
('avg_pool_3x3', 2),
('sep_conv_3x3', 2),
('max_pool_3x3', 1),
],
reduce_concat=[4, 5, 6], )
AmoebaNet = Genotype(
normal=[
('avg_pool_3x3', 0),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 2),
('sep_conv_3x3', 0),
('avg_pool_3x3', 3),
('sep_conv_3x3', 1),
('skip_connect', 1),
('skip_connect', 0),
('avg_pool_3x3', 1),
],
normal_concat=[4, 5, 6],
reduce=[
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_7x7', 2),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('conv_7x1_1x7', 0),
('sep_conv_3x3', 5),
],
reduce_concat=[3, 4, 6])
DARTS_V1 = Genotype(
normal=[('sep_conv_5x5', 0), ('dil_conv_3x3', 1), ('sep_conv_3x3', 2),
('sep_conv_5x5', 0), ('sep_conv_5x5', 0), ('dil_conv_3x3', 3),
('sep_conv_3x3', 0), ('max_pool_3x3', 1)],
normal_concat=range(2, 6),
reduce=[('max_pool_3x3', 1), ('max_pool_3x3', 0), ('dil_conv_3x3', 2),
('sep_conv_5x5', 0), ('max_pool_3x3', 0), ('dil_conv_3x3', 3),
('avg_pool_3x3', 3), ('avg_pool_3x3', 4)],
reduce_concat=range(2, 6))
DARTS_V2 = Genotype(
normal=[('dil_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0),
('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0),
('skip_connect', 0), ('sep_conv_3x3', 1)],
normal_concat=range(2, 6),
reduce=[('skip_connect', 1), ('max_pool_3x3', 0), ('max_pool_3x3', 1),
('skip_connect', 2), ('skip_connect', 2), ('dil_conv_5x5', 3),
('skip_connect', 2), ('max_pool_3x3', 1)],
reduce_concat=range(2, 6))
PC_DARTS = Genotype(
normal=[('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_5x5', 0),
('dil_conv_5x5', 2), ('sep_conv_5x5', 0), ('sep_conv_3x3', 2),
('sep_conv_3x3', 0), ('dil_conv_3x3', 1)],
normal_concat=range(2, 6),
reduce=[('avg_pool_3x3', 0), ('sep_conv_3x3', 1), ('skip_connect', 2),
('avg_pool_3x3', 0), ('dil_conv_5x5', 3), ('skip_connect', 2),
('skip_connect', 2), ('avg_pool_3x3', 0)],
reduce_concat=range(2, 6))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import ConstantInitializer, MSRAInitializer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable
from genotypes import PRIMITIVES
from genotypes import Genotype
from operations import *
class ConvBN(fluid.dygraph.Layer):
def __init__(self, c_curr, c_out, kernel_size, padding, stride, name=None):
super(ConvBN, self).__init__()
self.conv = Conv2D(
num_channels=c_curr,
num_filters=c_out,
filter_size=kernel_size,
stride=stride,
padding=padding,
param_attr=fluid.ParamAttr(
name=name + "_conv" if name is not None else None,
initializer=MSRAInitializer()),
bias_attr=False)
self.bn = BatchNorm(
num_channels=c_out,
param_attr=fluid.ParamAttr(
name=name + "_bn_scale" if name is not None else None,
initializer=ConstantInitializer(value=1)),
bias_attr=fluid.ParamAttr(
name=name + "_bn_offset" if name is not None else None,
initializer=ConstantInitializer(value=0)),
moving_mean_name=name + "_bn_mean" if name is not None else None,
moving_variance_name=name + "_bn_variance"
if name is not None else None)
def forward(self, x):
conv = self.conv(x)
bn = self.bn(conv)
return bn
class Classifier(fluid.dygraph.Layer):
def __init__(self, input_dim, num_classes, name=None):
super(Classifier, self).__init__()
self.pool2d = Pool2D(pool_type='avg', global_pooling=True)
self.fc = Linear(
input_dim=input_dim,
output_dim=num_classes,
param_attr=fluid.ParamAttr(
name=name + "_fc_weights" if name is not None else None,
initializer=MSRAInitializer()),
bias_attr=fluid.ParamAttr(
name=name + "_fc_bias" if name is not None else None,
initializer=MSRAInitializer()))
def forward(self, x):
x = self.pool2d(x)
x = fluid.layers.squeeze(x, axes=[2, 3])
out = self.fc(x)
return out
def drop_path(x, drop_prob):
if drop_prob > 0:
keep_prob = 1. - drop_prob
mask = 1 - np.random.binomial(
1, drop_prob, size=[x.shape[0]]).astype(np.float32)
mask = to_variable(mask)
x = fluid.layers.elementwise_mul(x / keep_prob, mask, axis=0)
return x
class Cell(fluid.dygraph.Layer):
def __init__(self, genotype, c_prev_prev, c_prev, c_curr, reduction,
reduction_prev):
super(Cell, self).__init__()
print(c_prev_prev, c_prev, c_curr)
if reduction_prev:
self.preprocess0 = FactorizedReduce(c_prev_prev, c_curr)
else:
self.preprocess0 = ReLUConvBN(c_prev_prev, c_curr, 1, 1, 0)
self.preprocess1 = ReLUConvBN(c_prev, c_curr, 1, 1, 0)
if reduction:
op_names, indices = zip(*genotype.reduce)
concat = genotype.reduce_concat
else:
op_names, indices = zip(*genotype.normal)
concat = genotype.normal_concat
multiplier = len(concat)
self._multiplier = multiplier
self._compile(c_curr, op_names, indices, multiplier, reduction)
def _compile(self, c_curr, op_names, indices, multiplier, reduction):
assert len(op_names) == len(indices)
self._steps = len(op_names) // 2
ops = []
edge_index = 0
for op_name, index in zip(op_names, indices):
stride = 2 if reduction and index < 2 else 1
op = OPS[op_name](c_curr, stride, True)
ops += [op]
edge_index += 1
self._ops = fluid.dygraph.LayerList(ops)
self._indices = indices
def forward(self, s0, s1, drop_prob, training):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices[2 * i]]
h2 = states[self._indices[2 * i + 1]]
op1 = self._ops[2 * i]
op2 = self._ops[2 * i + 1]
h1 = op1(h1)
h2 = op2(h2)
if training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
states += [h1 + h2]
out = fluid.layers.concat(input=states[-self._multiplier:], axis=1)
return out
class AuxiliaryHeadCIFAR(fluid.dygraph.Layer):
def __init__(self, C, num_classes):
super(AuxiliaryHeadCIFAR, self).__init__()
self.avgpool = Pool2D(
pool_size=5, pool_stride=3, pool_padding=0, pool_type='avg')
self.conv_bn1 = ConvBN(
c_curr=C,
c_out=128,
kernel_size=1,
padding=0,
stride=1,
name='aux_conv_bn1')
self.conv_bn2 = ConvBN(
c_curr=128,
c_out=768,
kernel_size=2,
padding=0,
stride=1,
name='aux_conv_bn2')
self.classifier = Classifier(768, num_classes, 'aux')
def forward(self, x):
x = fluid.layers.relu(x)
x = self.avgpool(x)
conv1 = self.conv_bn1(x)
conv1 = fluid.layers.relu(conv1)
conv2 = self.conv_bn2(conv1)
conv2 = fluid.layers.relu(conv2)
out = self.classifier(conv2)
return out
class NetworkCIFAR(fluid.dygraph.Layer):
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkCIFAR, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
stem_multiplier = 3
c_curr = stem_multiplier * C
self.stem = ConvBN(
c_curr=3, c_out=c_curr, kernel_size=3, padding=1, stride=1)
c_prev_prev, c_prev, c_curr = c_curr, c_curr, C
cells = []
reduction_prev = False
for i in range(layers):
if i in [layers // 3, 2 * layers // 3]:
c_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, c_prev_prev, c_prev, c_curr, reduction,
reduction_prev)
reduction_prev = reduction
cells += [cell]
c_prev_prev, c_prev = c_prev, cell._multiplier * c_curr
if i == 2 * layers // 3:
c_to_auxiliary = c_prev
self.cells = fluid.dygraph.LayerList(cells)
if auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(c_to_auxiliary,
num_classes)
self.classifier = Classifier(c_prev, num_classes)
def forward(self, input, drop_path_prob, training):
logits_aux = None
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, drop_path_prob, training)
if i == 2 * self._layers // 3:
if self._auxiliary and training:
logits_aux = self.auxiliary_head(s1)
logits = self.classifier(s1)
return logits, logits_aux
class AuxiliaryHeadImageNet(fluid.dygraph.Layer):
def __init__(self, C, num_classes):
super(AuxiliaryHeadImageNet, self).__init__()
self.avgpool = Pool2D(
pool_size=5, pool_stride=2, pool_padding=0, pool_type='avg')
self.conv_bn1 = ConvBN(
c_curr=C,
c_out=128,
kernel_size=1,
padding=0,
stride=1,
name='aux_conv_bn1')
self.conv_bn2 = ConvBN(
c_curr=128,
c_out=768,
kernel_size=2,
padding=0,
stride=1,
name='aux_conv_bn2')
self.classifier = Classifier(768, num_classes, 'aux')
def forward(self, x):
x = fluid.layers.relu(x)
x = self.avgpool(x)
conv1 = self.conv_bn1(x)
conv1 = fluid.layers.relu(conv1)
conv2 = self.conv_bn2(conv1)
conv2 = fluid.layers.relu(conv2)
out = self.classifier(conv2)
return out
class NetworkImageNet(fluid.dygraph.Layer):
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkImageNet, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
self.stem_a0 = ConvBN(
c_curr=3, c_out=C // 2, kernel_size=3, padding=1, stride=2)
self.stem_a1 = ConvBN(
c_curr=C // 2, c_out=C, kernel_size=3, padding=1, stride=2)
self.stem_b = ConvBN(
c_curr=C, c_out=C, kernel_size=3, padding=1, stride=2)
c_prev_prev, c_prev, c_curr = C, C, C
cells = []
reduction_prev = True
for i in range(layers):
if i in [layers // 3, 2 * layers // 3]:
c_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, c_prev_prev, c_prev, c_curr, reduction,
reduction_prev)
reduction_prev = reduction
cells += [cell]
c_prev_prev, c_prev = c_prev, cell._multiplier * c_curr
if i == 2 * layers // 3:
c_to_auxiliary = c_prev
self.cells = fluid.dygraph.LayerList(cells)
if auxiliary:
self.auxiliary_head = AuxiliaryHeadImageNet(c_to_auxiliary,
num_classes)
self.classifier = Classifier(c_prev, num_classes)
def forward(self, input, training):
logits_aux = None
s0 = self.stem_a0(input)
s0 = fluid.layers.relu(s0)
s0 = self.stem_a1(s0)
s1 = fluid.layers.relu(s0)
s1 = self.stem_b(s1)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, 0, training)
if i == 2 * self._layers // 3:
if self._auxiliary and training:
logits_aux = self.auxiliary_head(s1)
logits = self.classifier(s1)
return logits, logits_aux
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import NormalInitializer, MSRAInitializer, ConstantInitializer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable
from genotypes import PRIMITIVES
from genotypes import Genotype
from operations import *
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.shape
channels_per_group = num_channels // groups
# reshape
x = fluid.layers.reshape(
x, [batchsize, groups, channels_per_group, height, width])
x = fluid.layers.transpose(x, [0, 2, 1, 3, 4])
# flatten
x = fluid.layers.reshape(x, [batchsize, num_channels, height, width])
return x
class MixedOp(fluid.dygraph.Layer):
def __init__(self, c_cur, stride, method):
super(MixedOp, self).__init__()
self._method = method
self._k = 4 if self._method == "PC-DARTS" else 1
self.mp = Pool2D(
pool_size=2,
pool_stride=2,
pool_type='max', )
ops = []
for primitive in PRIMITIVES:
op = OPS[primitive](c_cur // self._k, stride, False)
if 'pool' in primitive:
gama = ParamAttr(
initializer=fluid.initializer.Constant(value=1),
trainable=False)
beta = ParamAttr(
initializer=fluid.initializer.Constant(value=0),
trainable=False)
BN = BatchNorm(
c_cur // self._k, param_attr=gama, bias_attr=beta)
op = fluid.dygraph.Sequential(op, BN)
ops.append(op)
self._ops = fluid.dygraph.LayerList(ops)
def forward(self, x, weights):
if self._method == "PC-DARTS":
dim_2 = x.shape[1]
xtemp = x[:, :dim_2 // self._k, :, :]
xtemp2 = x[:, dim_2 // self._k:, :, :]
temp1 = fluid.layers.sums(
[weights[i] * op(xtemp) for i, op in enumerate(self._ops)])
if temp1.shape[2] == x.shape[2]:
out = fluid.layers.concat([temp1, xtemp2], axis=1)
else:
out = fluid.layers.concat([temp1, self.mp(xtemp2)], axis=1)
out = channel_shuffle(out, self._k)
else:
out = fluid.layers.sums(
[weights[i] * op(x) for i, op in enumerate(self._ops)])
return out
class Cell(fluid.dygraph.Layer):
def __init__(self, steps, multiplier, c_prev_prev, c_prev, c_cur,
reduction, reduction_prev, method):
super(Cell, self).__init__()
self.reduction = reduction
if reduction_prev:
self.preprocess0 = FactorizedReduce(c_prev_prev, c_cur, False)
else:
self.preprocess0 = ReLUConvBN(c_prev_prev, c_cur, 1, 1, 0, False)
self.preprocess1 = ReLUConvBN(c_prev, c_cur, 1, 1, 0, affine=False)
self._steps = steps
self._multiplier = multiplier
self._method = method
ops = []
for i in range(self._steps):
for j in range(2 + i):
stride = 2 if reduction and j < 2 else 1
op = MixedOp(c_cur, stride, method)
ops.append(op)
self._ops = fluid.dygraph.LayerList(ops)
def forward(self, s0, s1, weights, weights2=None):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
offset = 0
for i in range(self._steps):
if self._method == "PC-DARTS":
s = fluid.layers.sums([
weights2[offset + j] *
self._ops[offset + j](h, weights[offset + j])
for j, h in enumerate(states)
])
else:
s = fluid.layers.sums([
self._ops[offset + j](h, weights[offset + j])
for j, h in enumerate(states)
])
offset += len(states)
states.append(s)
out = fluid.layers.concat(input=states[-self._multiplier:], axis=1)
return out
class Network(fluid.dygraph.Layer):
def __init__(self,
c_in,
num_classes,
layers,
method,
steps=4,
multiplier=4,
stem_multiplier=3):
super(Network, self).__init__()
self._c_in = c_in
self._num_classes = num_classes
self._layers = layers
self._steps = steps
self._multiplier = multiplier
self._method = method
c_cur = stem_multiplier * c_in
self.stem = fluid.dygraph.Sequential(
Conv2D(
num_channels=3,
num_filters=c_cur,
filter_size=3,
padding=1,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
bias_attr=False),
BatchNorm(
num_channels=c_cur,
param_attr=fluid.ParamAttr(
initializer=ConstantInitializer(value=1)),
bias_attr=fluid.ParamAttr(
initializer=ConstantInitializer(value=0))))
c_prev_prev, c_prev, c_cur = c_cur, c_cur, c_in
cells = []
reduction_prev = False
for i in range(layers):
if i in [layers // 3, 2 * layers // 3]:
c_cur *= 2
reduction = True
else:
reduction = False
cell = Cell(steps, multiplier, c_prev_prev, c_prev, c_cur,
reduction, reduction_prev, method)
reduction_prev = reduction
cells.append(cell)
c_prev_prev, c_prev = c_prev, multiplier * c_cur
self.cells = fluid.dygraph.LayerList(cells)
self.global_pooling = Pool2D(pool_type='avg', global_pooling=True)
self.classifier = Linear(
input_dim=c_prev,
output_dim=num_classes,
param_attr=ParamAttr(initializer=MSRAInitializer()),
bias_attr=ParamAttr(initializer=MSRAInitializer()))
self._initialize_alphas()
def forward(self, input):
s0 = s1 = self.stem(input)
weights2 = None
for i, cell in enumerate(self.cells):
if cell.reduction:
weights = fluid.layers.softmax(self.alphas_reduce)
if self._method == "PC-DARTS":
n = 3
start = 2
weights2 = fluid.layers.softmax(self.betas_reduce[0:2])
for i in range(self._steps - 1):
end = start + n
tw2 = fluid.layers.softmax(self.betas_reduce[start:
end])
start = end
n += 1
weights2 = fluid.layers.concat([weights2, tw2])
else:
weights = fluid.layers.softmax(self.alphas_normal)
if self._method == "PC-DARTS":
n = 3
start = 2
weights2 = fluid.layers.softmax(self.betas_normal[0:2])
for i in range(self._steps - 1):
end = start + n
tw2 = fluid.layers.softmax(self.betas_normal[start:
end])
start = end
n += 1
weights2 = fluid.layers.concat([weights2, tw2])
s0, s1 = s1, cell(s0, s1, weights, weights2)
out = self.global_pooling(s1)
out = fluid.layers.squeeze(out, axes=[2, 3])
logits = self.classifier(out)
return logits
def _loss(self, input, target):
logits = self(input)
loss = fluid.layers.reduce_mean(
fluid.layers.softmax_with_cross_entropy(logits, target))
return loss
def new(self):
model_new = Network(self._c_in, self._num_classes, self._layers,
self._method)
return model_new
def _initialize_alphas(self):
k = sum(1 for i in range(self._steps) for n in range(2 + i))
num_ops = len(PRIMITIVES)
self.alphas_normal = fluid.layers.create_parameter(
shape=[k, num_ops],
dtype="float32",
default_initializer=NormalInitializer(
loc=0.0, scale=1e-3))
self.alphas_reduce = fluid.layers.create_parameter(
shape=[k, num_ops],
dtype="float32",
default_initializer=NormalInitializer(
loc=0.0, scale=1e-3))
self._arch_parameters = [
self.alphas_normal,
self.alphas_reduce,
]
if self._method == "PC-DARTS":
self.betas_normal = fluid.layers.create_parameter(
shape=[k],
dtype="float32",
default_initializer=NormalInitializer(
loc=0.0, scale=1e-3))
self.betas_reduce = fluid.layers.create_parameter(
shape=[k],
dtype="float32",
default_initializer=NormalInitializer(
loc=0.0, scale=1e-3))
self._arch_parameters += [self.betas_normal, self.betas_reduce]
def arch_parameters(self):
return self._arch_parameters
def genotype(self):
def _parse(weights, weights2=None):
gene = []
n = 2
start = 0
for i in range(self._steps):
end = start + n
W = weights[start:end].copy()
if self._method == "PC-DARTS":
W2 = weights2[start:end].copy()
for j in range(n):
W[j, :] = W[j, :] * W2[j]
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
for j in edges:
k_best = None
for k in range(len(W[j])):
if k != PRIMITIVES.index('none'):
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
gene.append((PRIMITIVES[k_best], j))
start = end
n += 1
return gene
weightsr2 = None
weightsn2 = None
if self._method == "PC-DARTS":
n = 3
start = 2
weightsr2 = fluid.layers.softmax(self.betas_reduce[0:2])
weightsn2 = fluid.layers.softmax(self.betas_normal[0:2])
for i in range(self._steps - 1):
end = start + n
tw2 = fluid.layers.softmax(self.betas_reduce[start:end])
tn2 = fluid.layers.softmax(self.betas_normal[start:end])
start = end
n += 1
weightsr2 = fluid.layers.concat([weightsr2, tw2])
weightsn2 = fluid.layers.concat([weightsn2, tn2])
weightsr2 = weightsr2.numpy()
weightsn2 = weightsn2.numpy()
gene_normal = _parse(
fluid.layers.softmax(self.alphas_normal).numpy(), weightsn2)
gene_reduce = _parse(
fluid.layers.softmax(self.alphas_reduce).numpy(), weightsr2)
concat = range(2 + self._steps - self._multiplier, self._steps + 2)
genotype = Genotype(
normal=gene_normal,
normal_concat=concat,
reduce=gene_reduce,
reduce_concat=concat)
return genotype
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import ConstantInitializer, MSRAInitializer
OPS = {
'none':
lambda C, stride, affine: Zero(stride),
'avg_pool_3x3':
lambda C, stride, affine: Pool2D(
pool_size=3,
pool_type="avg",
pool_stride=stride,
pool_padding=1),
'max_pool_3x3':
lambda C, stride, affine: Pool2D(
pool_size=3,
pool_type="max",
pool_stride=stride,
pool_padding=1),
'skip_connect':
lambda C, stride, affine: Identity()
if stride == 1 else FactorizedReduce(C, C, affine),
'sep_conv_3x3':
lambda C, stride, affine: SepConv(C, C, 3, stride, 1,
affine),
'sep_conv_5x5':
lambda C, stride, affine: SepConv(C, C, 5, stride, 2,
affine),
'sep_conv_7x7':
lambda C, stride, affine: SepConv(C, C, 7, stride, 3,
affine),
'dil_conv_3x3':
lambda C, stride, affine: DilConv(C, C, 3, stride, 2,
2, affine),
'dil_conv_5x5':
lambda C, stride, affine: DilConv(C, C, 5, stride, 4,
2, affine),
'conv_7x1_1x7':
lambda C, stride, affine: Conv_7x1_1x7(
C, C, stride, affine),
}
def bn_param_config(affine=False):
gama = ParamAttr(
initializer=ConstantInitializer(value=1), trainable=affine)
beta = ParamAttr(
initializer=ConstantInitializer(value=0), trainable=affine)
return gama, beta
class Zero(fluid.dygraph.Layer):
def __init__(self, stride):
super(Zero, self).__init__()
self.stride = stride
self.pool = Pool2D(pool_size=1, pool_stride=2)
def forward(self, x):
pooled = self.pool(x)
x = fluid.layers.zeros_like(
x) if self.stride == 1 else fluid.layers.zeros_like(pooled)
return x
class Identity(fluid.dygraph.Layer):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class FactorizedReduce(fluid.dygraph.Layer):
def __init__(self, c_in, c_out, affine=True):
super(FactorizedReduce, self).__init__()
assert c_out % 2 == 0
self.conv1 = Conv2D(
num_channels=c_in,
num_filters=c_out // 2,
filter_size=1,
stride=2,
padding=0,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
bias_attr=False)
self.conv2 = Conv2D(
num_channels=c_in,
num_filters=c_out // 2,
filter_size=1,
stride=2,
padding=0,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
bias_attr=False)
gama, beta = bn_param_config(affine)
self.bn = BatchNorm(
num_channels=c_out, param_attr=gama, bias_attr=beta)
def forward(self, x):
x = fluid.layers.relu(x)
out = fluid.layers.concat(
input=[self.conv1(x), self.conv2(x[:, :, 1:, 1:])], axis=1)
out = self.bn(out)
return out
class SepConv(fluid.dygraph.Layer):
def __init__(self, c_in, c_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__()
self.conv1 = Conv2D(
num_channels=c_in,
num_filters=c_in,
filter_size=kernel_size,
stride=stride,
padding=padding,
groups=c_in,
use_cudnn=False,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
bias_attr=False)
self.conv2 = Conv2D(
num_channels=c_in,
num_filters=c_in,
filter_size=1,
stride=1,
padding=0,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
bias_attr=False)
gama, beta = bn_param_config(affine)
self.bn1 = BatchNorm(
num_channels=c_in, param_attr=gama, bias_attr=beta)
self.conv3 = Conv2D(
num_channels=c_in,
num_filters=c_in,
filter_size=kernel_size,
stride=1,
padding=padding,
groups=c_in,
use_cudnn=False,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
bias_attr=False)
self.conv4 = Conv2D(
num_channels=c_in,
num_filters=c_out,
filter_size=1,
stride=1,
padding=0,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
bias_attr=False)
gama, beta = bn_param_config(affine)
self.bn2 = BatchNorm(
num_channels=c_out, param_attr=gama, bias_attr=beta)
def forward(self, x):
x = fluid.layers.relu(x)
x = self.conv1(x)
x = self.conv2(x)
bn1 = self.bn1(x)
x = fluid.layers.relu(bn1)
x = self.conv3(x)
x = self.conv4(x)
bn2 = self.bn2(x)
return bn2
class DilConv(fluid.dygraph.Layer):
def __init__(self,
c_in,
c_out,
kernel_size,
stride,
padding,
dilation,
affine=True):
super(DilConv, self).__init__()
self.conv1 = Conv2D(
num_channels=c_in,
num_filters=c_in,
filter_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=c_in,
use_cudnn=False,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
bias_attr=False)
self.conv2 = Conv2D(
num_channels=c_in,
num_filters=c_out,
filter_size=1,
padding=0,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
bias_attr=False)
gama, beta = bn_param_config(affine)
self.bn1 = BatchNorm(
num_channels=c_out, param_attr=gama, bias_attr=beta)
def forward(self, x):
x = fluid.layers.relu(x)
x = self.conv1(x)
x = self.conv2(x)
out = self.bn1(x)
return out
class Conv_7x1_1x7(fluid.dygraph.Layer):
def __init__(self, c_in, c_out, stride, affine=True):
super(Conv_7x1_1x7, self).__init__()
self.conv1 = Conv2D(
num_channels=c_in,
num_filters=c_out,
filter_size=(1, 7),
padding=(0, 3),
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
bias_attr=False)
self.conv2 = Conv2D(
num_channels=c_in,
num_filters=c_out,
filter_size=(7, 1),
padding=(3, 0),
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
bias_attr=False)
gama, beta = bn_param_config(affine)
self.bn1 = BatchNorm(
num_channels=c_out, param_attr=gama, bias_attr=beta)
def forward(self, x):
x = fluid.layers.relu(x)
x = self.conv1(x)
x = self.conv2(x)
out = self.bn1(x)
return out
class ReLUConvBN(fluid.dygraph.Layer):
def __init__(self, c_in, c_out, kernel_size, stride, padding, affine=True):
super(ReLUConvBN, self).__init__()
self.conv = Conv2D(
num_channels=c_in,
num_filters=c_out,
filter_size=kernel_size,
stride=stride,
padding=padding,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
bias_attr=False)
gama, beta = bn_param_config(affine)
self.bn = BatchNorm(
num_channels=c_out, param_attr=gama, bias_attr=beta)
def forward(self, x):
x = fluid.layers.relu(x)
x = self.conv(x)
out = self.bn(x)
return out
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from PIL import Image
from PIL import ImageOps
import os
import math
import random
import tarfile
import functools
import numpy as np
from PIL import Image, ImageEnhance
import paddle
# for python2/python3 compatiablity
try:
import cPickle
except:
import _pickle as cPickle
IMAGE_SIZE = 32
IMAGE_DEPTH = 3
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
paddle.dataset.common.DATA_HOME = "dataset/"
THREAD = 16
BUF_SIZE = 10240
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
IMAGENET_STD = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
IMAGENET_DIM = 224
def preprocess(sample, is_training, args):
image_array = sample.reshape(IMAGE_DEPTH, IMAGE_SIZE, IMAGE_SIZE)
rgb_array = np.transpose(image_array, (1, 2, 0))
img = Image.fromarray(rgb_array, 'RGB')
if is_training:
# pad, ramdom crop, random_flip_left_right
img = ImageOps.expand(img, (4, 4, 4, 4), fill=0)
left_top = np.random.randint(8, size=2)
img = img.crop((left_top[1], left_top[0], left_top[1] + IMAGE_SIZE,
left_top[0] + IMAGE_SIZE))
if np.random.randint(2):
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img = np.array(img).astype(np.float32)
img_float = img / 255.0
img = (img_float - CIFAR_MEAN) / CIFAR_STD
if is_training and args.cutout:
center = np.random.randint(IMAGE_SIZE, size=2)
offset_width = max(0, center[0] - args.cutout_length // 2)
offset_height = max(0, center[1] - args.cutout_length // 2)
target_width = min(center[0] + args.cutout_length // 2, IMAGE_SIZE)
target_height = min(center[1] + args.cutout_length // 2, IMAGE_SIZE)
for i in range(offset_height, target_height):
for j in range(offset_width, target_width):
img[i][j][:] = 0.0
img = np.transpose(img, (2, 0, 1))
return img
def reader_generator(datasets, batch_size, is_training, is_shuffle, args):
def read_batch(datasets, args):
if is_shuffle:
random.shuffle(datasets)
for im, label in datasets:
im = preprocess(im, is_training, args)
yield im, [int(label)]
def reader():
batch_data = []
batch_label = []
for data in read_batch(datasets, args):
batch_data.append(data[0])
batch_label.append(data[1])
if len(batch_data) == batch_size:
batch_data = np.array(batch_data, dtype='float32')
batch_label = np.array(batch_label, dtype='int64')
batch_out = [batch_data, batch_label]
yield batch_out
batch_data = []
batch_label = []
return reader
def cifar10_reader(file_name, data_name, is_shuffle, args):
with tarfile.open(file_name, mode='r') as f:
names = [
each_item.name for each_item in f if data_name in each_item.name
]
names.sort()
datasets = []
for name in names:
print("Reading file " + name)
try:
batch = cPickle.load(
f.extractfile(name), encoding='iso-8859-1')
except:
batch = cPickle.load(f.extractfile(name))
data = batch['data']
labels = batch.get('labels', batch.get('fine_labels', None))
assert labels is not None
dataset = zip(data, labels)
datasets.extend(dataset)
if is_shuffle:
random.shuffle(datasets)
return datasets
def train_search(batch_size, train_portion, is_shuffle, args):
datasets = cifar10_reader(
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'data_batch', is_shuffle, args)
split_point = int(np.floor(train_portion * len(datasets)))
train_datasets = datasets[:split_point]
val_datasets = datasets[split_point:]
train_readers = []
val_readers = []
n = int(math.ceil(len(train_datasets) // args.num_workers)
) if args.use_multiprocess else len(train_datasets)
train_datasets_lists = [
train_datasets[i:i + n] for i in range(0, len(train_datasets), n)
]
val_datasets_lists = [
val_datasets[i:i + n] for i in range(0, len(val_datasets), n)
]
for pid in range(len(train_datasets_lists)):
train_readers.append(
reader_generator(train_datasets_lists[pid], batch_size, True, True,
args))
val_readers.append(
reader_generator(val_datasets_lists[pid], batch_size, True, True,
args))
if args.use_multiprocess:
reader = [
paddle.reader.multiprocess_reader(train_readers, False),
paddle.reader.multiprocess_reader(val_readers, False)
]
else:
reader = [train_readers[0], val_readers[0]]
return reader
def train_valid(batch_size, is_train, is_shuffle, args):
name = 'data_batch' if is_train else 'test_batch'
datasets = cifar10_reader(
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
name, is_shuffle, args)
n = int(math.ceil(len(datasets) // args.
num_workers)) if args.use_multiprocess else len(datasets)
datasets_lists = [datasets[i:i + n] for i in range(0, len(datasets), n)]
multi_readers = []
for pid in range(len(datasets_lists)):
multi_readers.append(
reader_generator(datasets_lists[pid], batch_size, is_train,
is_shuffle, args))
if args.use_multiprocess:
reader = paddle.reader.multiprocess_reader(multi_readers, False)
else:
reader = multi_readers[0]
return reader
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.LANCZOS)
return img
def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]):
aspect_ratio = math.sqrt(np.random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(img.size[0]) / img.size[1]) / (w**2),
(float(img.size[1]) / img.size[0]) / (h**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img.size[0] * img.size[1] * np.random.uniform(scale_min,
scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = np.random.randint(0, img.size[0] - w + 1)
j = np.random.randint(0, img.size[1] - h + 1)
img = img.crop((i, j, i + w, j + h))
img = img.resize((size, size), Image.BILINEAR)
return img
def distort_color(img):
def random_brightness(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
np.random.shuffle(ops)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
return img
def process_image(sample, mode, color_jitter, rotate):
img_path = sample[0]
img = Image.open(img_path)
if mode == 'train':
img = random_crop(img, IMAGENET_DIM)
if np.random.randint(0, 2) == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if color_jitter:
img = distort_color(img)
else:
img = resize_short(img, target_size=256)
img = crop_image(img, target_size=IMAGENET_DIM, center=True)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img -= IMAGENET_MEAN
img /= IMAGENET_STD
if mode == 'train' or mode == 'val':
return img, np.array([sample[1]], dtype='int64')
elif mode == 'test':
return [img]
def _reader_creator(file_list,
mode,
shuffle=False,
color_jitter=False,
rotate=False,
data_dir=None):
def reader():
try:
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
np.random.shuffle(full_lines)
lines = full_lines
for line in lines:
if mode == 'train' or mode == 'val':
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
yield img_path, int(label)
elif mode == 'test':
img_path = os.path.join(data_dir, line)
yield [img_path]
except Exception as e:
print("Reader failed!\n{}".format(str(e)))
os._exit(1)
mapper = functools.partial(
process_image, mode=mode, color_jitter=color_jitter, rotate=rotate)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def imagenet_reader(data_dir, mode):
if mode is 'train':
shuffle = True
suffix = 'train_list.txt'
elif mode is 'val':
shuffle = False
suffix = 'val_list.txt'
file_list = os.path.join(data_dir, suffix)
return _reader_creator(file_list, mode, shuffle=shuffle, data_dir=data_dir)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import ast
import argparse
import functools
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
import reader
from model_search import Network
from paddleslim.nas.darts import DARTSearch
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('log_freq', int, 50, "Log frequency.")
add_arg('use_multiprocess', bool, True, "Whether use multiprocess reader.")
add_arg('num_workers', int, 4, "The multiprocess reader number.")
add_arg('data', str, 'dataset/cifar10',"The dir of dataset.")
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('learning_rate', float, 0.025, "The start learning rate.")
add_arg('momentum', float, 0.9, "Momentum.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('epochs', int, 50, "Epoch number.")
add_arg('init_channels', int, 16, "Init channel number.")
add_arg('layers', int, 8, "Total number of layers.")
add_arg('class_num', int, 10, "Class number of dataset.")
add_arg('trainset_num', int, 50000, "images number of trainset.")
add_arg('model_save_dir', str, 'search_cifar', "The path to save model.")
add_arg('grad_clip', float, 5, "Gradient clipping.")
add_arg('arch_learning_rate',float, 3e-4, "Learning rate for arch encoding.")
add_arg('method', str, 'DARTS', "The search method you would like to use")
add_arg('cutout_length', int, 16, "Cutout length.")
add_arg('cutout', ast.literal_eval, False, "Whether use cutout.")
add_arg('unrolled', ast.literal_eval, False, "Use one-step unrolled validation loss")
add_arg('use_data_parallel', ast.literal_eval, False, "The flag indicating whether to use data parallel mode to train the model.")
# yapf: enable
def main(args):
if not args.use_gpu:
place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
train_reader, valid_reader = reader.train_search(
batch_size=args.batch_size,
train_portion=0.5,
is_shuffle=True,
args=args)
with fluid.dygraph.guard(place):
model = Network(args.init_channels, args.class_num, args.layers,
args.method)
searcher = DARTSearch(
model,
train_reader,
valid_reader,
learning_rate=args.learning_rate,
batchsize=args.batch_size,
num_imgs=args.trainset_num,
arch_learning_rate=args.arch_learning_rate,
unrolled=args.unrolled,
method=args.method,
num_epochs=args.epochs,
use_data_parallel=args.use_data_parallel,
log_freq=args.log_freq)
searcher.train()
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
main(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import ast
import argparse
import functools
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from model import NetworkCIFAR as Network
from paddleslim.common import AvgrageMeter
import genotypes
import reader
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_multiprocess', bool, True, "Whether use multiprocess reader.")
add_arg('num_workers', int, 4, "The multiprocess reader number.")
add_arg('data', str, 'dataset/cifar10',"The dir of dataset.")
add_arg('batch_size', int, 96, "Minibatch size.")
add_arg('learning_rate', float, 0.025, "The start learning rate.")
add_arg('momentum', float, 0.9, "Momentum.")
add_arg('weight_decay', float, 3e-4, "Weight_decay.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('epochs', int, 600, "Epoch number.")
add_arg('init_channels', int, 36, "Init channel number.")
add_arg('layers', int, 20, "Total number of layers.")
add_arg('class_num', int, 10, "Class number of dataset.")
add_arg('trainset_num', int, 50000, "images number of trainset.")
add_arg('model_save_dir', str, 'eval_cifar', "The path to save model.")
add_arg('cutout', bool, True, 'Whether use cutout.')
add_arg('cutout_length', int, 16, "Cutout length.")
add_arg('auxiliary', bool, True, 'Use auxiliary tower.')
add_arg('auxiliary_weight', float, 0.4, "Weight for auxiliary loss.")
add_arg('drop_path_prob', float, 0.2, "Drop path probability.")
add_arg('grad_clip', float, 5, "Gradient clipping.")
add_arg('arch', str, 'DARTS_V2', "Which architecture to use")
add_arg('report_freq', int, 50, 'Report frequency')
add_arg('use_data_parallel', ast.literal_eval, False, "The flag indicating whether to use data parallel mode to train the model.")
# yapf: enable
def train(model, train_reader, optimizer, epoch, drop_path_prob, args):
objs = AvgrageMeter()
top1 = AvgrageMeter()
top5 = AvgrageMeter()
model.train()
for step_id, data in enumerate(train_reader()):
image_np, label_np = data
image = to_variable(image_np)
label = to_variable(label_np)
label.stop_gradient = True
logits, logits_aux = model(image, drop_path_prob, True)
prec1 = fluid.layers.accuracy(input=logits, label=label, k=1)
prec5 = fluid.layers.accuracy(input=logits, label=label, k=5)
loss = fluid.layers.reduce_mean(
fluid.layers.softmax_with_cross_entropy(logits, label))
if args.auxiliary:
loss_aux = fluid.layers.reduce_mean(
fluid.layers.softmax_with_cross_entropy(logits_aux, label))
loss = loss + args.auxiliary_weight * loss_aux
if args.use_data_parallel:
loss = model.scale_loss(loss)
loss.backward()
model.apply_collective_grads()
else:
loss.backward()
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
args.grad_clip)
optimizer.minimize(loss, grad_clip=grad_clip)
model.clear_gradients()
n = image.shape[0]
objs.update(loss.numpy(), n)
top1.update(prec1.numpy(), n)
top5.update(prec5.numpy(), n)
if step_id % args.report_freq == 0:
logger.info(
"Train Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[0]))
return top1.avg[0]
def valid(model, valid_reader, epoch, args):
objs = AvgrageMeter()
top1 = AvgrageMeter()
top5 = AvgrageMeter()
model.eval()
for step_id, data in enumerate(valid_reader()):
image_np, label_np = data
image = to_variable(image_np)
label = to_variable(label_np)
logits, _ = model(image, 0, False)
prec1 = fluid.layers.accuracy(input=logits, label=label, k=1)
prec5 = fluid.layers.accuracy(input=logits, label=label, k=5)
loss = fluid.layers.reduce_mean(
fluid.layers.softmax_with_cross_entropy(logits, label))
n = image.shape[0]
objs.update(loss.numpy(), n)
top1.update(prec1.numpy(), n)
top5.update(prec5.numpy(), n)
if step_id % args.report_freq == 0:
logger.info(
"Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[0]))
return top1.avg[0]
def main(args):
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
genotype = eval("genotypes.%s" % args.arch)
model = Network(
C=args.init_channels,
num_classes=args.class_num,
layers=args.layers,
auxiliary=args.auxiliary,
genotype=genotype)
step_per_epoch = int(args.trainset_num / args.batch_size)
learning_rate = fluid.dygraph.CosineDecay(args.learning_rate,
step_per_epoch, args.epochs)
optimizer = fluid.optimizer.MomentumOptimizer(
learning_rate,
momentum=args.momentum,
regularization=fluid.regularizer.L2Decay(args.weight_decay),
parameter_list=model.parameters())
if args.use_data_parallel:
model = fluid.dygraph.parallel.DataParallel(model, strategy)
train_loader = fluid.io.DataLoader.from_generator(
capacity=64,
use_double_buffer=True,
iterable=True,
return_list=True)
valid_loader = fluid.io.DataLoader.from_generator(
capacity=64,
use_double_buffer=True,
iterable=True,
return_list=True)
train_reader = reader.train_valid(
batch_size=args.batch_size,
is_train=True,
is_shuffle=True,
args=args)
valid_reader = reader.train_valid(
batch_size=args.batch_size,
is_train=False,
is_shuffle=False,
args=args)
train_loader.set_batch_generator(train_reader, places=place)
valid_loader.set_batch_generator(valid_reader, places=place)
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
save_parameters = (not args.use_data_parallel) or (
args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
best_acc = 0
for epoch in range(args.epochs):
drop_path_prob = args.drop_path_prob * epoch / args.epochs
logger.info('Epoch {}, lr {:.6f}'.format(
epoch, optimizer.current_step_lr()))
train_top1 = train(model, train_loader, optimizer, epoch,
drop_path_prob, args)
logger.info("Epoch {}, train_acc {:.6f}".format(epoch, train_top1))
valid_top1 = valid(model, valid_loader, epoch, args)
if valid_top1 > best_acc:
best_acc = valid_top1
if save_parameters:
fluid.save_dygraph(model.state_dict(),
args.model_save_dir + "/best_model")
logger.info("Epoch {}, valid_acc {:.6f}, best_valid_acc {:.6f}".
format(epoch, valid_top1, best_acc))
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
main(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import ast
import argparse
import functools
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from model import NetworkImageNet as Network
from paddleslim.common import AvgrageMeter
import genotypes
import reader
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_multiprocess', bool, True, "Whether use multiprocess reader.")
add_arg('num_workers', int, 4, "The multiprocess reader number.")
add_arg('data_dir', str, 'dataset/ILSVRC2012',"The dir of dataset.")
add_arg('batch_size', int, 128, "Minibatch size.")
add_arg('learning_rate', float, 0.1, "The start learning rate.")
add_arg('decay_rate', float, 0.97, "The lr decay rate.")
add_arg('momentum', float, 0.9, "Momentum.")
add_arg('weight_decay', float, 3e-5, "Weight_decay.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('epochs', int, 250, "Epoch number.")
add_arg('init_channels', int, 48, "Init channel number.")
add_arg('layers', int, 14, "Total number of layers.")
add_arg('class_num', int, 1000, "Class number of dataset.")
add_arg('trainset_num', int, 1281167, "Images number of trainset.")
add_arg('model_save_dir', str, 'eval_imagenet', "The path to save model.")
add_arg('auxiliary', bool, True, 'Use auxiliary tower.')
add_arg('auxiliary_weight', float, 0.4, "Weight for auxiliary loss.")
add_arg('drop_path_prob', float, 0.0, "Drop path probability.")
add_arg('dropout', float, 0.0, "Dropout probability.")
add_arg('grad_clip', float, 5, "Gradient clipping.")
add_arg('label_smooth', float, 0.1, "Label smoothing.")
add_arg('arch', str, 'DARTS_V2', "Which architecture to use")
add_arg('report_freq', int, 100, 'Report frequency')
add_arg('use_data_parallel', ast.literal_eval, False, "The flag indicating whether to use data parallel mode to train the model.")
# yapf: enable
def cross_entropy_label_smooth(preds, targets, epsilon):
preds = fluid.layers.softmax(preds)
targets_one_hot = fluid.layers.one_hot(input=targets, depth=args.class_num)
targets_smooth = fluid.layers.label_smooth(
targets_one_hot, epsilon=epsilon, dtype="float32")
loss = fluid.layers.cross_entropy(
input=preds, label=targets_smooth, soft_label=True)
return loss
def train(model, train_reader, optimizer, epoch, args):
objs = AvgrageMeter()
top1 = AvgrageMeter()
top5 = AvgrageMeter()
model.train()
for step_id, data in enumerate(train_reader()):
image_np, label_np = data
image = to_variable(image_np)
label = to_variable(label_np)
label.stop_gradient = True
logits, logits_aux = model(image, True)
prec1 = fluid.layers.accuracy(input=logits, label=label, k=1)
prec5 = fluid.layers.accuracy(input=logits, label=label, k=5)
loss = fluid.layers.reduce_mean(
cross_entropy_label_smooth(logits, label, args.label_smooth))
if args.auxiliary:
loss_aux = fluid.layers.reduce_mean(
cross_entropy_label_smooth(logits_aux, label,
args.label_smooth))
loss = loss + args.auxiliary_weight * loss_aux
if args.use_data_parallel:
loss = model.scale_loss(loss)
loss.backward()
model.apply_collective_grads()
else:
loss.backward()
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
args.grad_clip)
optimizer.minimize(loss, grad_clip=grad_clip)
model.clear_gradients()
n = image.shape[0]
objs.update(loss.numpy(), n)
top1.update(prec1.numpy(), n)
top5.update(prec5.numpy(), n)
if step_id % args.report_freq == 0:
logger.info(
"Train Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[0]))
return top1.avg[0], top5.avg[0]
def valid(model, valid_reader, epoch, args):
objs = AvgrageMeter()
top1 = AvgrageMeter()
top5 = AvgrageMeter()
model.eval()
for step_id, data in enumerate(valid_reader()):
image_np, label_np = data
image = to_variable(image_np)
label = to_variable(label_np)
logits, _ = model(image, False)
prec1 = fluid.layers.accuracy(input=logits, label=label, k=1)
prec5 = fluid.layers.accuracy(input=logits, label=label, k=5)
loss = fluid.layers.reduce_mean(
cross_entropy_label_smooth(logits, label, args.label_smooth))
n = image.shape[0]
objs.update(loss.numpy(), n)
top1.update(prec1.numpy(), n)
top5.update(prec5.numpy(), n)
if step_id % args.report_freq == 0:
logger.info(
"Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[0]))
return top1.avg[0], top5.avg[0]
def main(args):
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
genotype = eval("genotypes.%s" % args.arch)
model = Network(
C=args.init_channels,
num_classes=args.class_num,
layers=args.layers,
auxiliary=args.auxiliary,
genotype=genotype)
step_per_epoch = int(args.trainset_num / args.batch_size)
learning_rate = fluid.dygraph.ExponentialDecay(
args.learning_rate,
step_per_epoch,
args.decay_rate,
staircase=True)
optimizer = fluid.optimizer.MomentumOptimizer(
learning_rate,
momentum=args.momentum,
regularization=fluid.regularizer.L2Decay(args.weight_decay),
parameter_list=model.parameters())
if args.use_data_parallel:
model = fluid.dygraph.parallel.DataParallel(model, strategy)
train_loader = fluid.io.DataLoader.from_generator(
capacity=64,
use_double_buffer=True,
iterable=True,
return_list=True)
valid_loader = fluid.io.DataLoader.from_generator(
capacity=64,
use_double_buffer=True,
iterable=True,
return_list=True)
train_reader = fluid.io.batch(
reader.imagenet_reader(args.data_dir, 'train'),
batch_size=args.batch_size,
drop_last=True)
valid_reader = fluid.io.batch(
reader.imagenet_reader(args.data_dir, 'val'),
batch_size=args.batch_size)
train_loader.set_sample_list_generator(train_reader, places=place)
valid_loader.set_sample_list_generator(valid_reader, places=place)
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
save_parameters = (not args.use_data_parallel) or (
args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
best_top1 = 0
for epoch in range(args.epochs):
logging.info('Epoch {}, lr {:.6f}'.format(
epoch, optimizer.current_step_lr()))
train_top1, train_top5 = train(model, train_loader, optimizer,
epoch, args)
logger.info("Epoch {}, train_top1 {:.6f}, train_top5 {:.6f}".
format(epoch, train_top1, train_top5))
valid_top1, valid_top5 = valid(model, valid_loader, epoch, args)
if valid_top1 > best_top1:
best_top1 = valid_top1
if save_parameters:
fluid.save_dygraph(model.state_dict(),
args.model_save_dir + "/best_model")
logger.info(
"Epoch {}, valid_top1 {:.6f}, valid_top5 {:.6f}, best_valid_top1 {:6f}".
format(epoch, valid_top1, valid_top5, best_top1))
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
main(args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import genotypes
from graphviz import Digraph
def plot(genotype_normal, genotype_reduce, filename):
g = Digraph(
format='png',
edge_attr=dict(fontname="times"),
node_attr=dict(
style='filled',
shape='ellipse',
align='center',
height='0.5',
width='0.5',
penwidth='2',
fontname="times"),
engine='dot')
g.body.extend(['rankdir=LR'])
g.node("reduce_c_{k-2}", fillcolor='darkseagreen2')
g.node("reduce_c_{k-1}", fillcolor='darkseagreen2')
g.node("normal_c_{k-2}", fillcolor='darkseagreen2')
g.node("normal_c_{k-1}", fillcolor='darkseagreen2')
assert len(genotype_normal) % 2 == 0
steps = len(genotype_normal) // 2
for i in range(steps):
g.node('n_' + str(i), fillcolor='lightblue')
for i in range(steps):
g.node('r_' + str(i), fillcolor='lightblue')
for i in range(steps):
for k in [2 * i, 2 * i + 1]:
op, j = genotype_normal[k]
if j == 0:
u = "normal_c_{k-2}"
elif j == 1:
u = "normal_c_{k-1}"
else:
u = 'n_' + str(j - 2)
v = 'n_' + str(i)
g.edge(u, v, label=op, fillcolor="gray")
for i in range(steps):
for k in [2 * i, 2 * i + 1]:
op, j = genotype_reduce[k]
if j == 0:
u = "reduce_c_{k-2}"
elif j == 1:
u = "reduce_c_{k-1}"
else:
u = 'r_' + str(j - 2)
v = 'r_' + str(i)
g.edge(u, v, label=op, fillcolor="gray")
g.node("r_c_{k}", fillcolor='palegoldenrod')
for i in range(steps):
g.edge('r_' + str(i), "r_c_{k}", fillcolor="gray")
g.node("n_c_{k}", fillcolor='palegoldenrod')
for i in range(steps):
g.edge('n_' + str(i), "n_c_{k}", fillcolor="gray")
g.render(filename, view=False)
if __name__ == '__main__':
if len(sys.argv) != 2:
print("usage:\n python {} ARCH_NAME".format(sys.argv[0]))
sys.exit(1)
genotype_name = sys.argv[1]
try:
genotype = eval('genotypes.{}'.format(genotype_name))
except AttributeError:
print("{} is not specified in genotypes.py".format(genotype_name))
sys.exit(1)
plot(genotype.normal, genotype.reduce, genotype_name)
......@@ -18,8 +18,9 @@ from .controller_server import ControllerServer
from .controller_client import ControllerClient
from .lock import lock, unlock
from .cached_reader import cached_reader
from .meter import AvgrageMeter
__all__ = [
'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer',
'ControllerClient', 'lock', 'unlock', 'cached_reader'
'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter'
]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['AvgrageMeter']
class AvgrageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.avg = 0
self.sum = 0
self.cnt = 0
def update(self, val, n=1):
self.sum += val * n
self.cnt += n
self.avg = self.sum / self.cnt
......@@ -16,7 +16,10 @@ from ..nas import search_space
from .search_space import *
from ..nas import sa_nas
from .sa_nas import *
from ..nas import darts
from .darts import *
__all__ = []
__all__ += sa_nas.__all__
__all__ += search_space.__all__
__all__ += darts.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from ..darts import train_search
from .train_search import *
__all__ = []
__all__ += train_search.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
class Architect(object):
def __init__(self, model, eta, arch_learning_rate, place, unrolled):
self.network_momentum = 0.9
self.network_weight_decay = 3e-4
self.eta = eta
self.model = model
self.optimizer = fluid.optimizer.Adam(
arch_learning_rate,
0.5,
0.999,
regularization=fluid.regularizer.L2Decay(1e-3),
parameter_list=self.model.arch_parameters())
self.place = place
self.unrolled = unrolled
if self.unrolled:
self.unrolled_model = self.model.new()
self.unrolled_model_params = [
p for p in self.unrolled_model.parameters()
if p.name not in [
a.name for a in self.unrolled_model.arch_parameters()
] and p.trainable
]
self.unrolled_optimizer = fluid.optimizer.MomentumOptimizer(
self.eta,
self.network_momentum,
regularization=fluid.regularizer.L2DecayRegularizer(
self.network_weight_decay),
parameter_list=self.unrolled_model_params)
def step(self, input_train, target_train, input_valid, target_valid):
if self.unrolled:
params_grads = self._backward_step_unrolled(
input_train, target_train, input_valid, target_valid)
self.optimizer.apply_gradients(params_grads)
else:
loss = self._backward_step(input_valid, target_valid)
self.optimizer.minimize(loss)
self.optimizer.clear_gradients()
def _backward_step(self, input_valid, target_valid):
loss = self.model._loss(input_valid, target_valid)
loss.backward()
return loss
def _backward_step_unrolled(self, input_train, target_train, input_valid,
target_valid):
self._compute_unrolled_model(input_train, target_train)
unrolled_loss = self.unrolled_model._loss(input_valid, target_valid)
unrolled_loss.backward()
vector = [
to_variable(param._grad_ivar().numpy())
for param in self.unrolled_model_params
]
arch_params_grads = [
(alpha, to_variable(ualpha._grad_ivar().numpy()))
for alpha, ualpha in zip(self.model.arch_parameters(),
self.unrolled_model.arch_parameters())
]
self.unrolled_model.clear_gradients()
implicit_grads = self._hessian_vector_product(vector, input_train,
target_train)
for (p, g), ig in zip(arch_params_grads, implicit_grads):
new_g = g - (ig * self.unrolled_optimizer.current_step_lr())
g.value().get_tensor().set(new_g.numpy(), self.place)
return arch_params_grads
def _compute_unrolled_model(self, input, target):
for x, y in zip(self.unrolled_model.parameters(),
self.model.parameters()):
x.value().get_tensor().set(y.numpy(), self.place)
loss = self.unrolled_model._loss(input, target)
loss.backward()
self.unrolled_optimizer.minimize(loss)
self.unrolled_model.clear_gradients()
def _hessian_vector_product(self, vector, input, target, r=1e-2):
R = r * fluid.layers.rsqrt(
fluid.layers.sum([
fluid.layers.reduce_sum(fluid.layers.square(v)) for v in vector
]))
model_params = [
p for p in self.model.parameters()
if p.name not in [a.name for a in self.model.arch_parameters()] and
p.trainable
]
for param, grad in zip(model_params, vector):
param_p = param + grad * R
param.value().get_tensor().set(param_p.numpy(), self.place)
loss = self.model._loss(input, target)
loss.backward()
grads_p = [
to_variable(param._grad_ivar().numpy())
for param in self.model.arch_parameters()
]
for param, grad in zip(model_params, vector):
param_n = param - grad * R * 2
param.value().get_tensor().set(param_n.numpy(), self.place)
self.model.clear_gradients()
loss = self.model._loss(input, target)
loss.backward()
grads_n = [
to_variable(param._grad_ivar().numpy())
for param in self.model.arch_parameters()
]
for param, grad in zip(model_params, vector):
param_o = param + grad * R
param.value().get_tensor().set(param_o.numpy(), self.place)
self.model.clear_gradients()
arch_grad = [(p - n) / (2 * R) for p, n in zip(grads_p, grads_n)]
return arch_grad
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__all__ = ['DARTSearch']
import logging
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from ...common import AvgrageMeter, get_logger
from .architect import Architect
logger = get_logger(__name__, level=logging.INFO)
SUPPORTED_METHODS = ["PC-DARTS", "DARTS"]
def count_parameters_in_MB(all_params):
parameters_number = 0
for param in all_params:
if param.trainable and 'aux' not in param.name:
parameters_number += np.prod(param.shape)
return parameters_number / 1e6
class DARTSearch(object):
def __init__(self,
model,
train_reader,
valid_reader,
learning_rate=0.025,
batchsize=64,
num_imgs=50000,
arch_learning_rate=3e-4,
unrolled='False',
method='DARTS',
num_epochs=50,
use_gpu=True,
use_data_parallel=False,
log_freq=50):
self.model = model
self.train_reader = train_reader
self.valid_reader = valid_reader
self.learning_rate = learning_rate
self.batchsize = batchsize
self.num_imgs = num_imgs
self.arch_learning_rate = arch_learning_rate
self.unrolled = unrolled
self.method = method
assert (self.method in SUPPORTED_METHODS
), "Currently only support PC-DARTS, DARTS two methods"
self.num_epochs = num_epochs
self.use_gpu = use_gpu
self.use_data_parallel = use_data_parallel
if not self.use_gpu:
self.place = fluid.CPUPlace()
elif not self.use_data_parallel:
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
self.log_freq = log_freq
def train_one_epoch(self, train_loader, valid_loader, architect, optimizer,
epoch):
objs = AvgrageMeter()
top1 = AvgrageMeter()
top5 = AvgrageMeter()
self.model.train()
for step_id, (
train_data,
valid_data) in enumerate(zip(train_loader(), valid_loader())):
train_image, train_label = train_data
valid_image, valid_label = valid_data
train_image = to_variable(train_image)
train_label = to_variable(train_label)
train_label.stop_gradient = True
valid_image = to_variable(valid_image)
valid_label = to_variable(valid_label)
valid_label.stop_gradient = True
n = train_image.shape[0]
if not (self.method == "PC-DARTS" and epoch < 15):
architect.step(train_image, train_label, valid_image,
valid_label)
logits = self.model(train_image)
prec1 = fluid.layers.accuracy(input=logits, label=train_label, k=1)
prec5 = fluid.layers.accuracy(input=logits, label=train_label, k=5)
loss = fluid.layers.reduce_mean(
fluid.layers.softmax_with_cross_entropy(logits, train_label))
if self.use_data_parallel:
loss = self.model.scale_loss(loss)
loss.backward()
self.model.apply_collective_grads()
else:
loss.backward()
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(5)
optimizer.minimize(loss, grad_clip)
self.model.clear_gradients()
objs.update(loss.numpy(), n)
top1.update(prec1.numpy(), n)
top5.update(prec5.numpy(), n)
if step_id % self.log_freq == 0:
logger.info(
"Train Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[
0]))
return top1.avg[0]
def valid_one_epoch(self, valid_loader, epoch):
objs = AvgrageMeter()
top1 = AvgrageMeter()
top5 = AvgrageMeter()
self.model.eval()
for step_id, (image, label) in enumerate(valid_loader):
image = to_variable(image)
label = to_variable(label)
n = image.shape[0]
logits = self.model(image)
prec1 = fluid.layers.accuracy(input=logits, label=label, k=1)
prec5 = fluid.layers.accuracy(input=logits, label=label, k=5)
loss = fluid.layers.reduce_mean(
fluid.layers.softmax_with_cross_entropy(logits, label))
objs.update(loss.numpy(), n)
top1.update(prec1.numpy(), n)
top5.update(prec5.numpy(), n)
if step_id % self.log_freq == 0:
logger.info(
"Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[
0]))
return top1.avg[0]
def train(self):
if self.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
model_parameters = [
p for p in self.model.parameters()
if p.name not in [a.name for a in self.model.arch_parameters()]
]
logger.info("param size = {:.6f}MB".format(
count_parameters_in_MB(model_parameters)))
step_per_epoch = int(self.num_imgs * 0.5 / self.batchsize)
if self.unrolled:
step_per_epoch *= 2
learning_rate = fluid.dygraph.CosineDecay(
self.learning_rate, step_per_epoch, self.num_epochs)
optimizer = fluid.optimizer.MomentumOptimizer(
learning_rate,
0.9,
regularization=fluid.regularizer.L2DecayRegularizer(3e-4),
parameter_list=model_parameters)
if self.use_data_parallel:
self.model = fluid.dygraph.parallel.DataParallel(self.model,
strategy)
self.train_reader = fluid.contrib.reader.distributed_batch_reader(
self.train_reader)
self.valid_reader = fluid.contrib.reader.distributed_batch_reader(
self.valid_reader)
train_loader = fluid.io.DataLoader.from_generator(
capacity=64,
use_double_buffer=True,
iterable=True,
return_list=True)
valid_loader = fluid.io.DataLoader.from_generator(
capacity=64,
use_double_buffer=True,
iterable=True,
return_list=True)
train_loader.set_batch_generator(self.train_reader, places=self.place)
valid_loader.set_batch_generator(self.valid_reader, places=self.place)
architect = Architect(self.model, learning_rate,
self.arch_learning_rate, self.place,
self.unrolled)
save_parameters = (not self.use_data_parallel) or (
self.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
for epoch in range(self.num_epochs):
logger.info('Epoch {}, lr {:.6f}'.format(
epoch, optimizer.current_step_lr()))
genotype = self.model.genotype()
logger.info('genotype = %s', genotype)
train_top1 = self.train_one_epoch(train_loader, valid_loader,
architect, optimizer, epoch)
logger.info("Epoch {}, train_acc {:.6f}".format(epoch, train_top1))
if epoch == self.num_epochs - 1:
valid_top1 = self.valid_one_epoch(valid_loader, epoch)
logger.info("Epoch {}, valid_acc {:.6f}".format(epoch,
valid_top1))
if save_parameters:
fluid.save_dygraph(self.model.state_dict(), "./weights")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册