提交 e92a1c9c 编写于 作者: L Lu Yufei 提交者: qingqing01

add 6 models for image classification found by Baidu BDL HiNAS project (#1218)

* add HNAS models for image classification
* change name to HiNAS
* update license for HiNAS_models
上级 1f760a08
# Image Classification Models
This directory contains six image classification models, which are models automatically discovered by Baidu Big Data Lab (BDL) Hierarchical Neural Architecture Search project (HiNAS), achieving 96.1% accuracy on CIFAR-10 dataset. These models are divided into two categories. The first three have no skip link, named HiNAS 0-2, and the last three networks contain skip links, which are similar to the shortcut connections in Resnet, named HiNAS 3-5.
---
## Table of Contents
- [Installation](#installation)
- [Data preparation](#data-preparation)
- [Training a model](#training-a-model)
- [Model performances](#model-performances)
## Installation
Running the trainer in current directory requires:
- PadddlePaddle Fluid >= v0.15.0
- CuDNN >=6.0
If PaddlePaddle and CuDNN in your runtime environment do not meet the requirements, please follow the instructions in [installation document](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html) and make an update.
## Data preparation
When you run the sample code for the first time, the trainer will automatically download the cifar-10 dataset. Please make sure your environment has an internet connection.
The dataset will be downloaded to `dataset/cifar/cifar-10-python.tar.gz` in the same directory as the Trainer. If automatic download fails, you can go to https://www.cs.toronto.edu/~kriz/cifar.html and download cifar-10-python.tar.gz to the location mentioned above.
## Training a model
After the environment is ready, you can train the model. There are two entrances: `train_hinas.py` and `train_hinas_res.py`. The former is used to train Model 0-2 (without skip link), and the latter is used to train Model 3-5 (contains skip link).
Train Model 0~2 (without skip link):
```
python train_hinas.py --model=m_id # m_id can be 0, 1 or 2.
```
Train Model 3~5 (with skip link):
```
python train_hinas_res.py --model=m_id # m_id can be 0, 1 or 2.
```
In addition, both `train_hinas.py` and `train_hinas_res.py` support the following parameters:
- **random_flip_left_right**: Random flip image horizontally. (Default: True)
- **random_flip_up_down**: Randomly flip image vertically. (Default: False)
- **cutout**: Add cutout action to image. (Default: True)
- **standardize_image**: Image standardize. (Default: True)
- **pad_and_cut_image**: Random padding image and then crop back to the original size. (Default: True)
- **shuffle_image**: Shuffle the order of the input images during training. (Default: True)
- **lr_max**: Learning rate at the begin of training. (Default: 0.1)
- **lr_min**: Learning rate at the end of training. (Default: 0.0001)
- **batch_size**: Training batch size (Default: 128)
- **num_epochs**: Total training epoch (Default: 200)
- **weight_decay**: L2 Regularization value (Default: 0.0004)
- **momentum**: The momentum parameter in momentum optimizer (Default: 0.9)
- **dropout_rate**: Dropout rate of the dropout layer (Default: 0.5)
- **bn_decay**: The decay/momentum parameter (or called moving average decay) in batch norm layer (Default: 0.9)
## Model performances
Train all six models using same hyperparameters:
- learning rate: 0.1 -> 0.0001 with cosine annealing
- total epoch: 200
- batch size: 128
- L2 decay: 0.000400
- optimizer: momentum optimizer with m=0.9 and use nesterov
- preprocess: random horizontal flip + image standardization + cutout
And below is the accuracy on CIFAR-10 dataset:
| model | round 1 | round 2 | round 3 | max | avg |
|----------|---------|---------|---------|--------|--------|
| HiNAS-0 | 0.9548 | 0.9520 | 0.9513 | 0.9548 | 0.9527 |
| HiNAS-1 | 0.9452 | 0.9462 | 0.9420 | 0.9462 | 0.9445 |
| HiNAS-2 | 0.9508 | 0.9506 | 0.9483 | 0.9508 | 0.9499 |
| HiNAS-3 | 0.9607 | 0.9623 | 0.9601 | 0.9623 | 0.9611 |
| HiNAS-4 | 0.9611 | 0.9584 | 0.9586 | 0.9611 | 0.9594 |
| HiNAS-5 | 0.9578 | 0.9588 | 0.9594 | 0.9594 | 0.9586 |
# Image Classification Models
本目录下包含6个图像分类模型,都是百度大数据实验室 Hierarchical Neural Architecture Search (HiNAS) 项目通过机器自动发现的模型,在CIFAR-10数据集上达到96.1%的准确率。这6个模型分为两类,前3个没有skip link,分别命名为 HiNAS 0-2号,后三个网络带有skip link,功能类似于Resnet中的shortcut connection,分别命名 HiNAS 3-5号。
---
## Table of Contents
- [Installation](#installation)
- [Data preparation](#data-preparation)
- [Training a model](#training-a-model)
- [Model performances](#model-performances)
## Installation
最低环境要求:
- PadddlePaddle Fluid >= v0.15.0
- Cudnn >=6.0
如果您的运行环境无法满足要求,可以参考此文档升级PaddlePaddle:[installation document](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)
## Data preparation
第一次训练模型的时候,Trainer会自动下载CIFAR-10数据集,请确保您的环境有互联网连接。
数据集会被下载到Trainer同目录下的`dataset/cifar/cifar-10-python.tar.gz`,如果自动下载失败,您可以自行从 https://www.cs.toronto.edu/~kriz/cifar.html 下载cifar-10-python.tar.gz,然后放到上述位置。
## Training a model
准备好环境后,可以训练模型,训练有2个入口,`train_hinas.py``train_hinas_res.py`,前者用来训练0-2号不含skip link的模型,后者用来训练3-5号包含skip link的模型。
训练0~2号不含skip link的模型:
```
python train_hinas.py --model=m_id # m_id can be 0, 1 or 2.
```
训练3~5号包含skip link的模型:
```
python train_hinas_res.py --model=m_id # m_id can be 0, 1 or 2.
```
此外,`train_hinas.py``train_hinas_res.py` 都支持以下参数:
初始化部分:
- random_flip_left_right:图片随机水平翻转(Default:True)
- random_flip_up_down:图片随机垂直翻转(Default:False)
- cutout:图片随机遮挡(Default:True)
- standardize_image:对图片每个像素做 standardize(Default:True)
- pad_and_cut_image:图片随机padding,并裁剪回原大小(Default:True)
- shuffle_image:训练时对输入图片的顺序做shuffle(Default:True)
- lr_max:训练开始时的learning rate(Default:0.1)
- lr_min:训练结束时的learning rate(Default:0.0001)
- batch_size:训练的batch size(Default:128)
- num_epochs:训练总的epoch(Default:200)
- weight_decay:训练时L2 Regularization大小(Default:0.0004)
- momentum:momentum优化器中的momentum系数(Default:0.9)
- dropout_rate:dropout层的dropout_rate(Default:0.5)
- bn_decay:batch norm层的decay/momentum系数(即moving average decay)大小(Default:0.9)
## Model performances
6个模型使用相同的参数训练:
- learning rate: 0.1 -> 0.0001 with cosine annealing
- total epoch: 200
- batch size: 128
- L2 decay: 0.000400
- optimizer: momentum optimizer with m=0.9 and use nesterov
- preprocess: random horizontal flip + image standardization + cutout
以下是6个模型在CIFAR-10数据集上的准确率:
| model | round 1 | round 2 | round 3 | max | avg |
|----------|---------|---------|---------|--------|--------|
| HiNAS-0 | 0.9548 | 0.9520 | 0.9513 | 0.9548 | 0.9527 |
| HiNAS-1 | 0.9452 | 0.9462 | 0.9420 | 0.9462 | 0.9445 |
| HiNAS-2 | 0.9508 | 0.9506 | 0.9483 | 0.9508 | 0.9499 |
| HiNAS-3 | 0.9607 | 0.9623 | 0.9601 | 0.9623 | 0.9611 |
| HiNAS-4 | 0.9611 | 0.9584 | 0.9586 | 0.9611 | 0.9594 |
| HiNAS-5 | 0.9578 | 0.9588 | 0.9594 | 0.9594 | 0.9586 |
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import operator
import numpy as np
import paddle.fluid as fluid
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_float("bn_decay", 0.9, "batch norm decay")
flags.DEFINE_float("dropout_rate", 0.5, "dropout rate")
def calc_padding(img_width, stride, dilation, filter_width):
""" calculate pixels to padding in order to keep input/output size same. """
filter_width = dilation * (filter_width - 1) + 1
if img_width % stride == 0:
pad_along_width = max(filter_width - stride, 0)
else:
pad_along_width = max(filter_width - (img_width % stride), 0)
return pad_along_width // 2, pad_along_width - pad_along_width // 2
def conv(inputs,
filters,
kernel,
strides=(1, 1),
dilation=(1, 1),
num_groups=1,
conv_param=None):
""" normal conv layer """
if isinstance(kernel, (tuple, list)):
n = operator.mul(*kernel) * inputs.shape[1]
else:
n = kernel * kernel * inputs.shape[1]
# pad input
padding = (0, 0, 0, 0) \
+ calc_padding(inputs.shape[2], strides[0], dilation[0], kernel[0]) \
+ calc_padding(inputs.shape[3], strides[1], dilation[1], kernel[1])
if sum(padding) > 0:
inputs = fluid.layers.pad(inputs, padding, 0)
param_attr = fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
0.0, scale=np.sqrt(2.0 / n)),
regularizer=fluid.regularizer.L2Decay(FLAGS.weight_decay))
bias_attr = fluid.param_attr.ParamAttr(
regularizer=fluid.regularizer.L2Decay(0.))
return fluid.layers.conv2d(
inputs,
filters,
kernel,
stride=strides,
padding=0,
dilation=dilation,
groups=num_groups,
param_attr=param_attr if conv_param is None else conv_param,
use_cudnn=False if num_groups == inputs.shape[1] == filters else True,
bias_attr=bias_attr,
act=None)
def sep(inputs, filters, kernel, strides=(1, 1), dilation=(1, 1)):
""" Separable convolution layer """
if isinstance(kernel, (tuple, list)):
n_depth = operator.mul(*kernel)
else:
n_depth = kernel * kernel
n_point = inputs.shape[1]
if isinstance(strides, (tuple, list)):
multiplier = strides[0]
else:
multiplier = strides
depthwise_param = fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
0.0, scale=np.sqrt(2.0 / n_depth)),
regularizer=fluid.regularizer.L2Decay(FLAGS.weight_decay))
pointwise_param = fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
0.0, scale=np.sqrt(2.0 / n_point)),
regularizer=fluid.regularizer.L2Decay(FLAGS.weight_decay))
depthwise_conv = conv(
inputs=inputs,
kernel=kernel,
filters=int(filters * multiplier),
strides=strides,
dilation=dilation,
num_groups=int(filters * multiplier),
conv_param=depthwise_param)
return conv(
inputs=depthwise_conv,
kernel=(1, 1),
filters=int(filters * multiplier),
strides=(1, 1),
dilation=dilation,
conv_param=pointwise_param)
def maxpool(inputs, kernel, strides=(1, 1)):
padding = (0, 0, 0, 0) \
+ calc_padding(inputs.shape[2], strides[0], 1, kernel[0]) \
+ calc_padding(inputs.shape[3], strides[1], 1, kernel[1])
if sum(padding) > 0:
inputs = fluid.layers.pad(inputs, padding, 0)
return fluid.layers.pool2d(
inputs, kernel, 'max', strides, pool_padding=0, ceil_mode=False)
def avgpool(inputs, kernel, strides=(1, 1)):
padding_pixel = (0, 0, 0, 0)
padding_pixel += calc_padding(inputs.shape[2], strides[0], 1, kernel[0])
padding_pixel += calc_padding(inputs.shape[3], strides[1], 1, kernel[1])
if padding_pixel[4] == padding_pixel[5] and padding_pixel[
6] == padding_pixel[7]:
# same padding pixel num on all sides.
return fluid.layers.pool2d(
inputs,
kernel,
'avg',
strides,
pool_padding=(padding_pixel[4], padding_pixel[6]),
ceil_mode=False)
elif padding_pixel[4] + 1 == padding_pixel[5] and padding_pixel[6] + 1 == padding_pixel[7] \
and strides == (1, 1):
# different padding size: first pad then crop.
x = fluid.layers.pool2d(
inputs,
kernel,
'avg',
strides,
pool_padding=(padding_pixel[5], padding_pixel[7]),
ceil_mode=False)
x_shape = x.shape
return fluid.layers.crop(
x,
shape=(-1, x_shape[1], x_shape[2] - 1, x_shape[3] - 1),
offsets=(0, 0, 1, 1))
else:
# not support. use padding-zero and pool2d.
print("Warning: use zero-padding in avgpool")
outputs = fluid.layers.pad(inputs, padding_pixel, 0)
return fluid.layers.pool2d(
outputs, kernel, 'avg', strides, pool_padding=0, ceil_mode=False)
def global_avgpool(inputs):
return fluid.layers.pool2d(
inputs,
1,
'avg',
1,
pool_padding=0,
global_pooling=True,
ceil_mode=True)
def fully_connected(inputs, units):
n = inputs.shape[1]
param_attr = fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
0.0, scale=np.sqrt(2.0 / n)),
regularizer=fluid.regularizer.L2Decay(FLAGS.weight_decay))
bias_attr = fluid.param_attr.ParamAttr(
regularizer=fluid.regularizer.L2Decay(0.))
return fluid.layers.fc(inputs,
units,
param_attr=param_attr,
bias_attr=bias_attr)
def bn_relu(inputs):
""" batch norm + rely layer """
output = fluid.layers.batch_norm(
inputs, momentum=FLAGS.bn_decay, epsilon=0.001, data_layout="NCHW")
return fluid.layers.relu(output)
def dropout(inputs):
""" dropout layer """
return fluid.layers.dropout(inputs, dropout_prob=FLAGS.dropout_rate)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import build.layers as layers
def conv_1x1(inputs, downsample=False):
return conv_base(inputs, (1, 1), downsample=downsample)
def conv_2x2(inputs, downsample=False):
return conv_base(inputs, (2, 2), downsample=downsample)
def conv_3x3(inputs, downsample=False):
return conv_base(inputs, (3, 3), downsample=downsample)
def dilated_2x2(inputs, downsample=False):
return conv_base(inputs, (2, 2), (2, 2), downsample)
def conv_1x2_2x1(inputs, downsample=False):
return pair_base(inputs, 2, downsample)
def conv_1x3_3x1(inputs, downsample=False):
return pair_base(inputs, 3, downsample)
def sep_2x2(inputs, downsample=False):
return sep_base(inputs, (2, 2), downsample=downsample)
def sep_3x3(inputs, downsample=False):
return sep_base(inputs, (3, 3), downsample=downsample)
def maxpool_2x2(inputs, downsample=False):
return maxpool_base(inputs, (2, 2), downsample)
def maxpool_3x3(inputs, downsample=False):
return maxpool_base(inputs, (3, 3), downsample)
def avgpool_2x2(inputs, downsample=False):
return avgpool_base(inputs, (2, 2), downsample)
def avgpool_3x3(inputs, downsample=False):
return avgpool_base(inputs, (3, 3), downsample)
def conv_base(inputs, kernel, dilation=(1, 1), downsample=False):
filters = inputs.shape[1]
if downsample:
output = layers.conv(inputs, filters * 2, kernel, (2, 2))
else:
output = layers.conv(inputs, filters, kernel, dilation=dilation)
return output
def pair_base(inputs, kernel, downsample=False):
filters = inputs.shape[1]
if downsample:
output = layers.conv(inputs, filters, (1, kernel), (1, 2))
output = layers.conv(output, filters, (kernel, 1), (2, 1))
output = layers.conv(output, filters * 2, (1, 1))
else:
output = layers.conv(inputs, filters, (1, kernel))
output = layers.conv(output, filters, (kernel, 1))
return output
def sep_base(inputs, kernel, dilation=(1, 1), downsample=False):
filters = inputs.shape[1]
if downsample:
output = layers.sep(inputs, filters * 2, kernel, (2, 2))
else:
output = layers.sep(inputs, filters, kernel, dilation=dilation)
return output
def maxpool_base(inputs, kernel, downsample=False):
if downsample:
filters = inputs.shape[1]
output = layers.maxpool(inputs, kernel, (2, 2))
output = layers.conv(output, filters * 2, (1, 1))
else:
output = layers.maxpool(inputs, kernel)
return output
def avgpool_base(inputs, kernel, downsample=False):
if downsample:
filters = inputs.shape[1]
output = layers.avgpool(inputs, kernel, (2, 2))
output = layers.conv(output, filters * 2, (1, 1))
else:
output = layers.avgpool(inputs, kernel)
return output
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from absl import flags
import build.layers as layers
import build.ops as _ops
FLAGS = flags.FLAGS
flags.DEFINE_integer("num_stages", 3, "number of stages")
flags.DEFINE_integer("num_blocks", 5, "number of blocks per stage")
flags.DEFINE_integer("num_ops", 2, "number of operations per block")
flags.DEFINE_integer("width", 64, "network width")
flags.DEFINE_string("downsample", "pool", "conv or pool")
num_classes = 10
ops = [
_ops.conv_1x1,
_ops.conv_2x2,
_ops.conv_3x3,
_ops.dilated_2x2,
_ops.conv_1x2_2x1,
_ops.conv_1x3_3x1,
_ops.sep_2x2,
_ops.sep_3x3,
_ops.maxpool_2x2,
_ops.maxpool_3x3,
_ops.avgpool_2x2,
_ops.avgpool_3x3,
]
def net(inputs, tokens):
""" build network with skip links """
x = layers.conv(inputs, FLAGS.width, (3, 3))
num_ops = FLAGS.num_blocks * FLAGS.num_ops
x = stage(x, tokens[:num_ops], pre_activation=True)
for i in range(1, FLAGS.num_stages):
x = stage(x, tokens[i * num_ops:(i + 1) * num_ops], downsample=True)
x = layers.bn_relu(x)
x = layers.global_avgpool(x)
x = layers.dropout(x)
logits = layers.fully_connected(x, num_classes)
return fluid.layers.softmax(logits)
def stage(x, tokens, pre_activation=False, downsample=False):
""" build network's stage. Stage consists of blocks """
x = block(x, tokens[:FLAGS.num_ops], pre_activation, downsample)
for i in range(1, FLAGS.num_blocks):
print("-" * 12)
x = block(x, tokens[i * FLAGS.num_ops:(i + 1) * FLAGS.num_ops])
print("=" * 12)
return x
def block(x, tokens, pre_activation=False, downsample=False):
""" build block. """
if pre_activation:
x = layers.bn_relu(x)
res = x
else:
res = x
x = layers.bn_relu(x)
x = ops[tokens[0]](x, downsample)
print("%s \t-> shape %s" % (ops[0].__name__, x.shape))
for token in tokens[1:]:
x = layers.bn_relu(x)
x = ops[token](x)
print("%s \t-> shape %s" % (ops[token].__name__, x.shape))
if downsample:
filters = res.shape[1]
if FLAGS.downsample == "conv":
res = layers.conv(res, filters * 2, (1, 1), (2, 2))
elif FLAGS.downsample == "pool":
res = layers.avgpool(res, (2, 2), (2, 2))
res = fluid.layers.pad(res, (0, 0, filters // 2, filters // 2, 0, 0,
0, 0))
else:
raise NotImplementedError
return x + res
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from absl import flags
import build.layers as layers
import build.ops as _ops
FLAGS = flags.FLAGS
flags.DEFINE_integer("num_stages", 5, "number of stages")
flags.DEFINE_integer("width", 64, "network width")
num_classes = 10
ops = [
_ops.conv_1x1, #0
_ops.conv_2x2, #1
_ops.conv_3x3, #2
_ops.dilated_2x2, #3
_ops.conv_1x2_2x1, #4
_ops.conv_1x3_3x1, #5
_ops.sep_2x2, #6
_ops.sep_3x3, #7
_ops.maxpool_2x2, #8
_ops.maxpool_3x3,
_ops.avgpool_2x2, #10
_ops.avgpool_3x3,
]
def net(inputs, tokens):
depth = len(tokens)
q, r = divmod(depth + 1, FLAGS.num_stages)
downsample_steps = [
i * q + max(0, i + r - FLAGS.num_stages + 1) - 2
for i in range(1, FLAGS.num_stages)
]
x = layers.conv(inputs, FLAGS.width, (3, 3))
x = layers.bn_relu(x)
for i, token in enumerate(tokens):
downsample = i in downsample_steps
x = ops[token](x, downsample)
print("%s \t-> shape %s" % (ops[token].__name__, x.shape))
if downsample:
print("=" * 12)
x = layers.bn_relu(x)
x = layers.global_avgpool(x)
x = layers.dropout(x)
logits = layers.fully_connected(x, num_classes)
return fluid.layers.softmax(logits)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
import reader
from absl import flags
# import preprocess
FLAGS = flags.FLAGS
flags.DEFINE_float("lr_max", 0.1, "initial learning rate")
flags.DEFINE_float("lr_min", 0.0001, "limiting learning rate")
flags.DEFINE_integer("batch_size", 128, "batch size")
flags.DEFINE_integer("num_epochs", 200, "total epochs to train")
flags.DEFINE_float("weight_decay", 0.0004, "weight decay")
flags.DEFINE_float("momentum", 0.9, "momentum")
flags.DEFINE_boolean("shuffle_image", True, "shuffle input images on training")
dataset_train_size = 50000
class Model(object):
def __init__(self, build_fn, tokens):
print("learning rate: %f -> %f, cosine annealing" %
(FLAGS.lr_max, FLAGS.lr_min))
print("epoch: %d" % FLAGS.num_epochs)
print("batch size: %d" % FLAGS.batch_size)
print("L2 decay: %f" % FLAGS.weight_decay)
self.max_step = dataset_train_size * FLAGS.num_epochs // FLAGS.batch_size
self.build_fn = build_fn
self.tokens = tokens
print("Token is %s" % ",".join(map(str, tokens)))
def cosine_annealing(self):
step = _decay_step_counter()
lr = FLAGS.lr_min + (FLAGS.lr_max - FLAGS.lr_min) / 2 \
* (1.0 + fluid.layers.ops.cos(step / self.max_step * math.pi))
return lr
def optimizer_program(self):
return fluid.optimizer.Momentum(
learning_rate=self.cosine_annealing(),
momentum=FLAGS.momentum,
use_nesterov=True,
regularization=fluid.regularizer.L2DecayRegularizer(
FLAGS.weight_decay))
def inference_network(self):
images = fluid.layers.data(
name='pixel', shape=[3, 32, 32], dtype='float32')
return self.build_fn(images, self.tokens)
def train_network(self):
predict = self.inference_network()
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
accuracy = fluid.layers.accuracy(input=predict, label=label)
# self.parameters = fluid.parameters.create(avg_cost)
return [avg_cost, accuracy]
def run(self):
train_files = reader.train10()
test_files = reader.test10()
if FLAGS.shuffle_image:
train_reader = paddle.batch(
paddle.reader.shuffle(train_files, dataset_train_size),
batch_size=FLAGS.batch_size)
else:
train_reader = paddle.batch(
train_files, batch_size=FLAGS.batch_size)
test_reader = paddle.batch(test_files, batch_size=FLAGS.batch_size)
costs = []
accs = []
def event_handler(event):
if isinstance(event, fluid.EndStepEvent):
costs.append(event.metrics[0])
accs.append(event.metrics[1])
if event.step % 20 == 0:
print("Epoch %d, Step %d, Loss %f, Acc %f" % (
event.epoch, event.step, np.mean(costs), np.mean(accs)))
del costs[:]
del accs[:]
if isinstance(event, fluid.EndEpochEvent):
if event.epoch % 3 == 0 or event.epoch == FLAGS.num_epochs - 1:
avg_cost, accuracy = trainer.test(
reader=test_reader, feed_order=['pixel', 'label'])
event_handler.best_acc = max(event_handler.best_acc,
accuracy)
print("Test with epoch %d, Loss %f, Acc %f" %
(event.epoch, avg_cost, accuracy))
print("Best acc %f" % event_handler.best_acc)
event_handler.best_acc = 0.0
place = fluid.CUDAPlace(0)
trainer = fluid.Trainer(
train_func=self.train_network,
optimizer_func=self.optimizer_program,
place=place)
trainer.train(
reader=train_reader,
num_epochs=FLAGS.num_epochs,
event_handler=event_handler,
feed_order=['pixel', 'label'])
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CIFAR-10 dataset.
This module will download dataset from
https://www.cs.toronto.edu/~kriz/cifar.html and parse train/test set into
paddle reader creators.
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes,
with 6000 images per class. There are 50000 training images and 10000 test images.
"""
from PIL import Image
from PIL import ImageOps
import numpy as np
import cPickle
import itertools
import paddle.dataset.common
import tarfile
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_boolean("random_flip_left_right", True,
"random flip left and right")
flags.DEFINE_boolean("random_flip_up_down", False, "random flip up and down")
flags.DEFINE_boolean("cutout", True, "cutout")
flags.DEFINE_boolean("standardize_image", True, "standardize input images")
flags.DEFINE_boolean("pad_and_cut_image", True, "pad and cut input images")
__all__ = ['train10', 'test10', 'convert']
URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
paddle.dataset.common.DATA_HOME = "dataset/"
image_size = 32
image_depth = 3
half_length = 8
def preprocess(sample, is_training):
image_array = sample.reshape(3, image_size, image_size)
rgb_array = np.transpose(image_array, (1, 2, 0))
img = Image.fromarray(rgb_array, 'RGB')
if is_training:
if FLAGS.pad_and_cut_image:
# pad and ramdom crop
img = ImageOps.expand(
img, (2, 2, 2, 2), fill=0) # pad to 36 * 36 * 3
left_top = np.random.randint(5, size=2) # rand 0 - 4
img = img.crop((left_top[0], left_top[1], left_top[0] + image_size,
left_top[1] + image_size))
if FLAGS.random_flip_left_right and np.random.randint(2):
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if FLAGS.random_flip_up_down and np.random.randint(2):
img = img.transpose(Image.FLIP_TOP_BOTTOM)
img = np.array(img).astype(np.float32)
if FLAGS.standardize_image:
# per_image_standardization
img_float = img / 255.0
mean = np.mean(img_float)
std = max(np.std(img_float), 1.0 / np.sqrt(3 * image_size * image_size))
img = (img_float - mean) / std
if is_training and FLAGS.cutout:
center = np.random.randint(image_size, size=2)
offset_width = max(0, center[0] - half_length)
offset_height = max(0, center[1] - half_length)
target_width = min(center[0] + half_length, image_size)
target_height = min(center[1] + half_length, image_size)
for i in range(offset_height, target_height):
for j in range(offset_width, target_width):
img[i][j][:] = 0.0
img = np.transpose(img, (2, 0, 1))
return img.reshape(3 * image_size * image_size)
def reader_creator(filename, sub_name, is_training):
def read_batch(batch):
data = batch['data']
labels = batch.get('labels', batch.get('fine_labels', None))
assert labels is not None
for sample, label in itertools.izip(data, labels):
yield preprocess(sample, is_training), int(label)
def reader():
with tarfile.open(filename, mode='r') as f:
names = [
each_item.name for each_item in f if sub_name in each_item.name
]
names.sort()
for name in names:
print("Reading file " + name)
batch = cPickle.load(f.extractfile(name))
for item in read_batch(batch):
yield item
return reader
def train10():
"""
CIFAR-10 training set creator.
It returns a reader creator, each sample in the reader is image pixels in
[0, 1] and label in [0, 9].
:return: Training reader creator
:rtype: callable
"""
return reader_creator(
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'data_batch', True)
def test10():
"""
CIFAR-10 test set creator.
It returns a reader creator, each sample in the reader is image pixels in
[0, 1] and label in [0, 9].
:return: Test reader creator.
:rtype: callable
"""
return reader_creator(
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'test_batch', False)
def fetch():
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
paddle.dataset.common.convert(path, train10(), 1000, "cifar_train10")
paddle.dataset.common.convert(path, test10(), 1000, "cifar_test10")
cnumpy.core.multiarray
_reconstruct
p0
(cnumpy
ndarray
p1
(I0
tp2
S'b'
p3
tp4
Rp5
(I1
(I21
tp6
cnumpy
dtype
p7
(S'i4'
p8
I0
I1
tp9
Rp10
(I3
S'<'
p11
NNNI-1
I-1
I0
tp12
bI00
S'\x05\x00\x00\x00\x07\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x05\x00\x00\x00\x02\x00\x00\x00\x08\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x01\x00\x00\x00\n\x00\x00\x00\t\x00\x00\x00\x03\x00\x00\x00\x08\x00\x00\x00\x0b\x00\x00\x00\x03\x00\x00\x00\t\x00\x00\x00\x02\x00\x00\x00\x06\x00\x00\x00\x01\x00\x00\x00\x06\x00\x00\x00'
p13
tp14
b.
\ No newline at end of file
cnumpy.core.multiarray
_reconstruct
p0
(cnumpy
ndarray
p1
(I0
tp2
S'b'
p3
tp4
Rp5
(I1
(I21
tp6
cnumpy
dtype
p7
(S'i4'
p8
I0
I1
tp9
Rp10
(I3
S'<'
p11
NNNI-1
I-1
I0
tp12
bI00
S'\x07\x00\x00\x00\x07\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x02\x00\x00\x00\x02\x00\x00\x00\x08\x00\x00\x00\x08\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x02\x00\x00\x00\n\x00\x00\x00\x08\x00\x00\x00\x02\x00\x00\x00\t\x00\x00\x00\x0b\x00\x00\x00\t\x00\x00\x00\x06\x00\x00\x00\x04\x00\x00\x00\x04\x00\x00\x00\n\x00\x00\x00'
p13
tp14
b.
\ No newline at end of file
cnumpy.core.multiarray
_reconstruct
p0
(cnumpy
ndarray
p1
(I0
tp2
S'b'
p3
tp4
Rp5
(I1
(I21
tp6
cnumpy
dtype
p7
(S'i4'
p8
I0
I1
tp9
Rp10
(I3
S'<'
p11
NNNI-1
I-1
I0
tp12
bI00
S'\x07\x00\x00\x00\x05\x00\x00\x00\x08\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\n\x00\x00\x00\t\x00\x00\x00\x02\x00\x00\x00\x02\x00\x00\x00\x02\x00\x00\x00\x08\x00\x00\x00\x08\x00\x00\x00\x08\x00\x00\x00\x02\x00\x00\x00\t\x00\x00\x00\x04\x00\x00\x00\t\x00\x00\x00\x0b\x00\x00\x00\x07\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00'
p13
tp14
b.
\ No newline at end of file
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
from absl import app
from absl import flags
import nn_paddle as nn
from build import vgg_base
FLAGS = flags.FLAGS
flags.DEFINE_string("tokdir", "tokens/", "token directory")
flags.DEFINE_integer("model", 0, "model")
mid = [17925, 18089, 15383]
def main(_):
f = os.path.join(FLAGS.tokdir, str(mid[FLAGS.model]) + ".pkl")
tokens = pickle.load(open(f, "rb"))
model = nn.Model(vgg_base.net, tokens)
model.run()
if __name__ == "__main__":
app.run(main)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
from absl import app
from absl import flags
import nn_paddle as nn
from build import resnet_base
FLAGS = flags.FLAGS
flags.DEFINE_string("tokdir", "tokens/", "token directory")
flags.DEFINE_integer("model", 0, "model")
mid = [17754, 15113, 15613]
def main(_):
f = os.path.join(FLAGS.tokdir, str(mid[FLAGS.model]) + ".pkl")
tokens = pickle.load(open(f, "rb"))
model = nn.Model(resnet_base.net, tokens)
model.run()
if __name__ == "__main__":
app.run(main)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册