未验证 提交 5bbdb309 编写于 作者: Y Yibing Liu 提交者: GitHub

Code clean for release (#2650)

上级 594d6bef
# Image Classification Models
This directory contains six image classification models, which are models automatically discovered by Baidu Big Data Lab (BDL) Hierarchical Neural Architecture Search project (HiNAS), achieving 96.1% accuracy on CIFAR-10 dataset. These models are divided into two categories. The first three have no skip link, named HiNAS 0-2, and the last three networks contain skip links, which are similar to the shortcut connections in Resnet, named HiNAS 3-5.
---
## Table of Contents
- [Installation](#installation)
- [Data preparation](#data-preparation)
- [Training a model](#training-a-model)
- [Model performances](#model-performances)
## Installation
Running the trainer in current directory requires:
- PadddlePaddle Fluid >= v0.15.0
- CuDNN >=6.0
If PaddlePaddle and CuDNN in your runtime environment do not meet the requirements, please follow the instructions in [installation document](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html) and make an update.
## Data preparation
When you run the sample code for the first time, the trainer will automatically download the cifar-10 dataset. Please make sure your environment has an internet connection.
The dataset will be downloaded to `dataset/cifar/cifar-10-python.tar.gz` in the same directory as the Trainer. If automatic download fails, you can go to https://www.cs.toronto.edu/~kriz/cifar.html and download cifar-10-python.tar.gz to the location mentioned above.
## Training a model
After the environment is ready, you can train the model. There are two entrances: `train_hinas.py` and `train_hinas_res.py`. The former is used to train Model 0-2 (without skip link), and the latter is used to train Model 3-5 (contains skip link).
Train Model 0~2 (without skip link):
```
python train_hinas.py --model=m_id # m_id can be 0, 1 or 2.
```
Train Model 3~5 (with skip link):
```
python train_hinas_res.py --model=m_id # m_id can be 0, 1 or 2.
```
In addition, both `train_hinas.py` and `train_hinas_res.py` support the following parameters:
- **random_flip_left_right**: Random flip image horizontally. (Default: True)
- **random_flip_up_down**: Randomly flip image vertically. (Default: False)
- **cutout**: Add cutout action to image. (Default: True)
- **standardize_image**: Image standardize. (Default: True)
- **pad_and_cut_image**: Random padding image and then crop back to the original size. (Default: True)
- **shuffle_image**: Shuffle the order of the input images during training. (Default: True)
- **lr_max**: Learning rate at the begin of training. (Default: 0.1)
- **lr_min**: Learning rate at the end of training. (Default: 0.0001)
- **batch_size**: Training batch size (Default: 128)
- **num_epochs**: Total training epoch (Default: 200)
- **weight_decay**: L2 Regularization value (Default: 0.0004)
- **momentum**: The momentum parameter in momentum optimizer (Default: 0.9)
- **dropout_rate**: Dropout rate of the dropout layer (Default: 0.5)
- **bn_decay**: The decay/momentum parameter (or called moving average decay) in batch norm layer (Default: 0.9)
## Model performances
Train all six models using same hyperparameters:
- learning rate: 0.1 -> 0.0001 with cosine annealing
- total epoch: 200
- batch size: 128
- L2 decay: 0.000400
- optimizer: momentum optimizer with m=0.9 and use nesterov
- preprocess: random horizontal flip + image standardization + cutout
And below is the accuracy on CIFAR-10 dataset:
| model | round 1 | round 2 | round 3 | max | avg |
|----------|---------|---------|---------|--------|--------|
| HiNAS-0 | 0.9548 | 0.9520 | 0.9513 | 0.9548 | 0.9527 |
| HiNAS-1 | 0.9452 | 0.9462 | 0.9420 | 0.9462 | 0.9445 |
| HiNAS-2 | 0.9508 | 0.9506 | 0.9483 | 0.9508 | 0.9499 |
| HiNAS-3 | 0.9607 | 0.9623 | 0.9601 | 0.9623 | 0.9611 |
| HiNAS-4 | 0.9611 | 0.9584 | 0.9586 | 0.9611 | 0.9594 |
| HiNAS-5 | 0.9578 | 0.9588 | 0.9594 | 0.9594 | 0.9586 |
# Image Classification Models
本目录下包含6个图像分类模型,都是百度大数据实验室 Hierarchical Neural Architecture Search (HiNAS) 项目通过机器自动发现的模型,在CIFAR-10数据集上达到96.1%的准确率。这6个模型分为两类,前3个没有skip link,分别命名为 HiNAS 0-2号,后三个网络带有skip link,功能类似于Resnet中的shortcut connection,分别命名 HiNAS 3-5号。
---
## Table of Contents
- [Installation](#installation)
- [Data preparation](#data-preparation)
- [Training a model](#training-a-model)
- [Model performances](#model-performances)
## Installation
最低环境要求:
- PadddlePaddle Fluid >= v0.15.0
- Cudnn >=6.0
如果您的运行环境无法满足要求,可以参考此文档升级PaddlePaddle:[installation document](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)
## Data preparation
第一次训练模型的时候,Trainer会自动下载CIFAR-10数据集,请确保您的环境有互联网连接。
数据集会被下载到Trainer同目录下的`dataset/cifar/cifar-10-python.tar.gz`,如果自动下载失败,您可以自行从 https://www.cs.toronto.edu/~kriz/cifar.html 下载cifar-10-python.tar.gz,然后放到上述位置。
## Training a model
准备好环境后,可以训练模型,训练有2个入口,`train_hinas.py``train_hinas_res.py`,前者用来训练0-2号不含skip link的模型,后者用来训练3-5号包含skip link的模型。
训练0~2号不含skip link的模型:
```
python train_hinas.py --model=m_id # m_id can be 0, 1 or 2.
```
训练3~5号包含skip link的模型:
```
python train_hinas_res.py --model=m_id # m_id can be 0, 1 or 2.
```
此外,`train_hinas.py``train_hinas_res.py` 都支持以下参数:
初始化部分:
- random_flip_left_right:图片随机水平翻转(Default:True)
- random_flip_up_down:图片随机垂直翻转(Default:False)
- cutout:图片随机遮挡(Default:True)
- standardize_image:对图片每个像素做 standardize(Default:True)
- pad_and_cut_image:图片随机padding,并裁剪回原大小(Default:True)
- shuffle_image:训练时对输入图片的顺序做shuffle(Default:True)
- lr_max:训练开始时的learning rate(Default:0.1)
- lr_min:训练结束时的learning rate(Default:0.0001)
- batch_size:训练的batch size(Default:128)
- num_epochs:训练总的epoch(Default:200)
- weight_decay:训练时L2 Regularization大小(Default:0.0004)
- momentum:momentum优化器中的momentum系数(Default:0.9)
- dropout_rate:dropout层的dropout_rate(Default:0.5)
- bn_decay:batch norm层的decay/momentum系数(即moving average decay)大小(Default:0.9)
## Model performances
6个模型使用相同的参数训练:
- learning rate: 0.1 -> 0.0001 with cosine annealing
- total epoch: 200
- batch size: 128
- L2 decay: 0.000400
- optimizer: momentum optimizer with m=0.9 and use nesterov
- preprocess: random horizontal flip + image standardization + cutout
以下是6个模型在CIFAR-10数据集上的准确率:
| model | round 1 | round 2 | round 3 | max | avg |
|----------|---------|---------|---------|--------|--------|
| HiNAS-0 | 0.9548 | 0.9520 | 0.9513 | 0.9548 | 0.9527 |
| HiNAS-1 | 0.9452 | 0.9462 | 0.9420 | 0.9462 | 0.9445 |
| HiNAS-2 | 0.9508 | 0.9506 | 0.9483 | 0.9508 | 0.9499 |
| HiNAS-3 | 0.9607 | 0.9623 | 0.9601 | 0.9623 | 0.9611 |
| HiNAS-4 | 0.9611 | 0.9584 | 0.9586 | 0.9611 | 0.9594 |
| HiNAS-5 | 0.9578 | 0.9588 | 0.9594 | 0.9594 | 0.9586 |
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import operator
import numpy as np
import paddle.fluid as fluid
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_float("bn_decay", 0.9, "batch norm decay")
flags.DEFINE_float("dropout_rate", 0.5, "dropout rate")
def calc_padding(img_width, stride, dilation, filter_width):
""" calculate pixels to padding in order to keep input/output size same. """
filter_width = dilation * (filter_width - 1) + 1
if img_width % stride == 0:
pad_along_width = max(filter_width - stride, 0)
else:
pad_along_width = max(filter_width - (img_width % stride), 0)
return pad_along_width // 2, pad_along_width - pad_along_width // 2
def conv(inputs,
filters,
kernel,
strides=(1, 1),
dilation=(1, 1),
num_groups=1,
conv_param=None):
""" normal conv layer """
if isinstance(kernel, (tuple, list)):
n = operator.mul(*kernel) * inputs.shape[1]
else:
n = kernel * kernel * inputs.shape[1]
# pad input
padding = (0, 0, 0, 0) \
+ calc_padding(inputs.shape[2], strides[0], dilation[0], kernel[0]) \
+ calc_padding(inputs.shape[3], strides[1], dilation[1], kernel[1])
if sum(padding) > 0:
inputs = fluid.layers.pad(inputs, padding, 0)
param_attr = fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
0.0, scale=np.sqrt(2.0 / n)),
regularizer=fluid.regularizer.L2Decay(FLAGS.weight_decay))
bias_attr = fluid.param_attr.ParamAttr(
regularizer=fluid.regularizer.L2Decay(0.))
return fluid.layers.conv2d(
inputs,
filters,
kernel,
stride=strides,
padding=0,
dilation=dilation,
groups=num_groups,
param_attr=param_attr if conv_param is None else conv_param,
use_cudnn=False if num_groups == inputs.shape[1] == filters else True,
bias_attr=bias_attr,
act=None)
def sep(inputs, filters, kernel, strides=(1, 1), dilation=(1, 1)):
""" Separable convolution layer """
if isinstance(kernel, (tuple, list)):
n_depth = operator.mul(*kernel)
else:
n_depth = kernel * kernel
n_point = inputs.shape[1]
if isinstance(strides, (tuple, list)):
multiplier = strides[0]
else:
multiplier = strides
depthwise_param = fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
0.0, scale=np.sqrt(2.0 / n_depth)),
regularizer=fluid.regularizer.L2Decay(FLAGS.weight_decay))
pointwise_param = fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
0.0, scale=np.sqrt(2.0 / n_point)),
regularizer=fluid.regularizer.L2Decay(FLAGS.weight_decay))
depthwise_conv = conv(
inputs=inputs,
kernel=kernel,
filters=int(filters * multiplier),
strides=strides,
dilation=dilation,
num_groups=int(filters * multiplier),
conv_param=depthwise_param)
return conv(
inputs=depthwise_conv,
kernel=(1, 1),
filters=int(filters * multiplier),
strides=(1, 1),
dilation=dilation,
conv_param=pointwise_param)
def maxpool(inputs, kernel, strides=(1, 1)):
padding = (0, 0, 0, 0) \
+ calc_padding(inputs.shape[2], strides[0], 1, kernel[0]) \
+ calc_padding(inputs.shape[3], strides[1], 1, kernel[1])
if sum(padding) > 0:
inputs = fluid.layers.pad(inputs, padding, 0)
return fluid.layers.pool2d(
inputs, kernel, 'max', strides, pool_padding=0, ceil_mode=False)
def avgpool(inputs, kernel, strides=(1, 1)):
padding_pixel = (0, 0, 0, 0)
padding_pixel += calc_padding(inputs.shape[2], strides[0], 1, kernel[0])
padding_pixel += calc_padding(inputs.shape[3], strides[1], 1, kernel[1])
if padding_pixel[4] == padding_pixel[5] and padding_pixel[
6] == padding_pixel[7]:
# same padding pixel num on all sides.
return fluid.layers.pool2d(
inputs,
kernel,
'avg',
strides,
pool_padding=(padding_pixel[4], padding_pixel[6]),
ceil_mode=False)
elif padding_pixel[4] + 1 == padding_pixel[5] and padding_pixel[6] + 1 == padding_pixel[7] \
and strides == (1, 1):
# different padding size: first pad then crop.
x = fluid.layers.pool2d(
inputs,
kernel,
'avg',
strides,
pool_padding=(padding_pixel[5], padding_pixel[7]),
ceil_mode=False)
x_shape = x.shape
return fluid.layers.crop(
x,
shape=(-1, x_shape[1], x_shape[2] - 1, x_shape[3] - 1),
offsets=(0, 0, 1, 1))
else:
# not support. use padding-zero and pool2d.
print("Warning: use zero-padding in avgpool")
outputs = fluid.layers.pad(inputs, padding_pixel, 0)
return fluid.layers.pool2d(
outputs, kernel, 'avg', strides, pool_padding=0, ceil_mode=False)
def global_avgpool(inputs):
return fluid.layers.pool2d(
inputs,
1,
'avg',
1,
pool_padding=0,
global_pooling=True,
ceil_mode=True)
def fully_connected(inputs, units):
n = inputs.shape[1]
param_attr = fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
0.0, scale=np.sqrt(2.0 / n)),
regularizer=fluid.regularizer.L2Decay(FLAGS.weight_decay))
bias_attr = fluid.param_attr.ParamAttr(
regularizer=fluid.regularizer.L2Decay(0.))
return fluid.layers.fc(inputs,
units,
param_attr=param_attr,
bias_attr=bias_attr)
def bn_relu(inputs):
""" batch norm + rely layer """
output = fluid.layers.batch_norm(
inputs, momentum=FLAGS.bn_decay, epsilon=0.001, data_layout="NCHW")
return fluid.layers.relu(output)
def dropout(inputs):
""" dropout layer """
return fluid.layers.dropout(inputs, dropout_prob=FLAGS.dropout_rate)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import build.layers as layers
def conv_1x1(inputs, downsample=False):
return conv_base(inputs, (1, 1), downsample=downsample)
def conv_2x2(inputs, downsample=False):
return conv_base(inputs, (2, 2), downsample=downsample)
def conv_3x3(inputs, downsample=False):
return conv_base(inputs, (3, 3), downsample=downsample)
def dilated_2x2(inputs, downsample=False):
return conv_base(inputs, (2, 2), (2, 2), downsample)
def conv_1x2_2x1(inputs, downsample=False):
return pair_base(inputs, 2, downsample)
def conv_1x3_3x1(inputs, downsample=False):
return pair_base(inputs, 3, downsample)
def sep_2x2(inputs, downsample=False):
return sep_base(inputs, (2, 2), downsample=downsample)
def sep_3x3(inputs, downsample=False):
return sep_base(inputs, (3, 3), downsample=downsample)
def maxpool_2x2(inputs, downsample=False):
return maxpool_base(inputs, (2, 2), downsample)
def maxpool_3x3(inputs, downsample=False):
return maxpool_base(inputs, (3, 3), downsample)
def avgpool_2x2(inputs, downsample=False):
return avgpool_base(inputs, (2, 2), downsample)
def avgpool_3x3(inputs, downsample=False):
return avgpool_base(inputs, (3, 3), downsample)
def conv_base(inputs, kernel, dilation=(1, 1), downsample=False):
filters = inputs.shape[1]
if downsample:
output = layers.conv(inputs, filters * 2, kernel, (2, 2))
else:
output = layers.conv(inputs, filters, kernel, dilation=dilation)
return output
def pair_base(inputs, kernel, downsample=False):
filters = inputs.shape[1]
if downsample:
output = layers.conv(inputs, filters, (1, kernel), (1, 2))
output = layers.conv(output, filters, (kernel, 1), (2, 1))
output = layers.conv(output, filters * 2, (1, 1))
else:
output = layers.conv(inputs, filters, (1, kernel))
output = layers.conv(output, filters, (kernel, 1))
return output
def sep_base(inputs, kernel, dilation=(1, 1), downsample=False):
filters = inputs.shape[1]
if downsample:
output = layers.sep(inputs, filters * 2, kernel, (2, 2))
else:
output = layers.sep(inputs, filters, kernel, dilation=dilation)
return output
def maxpool_base(inputs, kernel, downsample=False):
if downsample:
filters = inputs.shape[1]
output = layers.maxpool(inputs, kernel, (2, 2))
output = layers.conv(output, filters * 2, (1, 1))
else:
output = layers.maxpool(inputs, kernel)
return output
def avgpool_base(inputs, kernel, downsample=False):
if downsample:
filters = inputs.shape[1]
output = layers.avgpool(inputs, kernel, (2, 2))
output = layers.conv(output, filters * 2, (1, 1))
else:
output = layers.avgpool(inputs, kernel)
return output
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from absl import flags
import build.layers as layers
import build.ops as _ops
FLAGS = flags.FLAGS
flags.DEFINE_integer("num_stages", 3, "number of stages")
flags.DEFINE_integer("num_blocks", 5, "number of blocks per stage")
flags.DEFINE_integer("num_ops", 2, "number of operations per block")
flags.DEFINE_integer("width", 64, "network width")
flags.DEFINE_string("downsample", "pool", "conv or pool")
num_classes = 10
ops = [
_ops.conv_1x1,
_ops.conv_2x2,
_ops.conv_3x3,
_ops.dilated_2x2,
_ops.conv_1x2_2x1,
_ops.conv_1x3_3x1,
_ops.sep_2x2,
_ops.sep_3x3,
_ops.maxpool_2x2,
_ops.maxpool_3x3,
_ops.avgpool_2x2,
_ops.avgpool_3x3,
]
def net(inputs, tokens):
""" build network with skip links """
x = layers.conv(inputs, FLAGS.width, (3, 3))
num_ops = FLAGS.num_blocks * FLAGS.num_ops
x = stage(x, tokens[:num_ops], pre_activation=True)
for i in range(1, FLAGS.num_stages):
x = stage(x, tokens[i * num_ops:(i + 1) * num_ops], downsample=True)
x = layers.bn_relu(x)
x = layers.global_avgpool(x)
x = layers.dropout(x)
logits = layers.fully_connected(x, num_classes)
return fluid.layers.softmax(logits)
def stage(x, tokens, pre_activation=False, downsample=False):
""" build network's stage. Stage consists of blocks """
x = block(x, tokens[:FLAGS.num_ops], pre_activation, downsample)
for i in range(1, FLAGS.num_blocks):
print("-" * 12)
x = block(x, tokens[i * FLAGS.num_ops:(i + 1) * FLAGS.num_ops])
print("=" * 12)
return x
def block(x, tokens, pre_activation=False, downsample=False):
""" build block. """
if pre_activation:
x = layers.bn_relu(x)
res = x
else:
res = x
x = layers.bn_relu(x)
x = ops[tokens[0]](x, downsample)
print("%s \t-> shape %s" % (ops[0].__name__, x.shape))
for token in tokens[1:]:
x = layers.bn_relu(x)
x = ops[token](x)
print("%s \t-> shape %s" % (ops[token].__name__, x.shape))
if downsample:
filters = res.shape[1]
if FLAGS.downsample == "conv":
res = layers.conv(res, filters * 2, (1, 1), (2, 2))
elif FLAGS.downsample == "pool":
res = layers.avgpool(res, (2, 2), (2, 2))
res = fluid.layers.pad(res, (0, 0, filters // 2, filters // 2, 0, 0,
0, 0))
else:
raise NotImplementedError
return x + res
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from absl import flags
import build.layers as layers
import build.ops as _ops
FLAGS = flags.FLAGS
flags.DEFINE_integer("num_stages", 5, "number of stages")
flags.DEFINE_integer("width", 64, "network width")
num_classes = 10
ops = [
_ops.conv_1x1, #0
_ops.conv_2x2, #1
_ops.conv_3x3, #2
_ops.dilated_2x2, #3
_ops.conv_1x2_2x1, #4
_ops.conv_1x3_3x1, #5
_ops.sep_2x2, #6
_ops.sep_3x3, #7
_ops.maxpool_2x2, #8
_ops.maxpool_3x3,
_ops.avgpool_2x2, #10
_ops.avgpool_3x3,
]
def net(inputs, tokens):
depth = len(tokens)
q, r = divmod(depth + 1, FLAGS.num_stages)
downsample_steps = [
i * q + max(0, i + r - FLAGS.num_stages + 1) - 2
for i in range(1, FLAGS.num_stages)
]
x = layers.conv(inputs, FLAGS.width, (3, 3))
x = layers.bn_relu(x)
for i, token in enumerate(tokens):
downsample = i in downsample_steps
x = ops[token](x, downsample)
print("%s \t-> shape %s" % (ops[token].__name__, x.shape))
if downsample:
print("=" * 12)
x = layers.bn_relu(x)
x = layers.global_avgpool(x)
x = layers.dropout(x)
logits = layers.fully_connected(x, num_classes)
return fluid.layers.softmax(logits)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.contrib.trainer import *
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
import reader
from absl import flags
# import preprocess
FLAGS = flags.FLAGS
flags.DEFINE_float("lr_max", 0.1, "initial learning rate")
flags.DEFINE_float("lr_min", 0.0001, "limiting learning rate")
flags.DEFINE_integer("batch_size", 128, "batch size")
flags.DEFINE_integer("num_epochs", 200, "total epochs to train")
flags.DEFINE_float("weight_decay", 0.0004, "weight decay")
flags.DEFINE_float("momentum", 0.9, "momentum")
flags.DEFINE_boolean("shuffle_image", True, "shuffle input images on training")
dataset_train_size = 50000
class Model(object):
def __init__(self, build_fn, tokens):
print("learning rate: %f -> %f, cosine annealing" %
(FLAGS.lr_max, FLAGS.lr_min))
print("epoch: %d" % FLAGS.num_epochs)
print("batch size: %d" % FLAGS.batch_size)
print("L2 decay: %f" % FLAGS.weight_decay)
self.max_step = dataset_train_size * FLAGS.num_epochs // FLAGS.batch_size
self.build_fn = build_fn
self.tokens = tokens
print("Token is %s" % ",".join(map(str, tokens)))
def cosine_annealing(self):
step = _decay_step_counter()
lr = FLAGS.lr_min + (FLAGS.lr_max - FLAGS.lr_min) / 2 \
* (1.0 + fluid.layers.ops.cos(step / self.max_step * math.pi))
return lr
def optimizer_program(self):
return fluid.optimizer.Momentum(
learning_rate=self.cosine_annealing(),
momentum=FLAGS.momentum,
use_nesterov=True,
regularization=fluid.regularizer.L2DecayRegularizer(
FLAGS.weight_decay))
def inference_network(self):
images = fluid.layers.data(
name='pixel', shape=[3, 32, 32], dtype='float32')
return self.build_fn(images, self.tokens)
def train_network(self):
predict = self.inference_network()
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
accuracy = fluid.layers.accuracy(input=predict, label=label)
# self.parameters = fluid.parameters.create(avg_cost)
return [avg_cost, accuracy]
def run(self):
train_files = reader.train10()
test_files = reader.test10()
if FLAGS.shuffle_image:
train_reader = paddle.batch(
paddle.reader.shuffle(train_files, dataset_train_size),
batch_size=FLAGS.batch_size)
else:
train_reader = paddle.batch(
train_files, batch_size=FLAGS.batch_size)
test_reader = paddle.batch(test_files, batch_size=FLAGS.batch_size)
costs = []
accs = []
def event_handler(event):
if isinstance(event, EndStepEvent):
costs.append(event.metrics[0])
accs.append(event.metrics[1])
if event.step % 20 == 0:
print("Epoch %d, Step %d, Loss %f, Acc %f" % (
event.epoch, event.step, np.mean(costs), np.mean(accs)))
del costs[:]
del accs[:]
if isinstance(event, EndEpochEvent):
if event.epoch % 3 == 0 or event.epoch == FLAGS.num_epochs - 1:
avg_cost, accuracy = trainer.test(
reader=test_reader, feed_order=['pixel', 'label'])
event_handler.best_acc = max(event_handler.best_acc,
accuracy)
print("Test with epoch %d, Loss %f, Acc %f" %
(event.epoch, avg_cost, accuracy))
print("Best acc %f" % event_handler.best_acc)
event_handler.best_acc = 0.0
place = fluid.CUDAPlace(0)
trainer = Trainer(
train_func=self.train_network,
optimizer_func=self.optimizer_program,
place=place)
trainer.train(
reader=train_reader,
num_epochs=FLAGS.num_epochs,
event_handler=event_handler,
feed_order=['pixel', 'label'])
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CIFAR-10 dataset.
This module will download dataset from
https://www.cs.toronto.edu/~kriz/cifar.html and parse train/test set into
paddle reader creators.
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes,
with 6000 images per class. There are 50000 training images and 10000 test images.
"""
from PIL import Image
from PIL import ImageOps
import numpy as np
import cPickle
import itertools
import paddle.dataset.common
import tarfile
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_boolean("random_flip_left_right", True,
"random flip left and right")
flags.DEFINE_boolean("random_flip_up_down", False, "random flip up and down")
flags.DEFINE_boolean("cutout", True, "cutout")
flags.DEFINE_boolean("standardize_image", True, "standardize input images")
flags.DEFINE_boolean("pad_and_cut_image", True, "pad and cut input images")
__all__ = ['train10', 'test10', 'convert']
URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
paddle.dataset.common.DATA_HOME = "dataset/"
image_size = 32
image_depth = 3
half_length = 8
def preprocess(sample, is_training):
image_array = sample.reshape(3, image_size, image_size)
rgb_array = np.transpose(image_array, (1, 2, 0))
img = Image.fromarray(rgb_array, 'RGB')
if is_training:
if FLAGS.pad_and_cut_image:
# pad and ramdom crop
img = ImageOps.expand(
img, (2, 2, 2, 2), fill=0) # pad to 36 * 36 * 3
left_top = np.random.randint(5, size=2) # rand 0 - 4
img = img.crop((left_top[0], left_top[1], left_top[0] + image_size,
left_top[1] + image_size))
if FLAGS.random_flip_left_right and np.random.randint(2):
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if FLAGS.random_flip_up_down and np.random.randint(2):
img = img.transpose(Image.FLIP_TOP_BOTTOM)
img = np.array(img).astype(np.float32)
if FLAGS.standardize_image:
# per_image_standardization
img_float = img / 255.0
mean = np.mean(img_float)
std = max(np.std(img_float), 1.0 / np.sqrt(3 * image_size * image_size))
img = (img_float - mean) / std
if is_training and FLAGS.cutout:
center = np.random.randint(image_size, size=2)
offset_width = max(0, center[0] - half_length)
offset_height = max(0, center[1] - half_length)
target_width = min(center[0] + half_length, image_size)
target_height = min(center[1] + half_length, image_size)
for i in range(offset_height, target_height):
for j in range(offset_width, target_width):
img[i][j][:] = 0.0
img = np.transpose(img, (2, 0, 1))
return img.reshape(3 * image_size * image_size)
def reader_creator(filename, sub_name, is_training):
def read_batch(batch):
data = batch['data']
labels = batch.get('labels', batch.get('fine_labels', None))
assert labels is not None
for sample, label in itertools.izip(data, labels):
yield preprocess(sample, is_training), int(label)
def reader():
with tarfile.open(filename, mode='r') as f:
names = [
each_item.name for each_item in f if sub_name in each_item.name
]
names.sort()
for name in names:
print("Reading file " + name)
batch = cPickle.load(f.extractfile(name))
for item in read_batch(batch):
yield item
return reader
def train10():
"""
CIFAR-10 training set creator.
It returns a reader creator, each sample in the reader is image pixels in
[0, 1] and label in [0, 9].
:return: Training reader creator
:rtype: callable
"""
return reader_creator(
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'data_batch', True)
def test10():
"""
CIFAR-10 test set creator.
It returns a reader creator, each sample in the reader is image pixels in
[0, 1] and label in [0, 9].
:return: Test reader creator.
:rtype: callable
"""
return reader_creator(
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'test_batch', False)
def fetch():
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
paddle.dataset.common.convert(path, train10(), 1000, "cifar_train10")
paddle.dataset.common.convert(path, test10(), 1000, "cifar_test10")
cnumpy.core.multiarray
_reconstruct
p0
(cnumpy
ndarray
p1
(I0
tp2
S'b'
p3
tp4
Rp5
(I1
(I21
tp6
cnumpy
dtype
p7
(S'i4'
p8
I0
I1
tp9
Rp10
(I3
S'<'
p11
NNNI-1
I-1
I0
tp12
bI00
S'\x05\x00\x00\x00\x07\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x05\x00\x00\x00\x02\x00\x00\x00\x08\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x01\x00\x00\x00\n\x00\x00\x00\t\x00\x00\x00\x03\x00\x00\x00\x08\x00\x00\x00\x0b\x00\x00\x00\x03\x00\x00\x00\t\x00\x00\x00\x02\x00\x00\x00\x06\x00\x00\x00\x01\x00\x00\x00\x06\x00\x00\x00'
p13
tp14
b.
\ No newline at end of file
cnumpy.core.multiarray
_reconstruct
p0
(cnumpy
ndarray
p1
(I0
tp2
S'b'
p3
tp4
Rp5
(I1
(I21
tp6
cnumpy
dtype
p7
(S'i4'
p8
I0
I1
tp9
Rp10
(I3
S'<'
p11
NNNI-1
I-1
I0
tp12
bI00
S'\x07\x00\x00\x00\x07\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x02\x00\x00\x00\x02\x00\x00\x00\x08\x00\x00\x00\x08\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x02\x00\x00\x00\n\x00\x00\x00\x08\x00\x00\x00\x02\x00\x00\x00\t\x00\x00\x00\x0b\x00\x00\x00\t\x00\x00\x00\x06\x00\x00\x00\x04\x00\x00\x00\x04\x00\x00\x00\n\x00\x00\x00'
p13
tp14
b.
\ No newline at end of file
cnumpy.core.multiarray
_reconstruct
p0
(cnumpy
ndarray
p1
(I0
tp2
S'b'
p3
tp4
Rp5
(I1
(I21
tp6
cnumpy
dtype
p7
(S'i4'
p8
I0
I1
tp9
Rp10
(I3
S'<'
p11
NNNI-1
I-1
I0
tp12
bI00
S'\x07\x00\x00\x00\x05\x00\x00\x00\x08\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\n\x00\x00\x00\t\x00\x00\x00\x02\x00\x00\x00\x02\x00\x00\x00\x02\x00\x00\x00\x08\x00\x00\x00\x08\x00\x00\x00\x08\x00\x00\x00\x02\x00\x00\x00\t\x00\x00\x00\x04\x00\x00\x00\t\x00\x00\x00\x0b\x00\x00\x00\x07\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00'
p13
tp14
b.
\ No newline at end of file
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
from absl import app
from absl import flags
import nn_paddle as nn
from build import vgg_base
FLAGS = flags.FLAGS
flags.DEFINE_string("tokdir", "tokens/", "token directory")
flags.DEFINE_integer("model", 0, "model")
mid = [17925, 18089, 15383]
def main(_):
f = os.path.join(FLAGS.tokdir, str(mid[FLAGS.model]) + ".pkl")
tokens = pickle.load(open(f, "rb"))
model = nn.Model(vgg_base.net, tokens)
model.run()
if __name__ == "__main__":
app.run(main)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
from absl import app
from absl import flags
import nn_paddle as nn
from build import resnet_base
FLAGS = flags.FLAGS
flags.DEFINE_string("tokdir", "tokens/", "token directory")
flags.DEFINE_integer("model", 0, "model")
mid = [17754, 15113, 15613]
def main(_):
f = os.path.join(FLAGS.tokdir, str(mid[FLAGS.model]) + ".pkl")
tokens = pickle.load(open(f, "rb"))
model = nn.Model(resnet_base.net, tokens)
model.run()
if __name__ == "__main__":
app.run(main)
# LRC Local Rademachar Complexity Regularization
Regularization of Deep Neural Networks(DNNs) for the sake of improving their generalization capability is important and chllenging. This directory contains image classification model based on a novel regularizer rooted in Local Rademacher Complexity (LRC). We appreciate the contribution by [DARTS](https://arxiv.org/abs/1806.09055) for our research. The regularization by LRC and DARTS are combined in this model on CIFAR-10 dataset. Code accompanying the paper
> [An Empirical Study on Regularization of Deep Neural Networks by Local Rademacher Complexity](https://arxiv.org/abs/1902.00873)\
> Yingzhen Yang, Xingjian Li, Jun Huan.\
> _arXiv:1902.00873_.
---
# Table of Contents
- [Installation](#installation)
- [Data preparation](#data-preparation)
- [Training](#training)
## Installation
Running sample code in this directory requires PaddelPaddle Fluid v.1.2.0 and later. If the PaddlePaddle on your device is lower than this version, please follow the instructions in [installation document](http://www.paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html#paddlepaddle) and make an update.
## Data preparation
When you want to use the cifar-10 dataset for the first time, you can download the dataset as:
sh ./dataset/download.sh
Please make sure your environment has an internet connection.
The dataset will be downloaded to `dataset/cifar/cifar-10-batches-py` in the same directory as the `train.py`. If automatic download fails, you can download cifar-10-python.tar.gz from https://www.cs.toronto.edu/~kriz/cifar.html and decompress it to the location mentioned above.
## Training
After data preparation, one can start the training step by:
python -u train_mixup.py \
--batch_size=80 \
--auxiliary \
--weight_decay=0.0003 \
--learning_rate=0.025 \
--lrc_loss_lambda=0.7 \
--cutout
- Set ```export CUDA_VISIBLE_DEVICES=0``` to specifiy one GPU to train.
- For more help on arguments:
python train_mixup.py --help
**data reader introduction:**
* Data reader is defined in `reader.py`.
* Reshape the images to 32 * 32.
* In training stage, images are padding to 40 * 40 and cropped randomly to the original size.
* In training stage, images are horizontally random flipped.
* Images are standardized to (0, 1).
* In training stage, cutout images randomly.
* Shuffle the order of the input images during training.
**model configuration:**
* Use auxiliary loss and auxiliary\_weight=0.4.
* Use dropout and drop\_path\_prob=0.2.
* Set lrc\_loss\_lambda=0.7.
**training strategy:**
* Use momentum optimizer with momentum=0.9.
* Weight decay is 0.0003.
* Use cosine decay with init\_lr=0.025.
* Total epoch is 600.
* Use Xaiver initalizer to weight in conv2d, Constant initalizer to weight in batch norm and Normal initalizer to weight in fc.
* Initalize bias in batch norm and fc to zero constant and do not add bias to conv2d.
## Reference
- DARTS: Differentiable Architecture Search [`paper`](https://arxiv.org/abs/1806.09055)
- Differentiable architecture search in PyTorch [`code`](https://github.com/quark0/darts)
# LRC 局部Rademachar复杂度正则化
为了在深度神经网络中提升泛化能力,正则化的选择十分重要也具有挑战性。本目录包括了一种基于局部rademacher复杂度的新型正则(LRC)的图像分类模型。十分感谢[DARTS](https://arxiv.org/abs/1806.09055)模型对本研究提供的帮助。该模型将LRC正则和DARTS网络相结合,在CIFAR-10数据集中得到了很出色的效果。代码和文章一同发布
> [An Empirical Study on Regularization of Deep Neural Networks by Local Rademacher Complexity](https://arxiv.org/abs/1902.00873)\
> Yingzhen Yang, Xingjian Li, Jun Huan.\
> _arXiv:1902.00873_.
---
# 内容
- [安装](#安装)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
## 安装
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.2.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html#paddlepaddle)中的说明来更新PaddlePaddle。
## 数据准备
第一次使用CIFAR-10数据集时,您可以通过如果命令下载:
sh ./dataset/download.sh
请确保您的环境有互联网连接。数据会下载到`train.py`同目录下的`dataset/cifar/cifar-10-batches-py`。如果下载失败,您可以自行从https://www.cs.toronto.edu/~kriz/cifar.html上下载cifar-10-python.tar.gz并解压到上述位置。
## 模型训练
数据准备好后,可以通过如下命令开始训练:
python -u train_mixup.py \
--batch_size=80 \
--auxiliary \
--weight_decay=0.0003 \
--learning_rate=0.025 \
--lrc_loss_lambda=0.7 \
--cutout
- 通过设置 ```export CUDA_VISIBLE_DEVICES=0```指定单张GPU训练。
- 可选参数见:
python train_mixup.py --help
**数据读取器说明:**
* 数据读取器定义在`reader.py`
* 输入图像尺寸统一变换为32 * 32
* 训练时将图像填充为40 * 40然后随机剪裁为原输入图像大小
* 训练时图像随机水平翻转
* 对图像每个像素做归一化处理
* 训练时对图像做随机遮挡
* 训练时对输入图像做随机洗牌
**模型配置:**
* 使用辅助损失,辅助损失权重为0.4
* 使用dropout,随机丢弃率为0.2
* 设置lrc\_loss\_lambda为0.7
**训练策略:**
* 采用momentum优化算法训练,momentum=0.9
* 权重衰减系数为0.0001
* 采用正弦学习率衰减,初始学习率为0.025
* 总共训练600轮
* 对卷积权重采用Xaiver初始化,对batch norm权重采用固定初始化,对全连接层权重采用高斯初始化
* 对batch norm和全连接层偏差采用固定初始化,不对卷积设置偏差
## 引用
- DARTS: Differentiable Architecture Search [`论文`](https://arxiv.org/abs/1806.09055)
- Differentiable Architecture Search in PyTorch [`代码`](https://github.com/quark0/darts)
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"
mkdir cifar
cd cifar
# Download the data.
echo "Downloading..."
wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
# Extract the data.
echo "Extracting..."
tar zvxf cifar-10-python.tar.gz
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Based on:
# --------------------------------------------------------
# DARTS
# Copyright (c) 2018, Hanxiao Liu.
# Licensed under the Apache License, Version 2.0;
# --------------------------------------------------------
from collections import namedtuple
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
PRIMITIVES = [
'none', 'max_pool_3x3', 'avg_pool_3x3', 'skip_connect', 'sep_conv_3x3',
'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5'
]
NASNet = Genotype(
normal=[
('sep_conv_5x5', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 0),
('sep_conv_3x3', 0),
('avg_pool_3x3', 1),
('skip_connect', 0),
('avg_pool_3x3', 0),
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('skip_connect', 1),
],
normal_concat=[2, 3, 4, 5, 6],
reduce=[
('sep_conv_5x5', 1),
('sep_conv_7x7', 0),
('max_pool_3x3', 1),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('sep_conv_5x5', 0),
('skip_connect', 3),
('avg_pool_3x3', 2),
('sep_conv_3x3', 2),
('max_pool_3x3', 1),
],
reduce_concat=[4, 5, 6], )
AmoebaNet = Genotype(
normal=[
('avg_pool_3x3', 0),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 2),
('sep_conv_3x3', 0),
('avg_pool_3x3', 3),
('sep_conv_3x3', 1),
('skip_connect', 1),
('skip_connect', 0),
('avg_pool_3x3', 1),
],
normal_concat=[4, 5, 6],
reduce=[
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_7x7', 2),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('conv_7x1_1x7', 0),
('sep_conv_3x3', 5),
],
reduce_concat=[3, 4, 6])
DARTS_V1 = Genotype(
normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 0),
('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1),
('sep_conv_3x3', 0), ('skip_connect', 2)],
normal_concat=[2, 3, 4, 5],
reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2),
('max_pool_3x3', 0), ('max_pool_3x3', 0), ('skip_connect', 2),
('skip_connect', 2), ('avg_pool_3x3', 0)],
reduce_concat=[2, 3, 4, 5])
DARTS_V2 = Genotype(
normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0),
('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0),
('skip_connect', 0), ('dil_conv_3x3', 2)],
normal_concat=[2, 3, 4, 5],
reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2),
('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2),
('skip_connect', 2), ('max_pool_3x3', 1)],
reduce_concat=[2, 3, 4, 5])
MY_DARTS = Genotype(
normal=[('sep_conv_3x3', 0), ('skip_connect', 1), ('skip_connect', 0),
('dil_conv_5x5', 1), ('skip_connect', 0), ('sep_conv_3x3', 1),
('skip_connect', 0), ('sep_conv_3x3', 1)],
normal_concat=range(2, 6),
reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('max_pool_3x3', 0),
('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2),
('skip_connect', 2), ('skip_connect', 3)],
reduce_concat=range(2, 6))
DARTS = MY_DARTS
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Based on:
# --------------------------------------------------------
# DARTS
# Copyright (c) 2018, Hanxiao Liu.
# Licensed under the Apache License, Version 2.0;
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers.ops as ops
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
import math
from paddle.fluid.initializer import init_on_cpu
def cosine_decay(learning_rate, num_epoch, steps_one_epoch):
"""Applies cosine decay to the learning rate.
lr = 0.5 * (math.cos(epoch * (math.pi / 120)) + 1)
"""
global_step = _decay_step_counter()
with init_on_cpu():
decayed_lr = learning_rate * \
(ops.cos((global_step / steps_one_epoch) \
* math.pi / num_epoch) + 1)/2
return decayed_lr
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
#
# Based on:
# --------------------------------------------------------
# DARTS
# Copyright (c) 2018, Hanxiao Liu.
# Licensed under the Apache License, Version 2.0;
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
import time
import functools
import paddle
import paddle.fluid as fluid
from operations import *
class Cell():
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction,
reduction_prev):
print(C_prev_prev, C_prev, C)
if reduction_prev:
self.preprocess0 = functools.partial(FactorizedReduce, C_out=C)
else:
self.preprocess0 = functools.partial(
ReLUConvBN, C_out=C, kernel_size=1, stride=1, padding=0)
self.preprocess1 = functools.partial(
ReLUConvBN, C_out=C, kernel_size=1, stride=1, padding=0)
if reduction:
op_names, indices = zip(*genotype.reduce)
concat = genotype.reduce_concat
else:
op_names, indices = zip(*genotype.normal)
concat = genotype.normal_concat
print(op_names, indices, concat, reduction)
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = []
for name, index in zip(op_names, indices):
stride = 2 if reduction and index < 2 else 1
op = functools.partial(OPS[name], C=C, stride=stride, affine=True)
self._ops += [op]
self._indices = indices
def forward(self, s0, s1, drop_prob, is_train, name):
self.training = is_train
preprocess0_name = name + 'preprocess0.'
preprocess1_name = name + 'preprocess1.'
s0 = self.preprocess0(s0, name=preprocess0_name)
s1 = self.preprocess1(s1, name=preprocess1_name)
out = [s0, s1]
for i in range(self._steps):
h1 = out[self._indices[2 * i]]
h2 = out[self._indices[2 * i + 1]]
op1 = self._ops[2 * i]
op2 = self._ops[2 * i + 1]
h3 = op1(h1, name=name + '_ops.' + str(2 * i) + '.')
h4 = op2(h2, name=name + '_ops.' + str(2 * i + 1) + '.')
if self.training and drop_prob > 0.:
if h3 != h1:
h3 = fluid.layers.dropout(
h3,
drop_prob,
dropout_implementation='upscale_in_train')
if h4 != h2:
h4 = fluid.layers.dropout(
h4,
drop_prob,
dropout_implementation='upscale_in_train')
s = h3 + h4
out += [s]
return fluid.layers.concat([out[i] for i in self._concat], axis=1)
def AuxiliaryHeadCIFAR(input, num_classes, aux_name='auxiliary_head'):
relu_a = fluid.layers.relu(input)
pool_a = fluid.layers.pool2d(relu_a, 5, 'avg', 3)
conv2d_a = fluid.layers.conv2d(
pool_a,
128,
1,
name=aux_name + '.features.2',
param_attr=ParamAttr(
initializer=Xavier(
uniform=False, fan_in=0),
name=aux_name + '.features.2.weight'),
bias_attr=False)
bn_a_name = aux_name + '.features.3'
bn_a = fluid.layers.batch_norm(
conv2d_a,
act='relu',
name=bn_a_name,
param_attr=ParamAttr(
initializer=Constant(1.), name=bn_a_name + '.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.), name=bn_a_name + '.bias'),
moving_mean_name=bn_a_name + '.running_mean',
moving_variance_name=bn_a_name + '.running_var')
conv2d_b = fluid.layers.conv2d(
bn_a,
768,
2,
name=aux_name + '.features.5',
param_attr=ParamAttr(
initializer=Xavier(
uniform=False, fan_in=0),
name=aux_name + '.features.5.weight'),
bias_attr=False)
bn_b_name = aux_name + '.features.6'
bn_b = fluid.layers.batch_norm(
conv2d_b,
act='relu',
name=bn_b_name,
param_attr=ParamAttr(
initializer=Constant(1.), name=bn_b_name + '.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.), name=bn_b_name + '.bias'),
moving_mean_name=bn_b_name + '.running_mean',
moving_variance_name=bn_b_name + '.running_var')
fc_name = aux_name + '.classifier'
fc = fluid.layers.fc(bn_b,
num_classes,
name=fc_name,
param_attr=ParamAttr(
initializer=Normal(scale=1e-3),
name=fc_name + '.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.), name=fc_name + '.bias'))
return fc
def StemConv(input, C_out, kernel_size, padding):
conv_a = fluid.layers.conv2d(
input,
C_out,
kernel_size,
padding=padding,
param_attr=ParamAttr(
initializer=Xavier(
uniform=False, fan_in=0), name='stem.0.weight'),
bias_attr=False)
bn_a = fluid.layers.batch_norm(
conv_a,
param_attr=ParamAttr(
initializer=Constant(1.), name='stem.1.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.), name='stem.1.bias'),
moving_mean_name='stem.1.running_mean',
moving_variance_name='stem.1.running_var')
return bn_a
class NetworkCIFAR(object):
def __init__(self, C, class_num, layers, auxiliary, genotype):
self.class_num = class_num
self._layers = layers
self._auxiliary = auxiliary
stem_multiplier = 3
self.drop_path_prob = 0
C_curr = stem_multiplier * C
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = []
reduction_prev = False
for i in range(layers):
if i in [layers // 3, 2 * layers // 3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction,
reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
if i == 2 * layers // 3:
C_to_auxiliary = C_prev
def forward(self, init_channel, is_train):
self.training = is_train
self.logits_aux = None
num_channel = init_channel * 3
s0 = StemConv(self.image, num_channel, kernel_size=3, padding=1)
s1 = s0
for i, cell in enumerate(self.cells):
name = 'cells.' + str(i) + '.'
s0, s1 = s1, cell.forward(s0, s1, self.drop_path_prob, is_train,
name)
if i == int(2 * self._layers // 3):
if self._auxiliary and self.training:
self.logits_aux = AuxiliaryHeadCIFAR(s1, self.class_num)
out = fluid.layers.adaptive_pool2d(s1, (1, 1), "avg")
self.logits = fluid.layers.fc(out,
size=self.class_num,
param_attr=ParamAttr(
initializer=Normal(scale=1e-3),
name='classifier.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.),
name='classifier.bias'))
return self.logits, self.logits_aux
def build_input(self, image_shape, batch_size, is_train):
if is_train:
py_reader = fluid.layers.py_reader(
capacity=64,
shapes=[[-1] + image_shape, [-1, 1], [-1, 1], [-1, 1], [-1, 1],
[-1, 1], [-1, batch_size, self.class_num - 1]],
lod_levels=[0, 0, 0, 0, 0, 0, 0],
dtypes=[
"float32", "int64", "int64", "float32", "int32", "int32",
"float32"
],
use_double_buffer=True,
name='train_reader')
else:
py_reader = fluid.layers.py_reader(
capacity=64,
shapes=[[-1] + image_shape, [-1, 1]],
lod_levels=[0, 0],
dtypes=["float32", "int64"],
use_double_buffer=True,
name='test_reader')
return py_reader
def train_model(self, py_reader, init_channels, aux, aux_w, batch_size,
loss_lambda):
self.image, self.ya, self.yb, self.lam, self.label_reshape,\
self.non_label_reshape, self.rad_var = fluid.layers.read_file(py_reader)
self.logits, self.logits_aux = self.forward(init_channels, True)
self.mixup_loss = self.mixup_loss(aux, aux_w)
self.lrc_loss = self.lrc_loss(batch_size)
return self.mixup_loss + loss_lambda * self.lrc_loss
def test_model(self, py_reader, init_channels):
self.image, self.ya = fluid.layers.read_file(py_reader)
self.logits, _ = self.forward(init_channels, False)
prob = fluid.layers.softmax(self.logits, use_cudnn=False)
loss = fluid.layers.cross_entropy(prob, self.ya)
acc_1 = fluid.layers.accuracy(self.logits, self.ya, k=1)
acc_5 = fluid.layers.accuracy(self.logits, self.ya, k=5)
return loss, acc_1, acc_5
def mixup_loss(self, auxiliary, auxiliary_weight):
prob = fluid.layers.softmax(self.logits, use_cudnn=False)
loss_a = fluid.layers.cross_entropy(prob, self.ya)
loss_b = fluid.layers.cross_entropy(prob, self.yb)
loss_a_mean = fluid.layers.reduce_mean(loss_a)
loss_b_mean = fluid.layers.reduce_mean(loss_b)
loss = self.lam * loss_a_mean + (1 - self.lam) * loss_b_mean
if auxiliary:
prob_aux = fluid.layers.softmax(self.logits_aux, use_cudnn=False)
loss_a_aux = fluid.layers.cross_entropy(prob_aux, self.ya)
loss_b_aux = fluid.layers.cross_entropy(prob_aux, self.yb)
loss_a_aux_mean = fluid.layers.reduce_mean(loss_a_aux)
loss_b_aux_mean = fluid.layers.reduce_mean(loss_b_aux)
loss_aux = self.lam * loss_a_aux_mean + (1 - self.lam
) * loss_b_aux_mean
return loss + auxiliary_weight * loss_aux
def lrc_loss(self, batch_size):
y_diff_reshape = fluid.layers.reshape(self.logits, shape=(-1, 1))
label_reshape = fluid.layers.squeeze(self.label_reshape, axes=[1])
non_label_reshape = fluid.layers.squeeze(
self.non_label_reshape, axes=[1])
label_reshape.stop_gradient = True
non_label_reshape.stop_graident = True
y_diff_label_reshape = fluid.layers.gather(y_diff_reshape,
label_reshape)
y_diff_non_label_reshape = fluid.layers.gather(y_diff_reshape,
non_label_reshape)
y_diff_label = fluid.layers.reshape(
y_diff_label_reshape, shape=(-1, batch_size, 1))
y_diff_non_label = fluid.layers.reshape(
y_diff_non_label_reshape,
shape=(-1, batch_size, self.class_num - 1))
y_diff_ = y_diff_non_label - y_diff_label
y_diff_ = fluid.layers.transpose(y_diff_, perm=[1, 2, 0])
rad_var_trans = fluid.layers.transpose(self.rad_var, perm=[1, 2, 0])
rad_y_diff_trans = rad_var_trans * y_diff_
lrc_loss_sum = fluid.layers.reduce_sum(rad_y_diff_trans, dim=[0, 1])
lrc_loss_ = fluid.layers.abs(lrc_loss_sum) / (batch_size *
(self.class_num - 1))
lrc_loss_mean = fluid.layers.reduce_mean(lrc_loss_)
return lrc_loss_mean
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
#
# Based on:
# --------------------------------------------------------
# DARTS
# Copyright (c) 2018, Hanxiao Liu.
# Licensed under the Apache License, Version 2.0;
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Xavier
from paddle.fluid.initializer import Normal
from paddle.fluid.initializer import Constant
OPS = {
'none' : lambda input, C, stride, name, affine: Zero(input, stride, name),
'avg_pool_3x3' : lambda input, C, stride, name, affine: fluid.layers.pool2d(input, 3, 'avg', pool_stride=stride, pool_padding=1, name=name),
'max_pool_3x3' : lambda input, C, stride, name, affine: fluid.layers.pool2d(input, 3, 'max', pool_stride=stride, pool_padding=1, name=name),
'skip_connect' : lambda input,C, stride, name, affine: Identity(input, name) if stride == 1 else FactorizedReduce(input, C, name=name, affine=affine),
'sep_conv_3x3' : lambda input,C, stride, name, affine: SepConv(input, C, C, 3, stride, 1, name=name, affine=affine),
'sep_conv_5x5' : lambda input,C, stride, name, affine: SepConv(input, C, C, 5, stride, 2, name=name, affine=affine),
'sep_conv_7x7' : lambda input,C, stride, name, affine: SepConv(input, C, C, 7, stride, 3, name=name, affine=affine),
'dil_conv_3x3' : lambda input,C, stride, name, affine: DilConv(input, C, C, 3, stride, 2, 2, name=name, affine=affine),
'dil_conv_5x5' : lambda input,C, stride, name, affine: DilConv(input, C, C, 5, stride, 4, 2, name=name, affine=affine),
'conv_7x1_1x7' : lambda input,C, stride, name, affine: SevenConv(input, C, name=name, affine=affine)
}
def ReLUConvBN(input, C_out, kernel_size, stride, padding, name='',
affine=True):
relu_a = fluid.layers.relu(input)
conv2d_a = fluid.layers.conv2d(
relu_a,
C_out,
kernel_size,
stride,
padding,
param_attr=ParamAttr(
initializer=Xavier(
uniform=False, fan_in=0),
name=name + 'op.1.weight'),
bias_attr=False)
if affine:
reluconvbn_out = fluid.layers.batch_norm(
conv2d_a,
param_attr=ParamAttr(
initializer=Constant(1.), name=name + 'op.2.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.), name=name + 'op.2.bias'),
moving_mean_name=name + 'op.2.running_mean',
moving_variance_name=name + 'op.2.running_var')
else:
reluconvbn_out = fluid.layers.batch_norm(
conv2d_a,
param_attr=ParamAttr(
initializer=Constant(1.),
learning_rate=0.,
name=name + 'op.2.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.),
learning_rate=0.,
name=name + 'op.2.bias'),
moving_mean_name=name + 'op.2.running_mean',
moving_variance_name=name + 'op.2.running_var')
return reluconvbn_out
def DilConv(input,
C_in,
C_out,
kernel_size,
stride,
padding,
dilation,
name='',
affine=True):
relu_a = fluid.layers.relu(input)
conv2d_a = fluid.layers.conv2d(
relu_a,
C_in,
kernel_size,
stride,
padding,
dilation,
groups=C_in,
param_attr=ParamAttr(
initializer=Xavier(
uniform=False, fan_in=0),
name=name + 'op.1.weight'),
bias_attr=False,
use_cudnn=False)
conv2d_b = fluid.layers.conv2d(
conv2d_a,
C_out,
1,
param_attr=ParamAttr(
initializer=Xavier(
uniform=False, fan_in=0),
name=name + 'op.2.weight'),
bias_attr=False)
if affine:
dilconv_out = fluid.layers.batch_norm(
conv2d_b,
param_attr=ParamAttr(
initializer=Constant(1.), name=name + 'op.3.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.), name=name + 'op.3.bias'),
moving_mean_name=name + 'op.3.running_mean',
moving_variance_name=name + 'op.3.running_var')
else:
dilconv_out = fluid.layers.batch_norm(
conv2d_b,
param_attr=ParamAttr(
initializer=Constant(1.),
learning_rate=0.,
name=name + 'op.3.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.),
learning_rate=0.,
name=name + 'op.3.bias'),
moving_mean_name=name + 'op.3.running_mean',
moving_variance_name=name + 'op.3.running_var')
return dilconv_out
def SepConv(input,
C_in,
C_out,
kernel_size,
stride,
padding,
name='',
affine=True):
relu_a = fluid.layers.relu(input)
conv2d_a = fluid.layers.conv2d(
relu_a,
C_in,
kernel_size,
stride,
padding,
groups=C_in,
param_attr=ParamAttr(
initializer=Xavier(
uniform=False, fan_in=0),
name=name + 'op.1.weight'),
bias_attr=False,
use_cudnn=False)
conv2d_b = fluid.layers.conv2d(
conv2d_a,
C_in,
1,
param_attr=ParamAttr(
initializer=Xavier(
uniform=False, fan_in=0),
name=name + 'op.2.weight'),
bias_attr=False)
if affine:
bn_a = fluid.layers.batch_norm(
conv2d_b,
param_attr=ParamAttr(
initializer=Constant(1.), name=name + 'op.3.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.), name=name + 'op.3.bias'),
moving_mean_name=name + 'op.3.running_mean',
moving_variance_name=name + 'op.3.running_var')
else:
bn_a = fluid.layers.batch_norm(
conv2d_b,
param_attr=ParamAttr(
initializer=Constant(1.),
learning_rate=0.,
name=name + 'op.3.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.),
learning_rate=0.,
name=name + 'op.3.bias'),
moving_mean_name=name + 'op.3.running_mean',
moving_variance_name=name + 'op.3.running_var')
relu_b = fluid.layers.relu(bn_a)
conv2d_d = fluid.layers.conv2d(
relu_b,
C_in,
kernel_size,
1,
padding,
groups=C_in,
param_attr=ParamAttr(
initializer=Xavier(
uniform=False, fan_in=0),
name=name + 'op.5.weight'),
bias_attr=False,
use_cudnn=False)
conv2d_e = fluid.layers.conv2d(
conv2d_d,
C_out,
1,
param_attr=ParamAttr(
initializer=Xavier(
uniform=False, fan_in=0),
name=name + 'op.6.weight'),
bias_attr=False)
if affine:
sepconv_out = fluid.layers.batch_norm(
conv2d_e,
param_attr=ParamAttr(
initializer=Constant(1.), name=name + 'op.7.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.), name=name + 'op.7.bias'),
moving_mean_name=name + 'op.7.running_mean',
moving_variance_name=name + 'op.7.running_var')
else:
sepconv_out = fluid.layers.batch_norm(
conv2d_e,
param_attr=ParamAttr(
initializer=Constant(1.),
learning_rate=0.,
name=name + 'op.7.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.),
learning_rate=0.,
name=name + 'op.7.bias'),
moving_mean_name=name + 'op.7.running_mean',
moving_variance_name=name + 'op.7.running_var')
return sepconv_out
def SevenConv(input, C_out, stride, name='', affine=True):
relu_a = fluid.layers.relu(input)
conv2d_a = fluid.layers.conv2d(
relu_a,
C_out, (1, 7), (1, stride), (0, 3),
param_attr=ParamAttr(
initializer=Xavier(
uniform=False, fan_in=0),
name=name + 'op.1.weight'),
bias_attr=False)
conv2d_b = fluid.layers.conv2d(
conv2d_a,
C_out, (7, 1), (stride, 1), (3, 0),
param_attr=ParamAttr(
initializer=Xavier(
uniform=False, fan_in=0),
name=name + 'op.2.weight'),
bias_attr=False)
if affine:
out = fluid.layers.batch_norm(
conv2d_b,
param_attr=ParamAttr(
initializer=Constant(1.), name=name + 'op.3.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.), name=name + 'op.3.bias'),
moving_mean_name=name + 'op.3.running_mean',
moving_variance_name=name + 'op.3.running_var')
else:
out = fluid.layers.batch_norm(
conv2d_b,
param_attr=ParamAttr(
initializer=Constant(1.),
learning_rate=0.,
name=name + 'op.3.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.),
learning_rate=0.,
name=name + 'op.3.bias'),
moving_mean_name=name + 'op.3.running_mean',
moving_variance_name=name + 'op.3.running_var')
def Identity(input, name=''):
return input
def Zero(input, stride, name=''):
ones = np.ones(input.shape[-2:])
ones[::stride, ::stride] = 0
ones = fluid.layers.assign(ones)
return input * ones
def FactorizedReduce(input, C_out, name='', affine=True):
relu_a = fluid.layers.relu(input)
conv2d_a = fluid.layers.conv2d(
relu_a,
C_out // 2,
1,
2,
param_attr=ParamAttr(
initializer=Xavier(
uniform=False, fan_in=0),
name=name + 'conv_1.weight'),
bias_attr=False)
h_end = relu_a.shape[2]
w_end = relu_a.shape[3]
slice_a = fluid.layers.slice(relu_a, [2, 3], [1, 1], [h_end, w_end])
conv2d_b = fluid.layers.conv2d(
slice_a,
C_out // 2,
1,
2,
param_attr=ParamAttr(
initializer=Xavier(
uniform=False, fan_in=0),
name=name + 'conv_2.weight'),
bias_attr=False)
out = fluid.layers.concat([conv2d_a, conv2d_b], axis=1)
if affine:
out = fluid.layers.batch_norm(
out,
param_attr=ParamAttr(
initializer=Constant(1.), name=name + 'bn.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.), name=name + 'bn.bias'),
moving_mean_name=name + 'bn.running_mean',
moving_variance_name=name + 'bn.running_var')
else:
out = fluid.layers.batch_norm(
out,
param_attr=ParamAttr(
initializer=Constant(1.),
learning_rate=0.,
name=name + 'bn.weight'),
bias_attr=ParamAttr(
initializer=Constant(0.),
learning_rate=0.,
name=name + 'bn.bias'),
moving_mean_name=name + 'bn.running_mean',
moving_variance_name=name + 'bn.running_var')
return out
# Copyright (c) 2019 PaddlePaddle Authors. All Rig hts Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Based on:
# --------------------------------------------------------
# DARTS
# Copyright (c) 2018, Hanxiao Liu.
# Licensed under the Apache License, Version 2.0;
# --------------------------------------------------------
"""
CIFAR-10 dataset.
This module will download dataset from
https://www.cs.toronto.edu/~kriz/cifar.html and parse train/test set into
paddle reader creators.
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes,
with 6000 images per class. There are 50000 training images and 10000 test images.
"""
from PIL import Image
from PIL import ImageOps
import numpy as np
import cPickle
import random
import utils
import paddle.fluid as fluid
import time
import os
import functools
import paddle.reader
__all__ = ['train10', 'test10']
image_size = 32
image_depth = 3
half_length = 8
CIFAR_MEAN = [0.4914, 0.4822, 0.4465]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
def generate_reshape_label(label, batch_size, CIFAR_CLASSES=10):
reshape_label = np.zeros((batch_size, 1), dtype='int32')
reshape_non_label = np.zeros(
(batch_size * (CIFAR_CLASSES - 1), 1), dtype='int32')
num = 0
for i in range(batch_size):
label_i = label[i]
reshape_label[i] = label_i + i * CIFAR_CLASSES
for j in range(CIFAR_CLASSES):
if label_i != j:
reshape_non_label[num] = \
j + i * CIFAR_CLASSES
num += 1
return reshape_label, reshape_non_label
def generate_bernoulli_number(batch_size, CIFAR_CLASSES=10):
rcc_iters = 50
rad_var = np.zeros((rcc_iters, batch_size, CIFAR_CLASSES - 1))
for i in range(rcc_iters):
bernoulli_num = np.random.binomial(size=batch_size, n=1, p=0.5)
bernoulli_map = np.array([])
ones = np.ones((CIFAR_CLASSES - 1, 1))
for batch_id in range(batch_size):
num = bernoulli_num[batch_id]
var_id = 2 * ones * num - 1
bernoulli_map = np.append(bernoulli_map, var_id)
rad_var[i] = bernoulli_map.reshape((batch_size, CIFAR_CLASSES - 1))
return rad_var.astype('float32')
def preprocess(sample, is_training, args):
image_array = sample.reshape(3, image_size, image_size)
rgb_array = np.transpose(image_array, (1, 2, 0))
img = Image.fromarray(rgb_array, 'RGB')
if is_training:
# pad and ramdom crop
img = ImageOps.expand(img, (4, 4, 4, 4), fill=0) # pad to 40 * 40 * 3
left_top = np.random.randint(9, size=2) # rand 0 - 8
img = img.crop((left_top[0], left_top[1], left_top[0] + image_size,
left_top[1] + image_size))
if np.random.randint(2):
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img = np.array(img).astype(np.float32)
# per_image_standardization
img_float = img / 255.0
img = (img_float - CIFAR_MEAN) / CIFAR_STD
if is_training and args.cutout:
center = np.random.randint(image_size, size=2)
offset_width = max(0, center[0] - half_length)
offset_height = max(0, center[1] - half_length)
target_width = min(center[0] + half_length, image_size)
target_height = min(center[1] + half_length, image_size)
for i in range(offset_height, target_height):
for j in range(offset_width, target_width):
img[i][j][:] = 0.0
img = np.transpose(img, (2, 0, 1))
return img
def reader_creator_filepath(filename, sub_name, is_training, args):
files = os.listdir(filename)
names = [each_item for each_item in files if sub_name in each_item]
names.sort()
datasets = []
for name in names:
print("Reading file " + name)
batch = cPickle.load(open(filename + name, 'rb'))
data = batch['data']
labels = batch.get('labels', batch.get('fine_labels', None))
assert labels is not None
dataset = zip(data, labels)
datasets.extend(dataset)
random.shuffle(datasets)
def read_batch(datasets, args):
for sample, label in datasets:
im = preprocess(sample, is_training, args)
yield im, [int(label)]
def reader():
batch_data = []
batch_label = []
for data, label in read_batch(datasets, args):
batch_data.append(data)
batch_label.append(label)
if len(batch_data) == args.batch_size:
batch_data = np.array(batch_data, dtype='float32')
batch_label = np.array(batch_label, dtype='int64')
if is_training:
flatten_label, flatten_non_label = \
generate_reshape_label(batch_label, args.batch_size)
rad_var = generate_bernoulli_number(args.batch_size)
mixed_x, y_a, y_b, lam = utils.mixup_data(
batch_data, batch_label, args.batch_size,
args.mix_alpha)
batch_out = [[mixed_x, y_a, y_b, lam, flatten_label, \
flatten_non_label, rad_var]]
yield batch_out
else:
batch_out = [[batch_data, batch_label]]
yield batch_out
batch_data = []
batch_label = []
return reader
def train10(args):
"""
CIFAR-10 training set creator.
It returns a reader creator, each sample in the reader is image pixels in
[0, 1] and label in [0, 9].
:return: Training reader creator
:rtype: callable
"""
return reader_creator_filepath(args.data, 'data_batch', True, args)
def test10(args):
"""
CIFAR-10 test set creator.
It returns a reader creator, each sample in the reader is image pixels in
[0, 1] and label in [0, 9].
:return: Test reader creator.
:rtype: callable
"""
return reader_creator_filepath(args.data, 'test_batch', False, args)
CUDA_VISIBLE_DEVICES=0 python -u train_mixup.py \
--batch_size=80 \
--auxiliary \
--weight_decay=0.0003 \
--learning_rate=0.025 \
--lrc_loss_lambda=0.7 \
--cutout
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
#
# Based on:
# --------------------------------------------------------
# DARTS
# Copyright (c) 2018, Hanxiao Liu.
# Licensed under the Apache License, Version 2.0;
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from learning_rate import cosine_decay
import numpy as np
import argparse
from model import NetworkCIFAR as Network
import reader
import sys
import os
import time
import logging
import genotypes
import paddle.fluid as fluid
import shutil
import utils
import cPickle as cp
parser = argparse.ArgumentParser("cifar")
parser.add_argument(
'--data',
type=str,
default='./dataset/cifar/cifar-10-batches-py/',
help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
parser.add_argument(
'--learning_rate', type=float, default=0.025, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument(
'--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument(
'--report_freq', type=float, default=50, help='report frequency')
parser.add_argument(
'--epochs', type=int, default=600, help='num of training epochs')
parser.add_argument(
'--init_channels', type=int, default=36, help='num of init channels')
parser.add_argument(
'--layers', type=int, default=20, help='total number of layers')
parser.add_argument(
'--model_path',
type=str,
default='saved_models',
help='path to save the model')
parser.add_argument(
'--auxiliary',
action='store_true',
default=False,
help='use auxiliary tower')
parser.add_argument(
'--auxiliary_weight',
type=float,
default=0.4,
help='weight for auxiliary loss')
parser.add_argument(
'--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument(
'--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument(
'--drop_path_prob', type=float, default=0.2, help='drop path probability')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument(
'--arch', type=str, default='DARTS', help='which architecture to use')
parser.add_argument(
'--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument(
'--lr_exp_decay',
action='store_true',
default=False,
help='use exponential_decay learning_rate')
parser.add_argument('--mix_alpha', type=float, default=0.5, help='mixup alpha')
parser.add_argument(
'--lrc_loss_lambda', default=0, type=float, help='lrc_loss_lambda')
parser.add_argument(
'--loss_type',
default=1,
type=float,
help='loss_type 0: cross entropy 1: multi margin loss 2: max margin loss')
args = parser.parse_args()
CIFAR_CLASSES = 10
dataset_train_size = 50000
image_size = 32
def main():
image_shape = [3, image_size, image_size]
devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
devices_num = len(devices.split(","))
logging.info("args = %s", args)
genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
args.auxiliary, genotype)
steps_one_epoch = dataset_train_size / (devices_num * args.batch_size)
train(model, args, image_shape, steps_one_epoch)
def build_program(main_prog, startup_prog, args, is_train, model, im_shape,
steps_one_epoch):
out = []
with fluid.program_guard(main_prog, startup_prog):
py_reader = model.build_input(im_shape, args.batch_size, is_train)
if is_train:
with fluid.unique_name.guard():
loss = model.train_model(py_reader, args.init_channels,
args.auxiliary, args.auxiliary_weight,
args.batch_size, args.lrc_loss_lambda)
optimizer = fluid.optimizer.Momentum(
learning_rate=cosine_decay(args.learning_rate, \
args.epochs, steps_one_epoch),
regularization=fluid.regularizer.L2Decay(\
args.weight_decay),
momentum=args.momentum)
optimizer.minimize(loss)
out = [py_reader, loss]
else:
with fluid.unique_name.guard():
loss, acc_1, acc_5 = model.test_model(py_reader,
args.init_channels)
out = [py_reader, loss, acc_1, acc_5]
return out
def train(model, args, im_shape, steps_one_epoch):
train_startup_prog = fluid.Program()
test_startup_prog = fluid.Program()
train_prog = fluid.Program()
test_prog = fluid.Program()
train_py_reader, loss_train = build_program(train_prog, train_startup_prog,
args, True, model, im_shape,
steps_one_epoch)
test_py_reader, loss_test, acc_1, acc_5 = build_program(
test_prog, test_startup_prog, args, False, model, im_shape,
steps_one_epoch)
test_prog = test_prog.clone(for_test=True)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(train_startup_prog)
exe.run(test_startup_prog)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 1
train_exe = fluid.ParallelExecutor(
main_program=train_prog,
use_cuda=True,
loss_name=loss_train.name,
exec_strategy=exec_strategy)
train_reader = reader.train10(args)
test_reader = reader.test10(args)
train_py_reader.decorate_paddle_reader(train_reader)
test_py_reader.decorate_paddle_reader(test_reader)
fluid.clip.set_gradient_clip(fluid.clip.GradientClipByNorm(args.grad_clip))
fluid.memory_optimize(fluid.default_main_program())
def save_model(postfix, main_prog):
model_path = os.path.join(args.model_path, postfix)
if os.path.isdir(model_path):
shutil.rmtree(model_path)
fluid.io.save_persistables(exe, model_path, main_program=main_prog)
def test(epoch_id):
test_fetch_list = [loss_test, acc_1, acc_5]
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
test_py_reader.start()
test_start_time = time.time()
step_id = 0
try:
while True:
prev_test_start_time = test_start_time
test_start_time = time.time()
loss_test_v, acc_1_v, acc_5_v = exe.run(
test_prog, fetch_list=test_fetch_list)
objs.update(np.array(loss_test_v), args.batch_size)
top1.update(np.array(acc_1_v), args.batch_size)
top5.update(np.array(acc_5_v), args.batch_size)
if step_id % args.report_freq == 0:
print("Epoch {}, Step {}, acc_1 {}, acc_5 {}, time {}".
format(epoch_id, step_id,
np.array(acc_1_v),
np.array(acc_5_v), test_start_time -
prev_test_start_time))
step_id += 1
except fluid.core.EOFException:
test_py_reader.reset()
print("Epoch {0}, top1 {1}, top5 {2}".format(epoch_id, top1.avg,
top5.avg))
train_fetch_list = [loss_train]
epoch_start_time = time.time()
for epoch_id in range(args.epochs):
model.drop_path_prob = args.drop_path_prob * epoch_id / args.epochs
train_py_reader.start()
epoch_end_time = time.time()
if epoch_id > 0:
print("Epoch {}, total time {}".format(epoch_id - 1, epoch_end_time
- epoch_start_time))
epoch_start_time = epoch_end_time
epoch_end_time
start_time = time.time()
step_id = 0
try:
while True:
prev_start_time = start_time
start_time = time.time()
loss_v, = train_exe.run(
fetch_list=[v.name for v in train_fetch_list])
print("Epoch {}, Step {}, loss {}, time {}".format(epoch_id, step_id, \
np.array(loss_v).mean(), start_time-prev_start_time))
step_id += 1
sys.stdout.flush()
except fluid.core.EOFException:
train_py_reader.reset()
if epoch_id % 50 == 0 or epoch_id == args.epochs - 1:
save_model(str(epoch_id), train_prog)
test(epoch_id)
if __name__ == '__main__':
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Based on:
# --------------------------------------------------------
# DARTS
# Copyright (c) 2018, Hanxiao Liu.
# Licensed under the Apache License, Version 2.0;
# --------------------------------------------------------
import os
import sys
import time
import math
import numpy as np
def mixup_data(x, y, batch_size, alpha=1.0):
'''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
if alpha > 0.:
lam = np.random.beta(alpha, alpha)
else:
lam = 1.
index = np.random.permutation(batch_size)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x.astype('float32'), y_a.astype('int64'),\
y_b.astype('int64'), np.array(lam, dtype='float32')
class AvgrageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.avg = 0
self.sum = 0
self.cnt = 0
def update(self, val, n=1):
self.sum += val * n
self.cnt += n
self.avg = self.sum / self.cnt
#-*- coding: utf-8 -*-
import math
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from tqdm import tqdm
class DQNModel(object):
def __init__(self, state_dim, action_dim, gamma, hist_len, use_cuda=False):
self.img_height = state_dim[0]
self.img_width = state_dim[1]
self.action_dim = action_dim
self.gamma = gamma
self.exploration = 1.1
self.update_target_steps = 10000 // 4
self.hist_len = hist_len
self.use_cuda = use_cuda
self.global_step = 0
self._build_net()
def _get_inputs(self):
return fluid.layers.data(
name='state',
shape=[self.hist_len, self.img_height, self.img_width],
dtype='float32'), \
fluid.layers.data(
name='action', shape=[1], dtype='int32'), \
fluid.layers.data(
name='reward', shape=[], dtype='float32'), \
fluid.layers.data(
name='next_s',
shape=[self.hist_len, self.img_height, self.img_width],
dtype='float32'), \
fluid.layers.data(
name='isOver', shape=[], dtype='bool')
def _build_net(self):
self.predict_program = fluid.Program()
self.train_program = fluid.Program()
self._sync_program = fluid.Program()
with fluid.program_guard(self.predict_program):
state, action, reward, next_s, isOver = self._get_inputs()
self.pred_value = self.get_DQN_prediction(state)
with fluid.program_guard(self.train_program):
state, action, reward, next_s, isOver = self._get_inputs()
pred_value = self.get_DQN_prediction(state)
reward = fluid.layers.clip(reward, min=-1.0, max=1.0)
action_onehot = fluid.layers.one_hot(action, self.action_dim)
action_onehot = fluid.layers.cast(action_onehot, dtype='float32')
pred_action_value = fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1)
targetQ_predict_value = self.get_DQN_prediction(next_s, target=True)
best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1)
best_v.stop_gradient = True
target = reward + (1.0 - fluid.layers.cast(
isOver, dtype='float32')) * self.gamma * best_v
cost = fluid.layers.square_error_cost(pred_action_value, target)
cost = fluid.layers.reduce_mean(cost)
optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3)
optimizer.minimize(cost)
vars = list(self.train_program.list_vars())
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars_name = [
x.name.replace('target', 'policy') for x in target_vars]
policy_vars = list(filter(
lambda x: x.name in policy_vars_name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
with fluid.program_guard(self._sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
# fluid exe
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
self.exe = fluid.Executor(place)
self.exe.run(fluid.default_startup_program())
def get_DQN_prediction(self, image, target=False):
image = image / 255.0
variable_field = 'target' if target else 'policy'
conv1 = fluid.layers.conv2d(
input=image,
num_filters=32,
filter_size=5,
stride=1,
padding=2,
act='relu',
param_attr=ParamAttr(name='{}_conv1'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv1_b'.format(variable_field)))
max_pool1 = fluid.layers.pool2d(
input=conv1, pool_size=2, pool_stride=2, pool_type='max')
conv2 = fluid.layers.conv2d(
input=max_pool1,
num_filters=32,
filter_size=5,
stride=1,
padding=2,
act='relu',
param_attr=ParamAttr(name='{}_conv2'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv2_b'.format(variable_field)))
max_pool2 = fluid.layers.pool2d(
input=conv2, pool_size=2, pool_stride=2, pool_type='max')
conv3 = fluid.layers.conv2d(
input=max_pool2,
num_filters=64,
filter_size=4,
stride=1,
padding=1,
act='relu',
param_attr=ParamAttr(name='{}_conv3'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv3_b'.format(variable_field)))
max_pool3 = fluid.layers.pool2d(
input=conv3, pool_size=2, pool_stride=2, pool_type='max')
conv4 = fluid.layers.conv2d(
input=max_pool3,
num_filters=64,
filter_size=3,
stride=1,
padding=1,
act='relu',
param_attr=ParamAttr(name='{}_conv4'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv4_b'.format(variable_field)))
flatten = fluid.layers.flatten(conv4, axis=1)
out = fluid.layers.fc(
input=flatten,
size=self.action_dim,
param_attr=ParamAttr(name='{}_fc1'.format(variable_field)),
bias_attr=ParamAttr(name='{}_fc1_b'.format(variable_field)))
return out
def act(self, state, train_or_test):
sample = np.random.random()
if train_or_test == 'train' and sample < self.exploration:
act = np.random.randint(self.action_dim)
else:
if np.random.random() < 0.01:
act = np.random.randint(self.action_dim)
else:
state = np.expand_dims(state, axis=0)
pred_Q = self.exe.run(self.predict_program,
feed={'state': state.astype('float32')},
fetch_list=[self.pred_value])[0]
pred_Q = np.squeeze(pred_Q, axis=0)
act = np.argmax(pred_Q)
if train_or_test == 'train':
self.exploration = max(0.1, self.exploration - 1e-6)
return act
def train(self, state, action, reward, next_state, isOver):
if self.global_step % self.update_target_steps == 0:
self.sync_target_network()
self.global_step += 1
action = np.expand_dims(action, -1)
self.exe.run(self.train_program,
feed={
'state': state.astype('float32'),
'action': action.astype('int32'),
'reward': reward,
'next_s': next_state.astype('float32'),
'isOver': isOver
})
def sync_target_network(self):
self.exe.run(self._sync_program)
#-*- coding: utf-8 -*-
import math
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from tqdm import tqdm
class DoubleDQNModel(object):
def __init__(self, state_dim, action_dim, gamma, hist_len, use_cuda=False):
self.img_height = state_dim[0]
self.img_width = state_dim[1]
self.action_dim = action_dim
self.gamma = gamma
self.exploration = 1.1
self.update_target_steps = 10000 // 4
self.hist_len = hist_len
self.use_cuda = use_cuda
self.global_step = 0
self._build_net()
def _get_inputs(self):
return fluid.layers.data(
name='state',
shape=[self.hist_len, self.img_height, self.img_width],
dtype='float32'), \
fluid.layers.data(
name='action', shape=[1], dtype='int32'), \
fluid.layers.data(
name='reward', shape=[], dtype='float32'), \
fluid.layers.data(
name='next_s',
shape=[self.hist_len, self.img_height, self.img_width],
dtype='float32'), \
fluid.layers.data(
name='isOver', shape=[], dtype='bool')
def _build_net(self):
self.predict_program = fluid.Program()
self.train_program = fluid.Program()
self._sync_program = fluid.Program()
with fluid.program_guard(self.predict_program):
state, action, reward, next_s, isOver = self._get_inputs()
self.pred_value = self.get_DQN_prediction(state)
with fluid.program_guard(self.train_program):
state, action, reward, next_s, isOver = self._get_inputs()
pred_value = self.get_DQN_prediction(state)
reward = fluid.layers.clip(reward, min=-1.0, max=1.0)
action_onehot = fluid.layers.one_hot(action, self.action_dim)
action_onehot = fluid.layers.cast(action_onehot, dtype='float32')
pred_action_value = fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1)
targetQ_predict_value = self.get_DQN_prediction(next_s, target=True)
next_s_predcit_value = self.get_DQN_prediction(next_s)
greedy_action = fluid.layers.argmax(next_s_predcit_value, axis=1)
greedy_action = fluid.layers.unsqueeze(greedy_action, axes=[1])
predict_onehot = fluid.layers.one_hot(greedy_action, self.action_dim)
best_v = fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(predict_onehot, targetQ_predict_value),
dim=1)
best_v.stop_gradient = True
target = reward + (1.0 - fluid.layers.cast(
isOver, dtype='float32')) * self.gamma * best_v
cost = fluid.layers.square_error_cost(pred_action_value, target)
cost = fluid.layers.reduce_mean(cost)
optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3)
optimizer.minimize(cost)
vars = list(self.train_program.list_vars())
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars_name = [
x.name.replace('target', 'policy') for x in target_vars]
policy_vars = list(filter(
lambda x: x.name in policy_vars_name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
with fluid.program_guard(self._sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
# fluid exe
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
self.exe = fluid.Executor(place)
self.exe.run(fluid.default_startup_program())
def get_DQN_prediction(self, image, target=False):
image = image / 255.0
variable_field = 'target' if target else 'policy'
conv1 = fluid.layers.conv2d(
input=image,
num_filters=32,
filter_size=5,
stride=1,
padding=2,
act='relu',
param_attr=ParamAttr(name='{}_conv1'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv1_b'.format(variable_field)))
max_pool1 = fluid.layers.pool2d(
input=conv1, pool_size=2, pool_stride=2, pool_type='max')
conv2 = fluid.layers.conv2d(
input=max_pool1,
num_filters=32,
filter_size=5,
stride=1,
padding=2,
act='relu',
param_attr=ParamAttr(name='{}_conv2'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv2_b'.format(variable_field)))
max_pool2 = fluid.layers.pool2d(
input=conv2, pool_size=2, pool_stride=2, pool_type='max')
conv3 = fluid.layers.conv2d(
input=max_pool2,
num_filters=64,
filter_size=4,
stride=1,
padding=1,
act='relu',
param_attr=ParamAttr(name='{}_conv3'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv3_b'.format(variable_field)))
max_pool3 = fluid.layers.pool2d(
input=conv3, pool_size=2, pool_stride=2, pool_type='max')
conv4 = fluid.layers.conv2d(
input=max_pool3,
num_filters=64,
filter_size=3,
stride=1,
padding=1,
act='relu',
param_attr=ParamAttr(name='{}_conv4'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv4_b'.format(variable_field)))
flatten = fluid.layers.flatten(conv4, axis=1)
out = fluid.layers.fc(
input=flatten,
size=self.action_dim,
param_attr=ParamAttr(name='{}_fc1'.format(variable_field)),
bias_attr=ParamAttr(name='{}_fc1_b'.format(variable_field)))
return out
def act(self, state, train_or_test):
sample = np.random.random()
if train_or_test == 'train' and sample < self.exploration:
act = np.random.randint(self.action_dim)
else:
if np.random.random() < 0.01:
act = np.random.randint(self.action_dim)
else:
state = np.expand_dims(state, axis=0)
pred_Q = self.exe.run(self.predict_program,
feed={'state': state.astype('float32')},
fetch_list=[self.pred_value])[0]
pred_Q = np.squeeze(pred_Q, axis=0)
act = np.argmax(pred_Q)
if train_or_test == 'train':
self.exploration = max(0.1, self.exploration - 1e-6)
return act
def train(self, state, action, reward, next_state, isOver):
if self.global_step % self.update_target_steps == 0:
self.sync_target_network()
self.global_step += 1
action = np.expand_dims(action, -1)
self.exe.run(self.train_program,
feed={
'state': state.astype('float32'),
'action': action.astype('int32'),
'reward': reward,
'next_s': next_state.astype('float32'),
'isOver': isOver
})
def sync_target_network(self):
self.exe.run(self._sync_program)
#-*- coding: utf-8 -*-
import math
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from tqdm import tqdm
class DuelingDQNModel(object):
def __init__(self, state_dim, action_dim, gamma, hist_len, use_cuda=False):
self.img_height = state_dim[0]
self.img_width = state_dim[1]
self.action_dim = action_dim
self.gamma = gamma
self.exploration = 1.1
self.update_target_steps = 10000 // 4
self.hist_len = hist_len
self.use_cuda = use_cuda
self.global_step = 0
self._build_net()
def _get_inputs(self):
return fluid.layers.data(
name='state',
shape=[self.hist_len, self.img_height, self.img_width],
dtype='float32'), \
fluid.layers.data(
name='action', shape=[1], dtype='int32'), \
fluid.layers.data(
name='reward', shape=[], dtype='float32'), \
fluid.layers.data(
name='next_s',
shape=[self.hist_len, self.img_height, self.img_width],
dtype='float32'), \
fluid.layers.data(
name='isOver', shape=[], dtype='bool')
def _build_net(self):
self.predict_program = fluid.Program()
self.train_program = fluid.Program()
self._sync_program = fluid.Program()
with fluid.program_guard(self.predict_program):
state, action, reward, next_s, isOver = self._get_inputs()
self.pred_value = self.get_DQN_prediction(state)
with fluid.program_guard(self.train_program):
state, action, reward, next_s, isOver = self._get_inputs()
pred_value = self.get_DQN_prediction(state)
reward = fluid.layers.clip(reward, min=-1.0, max=1.0)
action_onehot = fluid.layers.one_hot(action, self.action_dim)
action_onehot = fluid.layers.cast(action_onehot, dtype='float32')
pred_action_value = fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1)
targetQ_predict_value = self.get_DQN_prediction(next_s, target=True)
best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1)
best_v.stop_gradient = True
target = reward + (1.0 - fluid.layers.cast(
isOver, dtype='float32')) * self.gamma * best_v
cost = fluid.layers.square_error_cost(pred_action_value, target)
cost = fluid.layers.reduce_mean(cost)
optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3)
optimizer.minimize(cost)
vars = list(self.train_program.list_vars())
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars_name = [
x.name.replace('target', 'policy') for x in target_vars]
policy_vars = list(filter(
lambda x: x.name in policy_vars_name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
with fluid.program_guard(self._sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
# fluid exe
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
self.exe = fluid.Executor(place)
self.exe.run(fluid.default_startup_program())
def get_DQN_prediction(self, image, target=False):
image = image / 255.0
variable_field = 'target' if target else 'policy'
conv1 = fluid.layers.conv2d(
input=image,
num_filters=32,
filter_size=5,
stride=1,
padding=2,
act='relu',
param_attr=ParamAttr(name='{}_conv1'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv1_b'.format(variable_field)))
max_pool1 = fluid.layers.pool2d(
input=conv1, pool_size=2, pool_stride=2, pool_type='max')
conv2 = fluid.layers.conv2d(
input=max_pool1,
num_filters=32,
filter_size=5,
stride=1,
padding=2,
act='relu',
param_attr=ParamAttr(name='{}_conv2'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv2_b'.format(variable_field)))
max_pool2 = fluid.layers.pool2d(
input=conv2, pool_size=2, pool_stride=2, pool_type='max')
conv3 = fluid.layers.conv2d(
input=max_pool2,
num_filters=64,
filter_size=4,
stride=1,
padding=1,
act='relu',
param_attr=ParamAttr(name='{}_conv3'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv3_b'.format(variable_field)))
max_pool3 = fluid.layers.pool2d(
input=conv3, pool_size=2, pool_stride=2, pool_type='max')
conv4 = fluid.layers.conv2d(
input=max_pool3,
num_filters=64,
filter_size=3,
stride=1,
padding=1,
act='relu',
param_attr=ParamAttr(name='{}_conv4'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv4_b'.format(variable_field)))
flatten = fluid.layers.flatten(conv4, axis=1)
value = fluid.layers.fc(
input=flatten,
size=1,
param_attr=ParamAttr(name='{}_value_fc'.format(variable_field)),
bias_attr=ParamAttr(name='{}_value_fc_b'.format(variable_field)))
advantage = fluid.layers.fc(
input=flatten,
size=self.action_dim,
param_attr=ParamAttr(name='{}_advantage_fc'.format(variable_field)),
bias_attr=ParamAttr(
name='{}_advantage_fc_b'.format(variable_field)))
Q = advantage + (value - fluid.layers.reduce_mean(
advantage, dim=1, keep_dim=True))
return Q
def act(self, state, train_or_test):
sample = np.random.random()
if train_or_test == 'train' and sample < self.exploration:
act = np.random.randint(self.action_dim)
else:
if np.random.random() < 0.01:
act = np.random.randint(self.action_dim)
else:
state = np.expand_dims(state, axis=0)
pred_Q = self.exe.run(self.predict_program,
feed={'state': state.astype('float32')},
fetch_list=[self.pred_value])[0]
pred_Q = np.squeeze(pred_Q, axis=0)
act = np.argmax(pred_Q)
if train_or_test == 'train':
self.exploration = max(0.1, self.exploration - 1e-6)
return act
def train(self, state, action, reward, next_state, isOver):
if self.global_step % self.update_target_steps == 0:
self.sync_target_network()
self.global_step += 1
action = np.expand_dims(action, -1)
self.exe.run(self.train_program,
feed={
'state': state.astype('float32'),
'action': action.astype('int32'),
'reward': reward,
'next_s': next_state.astype('float32'),
'isOver': isOver
})
def sync_target_network(self):
self.exe.run(self._sync_program)
[中文版](README_cn.md)
## Reproduce DQN, DoubleDQN, DuelingDQN model with Fluid version of PaddlePaddle
Based on PaddlePaddle's next-generation API Fluid, the DQN model of deep reinforcement learning is reproduced, and the same level of indicators of the paper is reproduced in the classic Atari game. The model receives the image of the game as input, and uses the end-to-end model to directly predict the next step. The repository contains the following three types of models:
+ DQN in
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
+ DoubleDQN in:
[Deep Reinforcement Learning with Double Q-Learning](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/viewPaper/12389)
+ DuelingDQN in:
[Dueling Network Architectures for Deep Reinforcement Learning](http://proceedings.mlr.press/v48/wangf16.html)
## Atari benchmark & performance
### Atari games introduction
Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari game.
### Pong game result
The average game rewards that can be obtained for the three models as the number of training steps changes during the training are as follows(about 3 hours/1 Million steps):
<div align="center">
<img src="assets/dqn.png" width="600" height="300" alt="DQN result"></img>
</div>
## How to use
### Dependencies:
+ python2.7
+ gym
+ tqdm
+ opencv-python
+ paddlepaddle-gpu>=1.0.0
+ ale_python_interface
### Install Dependencies:
+ Install PaddlePaddle:
recommended to compile and install PaddlePaddle from source code
+ Install other dependencies:
```
pip install -r requirement.txt
pip install gym[atari]
```
Install ale_python_interface, please see [here](https://github.com/mgbellemare/Arcade-Learning-Environment).
### Start Training:
```
# To train a model for Pong game with gpu (use DQN model as default)
python train.py --rom ./rom_files/pong.bin --use_cuda
# To train a model for Pong with DoubleDQN
python train.py --rom ./rom_files/pong.bin --use_cuda --alg DoubleDQN
# To train a model for Pong with DuelingDQN
python train.py --rom ./rom_files/pong.bin --use_cuda --alg DuelingDQN
```
To train more games, you can install more rom files from [here](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms).
### Start Testing:
```
# Play the game with saved best model and calculate the average rewards
python play.py --rom ./rom_files/pong.bin --use_cuda --model_path ./saved_model/DQN-pong
# Play the game with visualization
python play.py --rom ./rom_files/pong.bin --use_cuda --model_path ./saved_model/DQN-pong --viz 0.01
```
[Here](https://pan.baidu.com/s/1gIsbNw5V7tMeb74ojx-TMA) is saved models for Pong and Breakout games. You can use it to play the game directly.
## 基于PaddlePaddle的Fluid版本复现DQN, DoubleDQN, DuelingDQN三个模型
基于PaddlePaddle下一代API Fluid复现了深度强化学习领域的DQN模型,在经典的Atari 游戏上复现了论文同等水平的指标,模型接收游戏的图像作为输入,采用端到端的模型直接预测下一步要执行的控制信号,本仓库一共包含以下3类模型:
+ DQN模型:
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
+ DoubleDQN模型:
[Deep Reinforcement Learning with Double Q-Learning](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/viewPaper/12389)
+ DuelingDQN模型:
[Dueling Network Architectures for Deep Reinforcement Learning](http://proceedings.mlr.press/v48/wangf16.html)
## 模型效果:Atari游戏表现
### Atari游戏介绍
请点击[这里](https://gym.openai.com/envs/#atari)了解Atari游戏。
### Pong游戏训练结果
三个模型在训练过程中随着训练步数的变化,能得到的平均游戏奖励如下图所示(大概3小时每1百万步):
<div align="center">
<img src="assets/dqn.png" width="600" height="300" alt="DQN result"></img>
</div>
## 使用教程
### 依赖:
+ python2.7
+ gym
+ tqdm
+ opencv-python
+ paddlepaddle-gpu>=1.0.0
+ ale_python_interface
### 下载依赖:
+ 安装PaddlePaddle:
建议通过PaddlePaddle源码进行编译安装
+ 下载其它依赖:
```
pip install -r requirement.txt
pip install gym[atari]
```
安装ale_python_interface可以参考[这里](https://github.com/mgbellemare/Arcade-Learning-Environment)
### 训练模型:
```
# 使用GPU训练Pong游戏(默认使用DQN模型)
python train.py --rom ./rom_files/pong.bin --use_cuda
# 训练DoubleDQN模型
python train.py --rom ./rom_files/pong.bin --use_cuda --alg DoubleDQN
# 训练DuelingDQN模型
python train.py --rom ./rom_files/pong.bin --use_cuda --alg DuelingDQN
```
训练更多游戏,可以从[这里](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms)下载游戏rom
### 测试模型:
```
# Play the game with saved model and calculate the average rewards
# 使用训练过程中保存的最好模型玩游戏,以及计算平均奖励(rewards)
python play.py --rom ./rom_files/pong.bin --use_cuda --model_path ./saved_model/DQN-pong
# 以可视化的形式来玩游戏
python play.py --rom ./rom_files/pong.bin --use_cuda --model_path ./saved_model/DQN-pong --viz 0.01
```
[这里](https://pan.baidu.com/s/1gIsbNw5V7tMeb74ojx-TMA)是Pong和Breakout游戏训练好的模型,可以直接用来测试。
# -*- coding: utf-8 -*-
import numpy as np
import os
import cv2
import threading
import gym
from gym import spaces
from gym.envs.atari.atari_env import ACTION_MEANING
from atari_py import ALEInterface
__all__ = ['AtariPlayer']
ROM_URL = "https://github.com/openai/atari-py/tree/master/atari_py/atari_roms"
_ALE_LOCK = threading.Lock()
"""
The following AtariPlayer are copied or modified from tensorpack/tensorpack:
https://github.com/tensorpack/tensorpack/blob/master/examples/DeepQNetwork/atari.py
"""
class AtariPlayer(gym.Env):
"""
A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings.
Info:
score: the accumulated reward in the current game
gameOver: True when the current game is Over
"""
def __init__(self,
rom_file,
viz=0,
frame_skip=4,
nullop_start=30,
live_lost_as_eoe=True,
max_num_frames=0):
"""
Args:
rom_file: path to the rom
frame_skip: skip every k frames and repeat the action
viz: visualization to be done.
Set to 0 to disable.
Set to a positive number to be the delay between frames to show.
Set to a string to be a directory to store frames.
nullop_start: start with random number of null ops.
live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
max_num_frames: maximum number of frames per episode.
"""
super(AtariPlayer, self).__init__()
assert os.path.isfile(rom_file), \
"rom {} not found. Please download at {}".format(rom_file, ROM_URL)
try:
ALEInterface.setLoggerMode(ALEInterface.Logger.Error)
except AttributeError:
print("You're not using latest ALE")
# avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
with _ALE_LOCK:
self.ale = ALEInterface()
self.ale.setInt(b"random_seed", np.random.randint(0, 30000))
self.ale.setInt(b"max_num_frames_per_episode", max_num_frames)
self.ale.setBool(b"showinfo", False)
self.ale.setInt(b"frame_skip", 1)
self.ale.setBool(b'color_averaging', False)
# manual.pdf suggests otherwise.
self.ale.setFloat(b'repeat_action_probability', 0.0)
# viz setup
if isinstance(viz, str):
assert os.path.isdir(viz), viz
self.ale.setString(b'record_screen_dir', viz)
viz = 0
if isinstance(viz, int):
viz = float(viz)
self.viz = viz
if self.viz and isinstance(self.viz, float):
self.windowname = os.path.basename(rom_file)
cv2.startWindowThread()
cv2.namedWindow(self.windowname)
self.ale.loadROM(rom_file.encode('utf-8'))
self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet()
self.live_lost_as_eoe = live_lost_as_eoe
self.frame_skip = frame_skip
self.nullop_start = nullop_start
self.action_space = spaces.Discrete(len(self.actions))
self.observation_space = spaces.Box(low=0,
high=255,
shape=(self.height, self.width),
dtype=np.uint8)
self._restart_episode()
def get_action_meanings(self):
return [ACTION_MEANING[i] for i in self.actions]
def _grab_raw_image(self):
"""
:returns: the current 3-channel image
"""
m = self.ale.getScreenRGB()
return m.reshape((self.height, self.width, 3))
def _current_state(self):
"""
returns: a gray-scale (h, w) uint8 image
"""
ret = self._grab_raw_image()
# avoid missing frame issue: max-pooled over the last screen
ret = np.maximum(ret, self.last_raw_screen)
if self.viz:
if isinstance(self.viz, float):
cv2.imshow(self.windowname, ret)
cv2.waitKey(int(self.viz * 1000))
ret = ret.astype('float32')
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
return ret.astype('uint8') # to save some memory
def _restart_episode(self):
with _ALE_LOCK:
self.ale.reset_game()
# random null-ops start
n = np.random.randint(self.nullop_start)
self.last_raw_screen = self._grab_raw_image()
for k in range(n):
if k == n - 1:
self.last_raw_screen = self._grab_raw_image()
self.ale.act(0)
def reset(self):
if self.ale.game_over():
self._restart_episode()
return self._current_state()
def step(self, act):
oldlives = self.ale.lives()
r = 0
for k in range(self.frame_skip):
if k == self.frame_skip - 1:
self.last_raw_screen = self._grab_raw_image()
r += self.ale.act(self.actions[act])
newlives = self.ale.lives()
if self.ale.game_over() or \
(self.live_lost_as_eoe and newlives < oldlives):
break
isOver = self.ale.game_over()
if self.live_lost_as_eoe:
isOver = isOver or newlives < oldlives
info = {'ale.lives': newlives}
return self._current_state(), r, isOver, info
# -*- coding: utf-8 -*-
import numpy as np
from collections import deque
import gym
from gym import spaces
_v0, _v1 = gym.__version__.split('.')[:2]
assert int(_v0) > 0 or int(_v1) >= 10, gym.__version__
"""
The following wrappers are copied or modified from openai/baselines:
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""
class MapState(gym.ObservationWrapper):
def __init__(self, env, map_func):
gym.ObservationWrapper.__init__(self, env)
self._func = map_func
def observation(self, obs):
return self._func(obs)
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Buffer observations and stack across channels (last axis)."""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2]
self.observation_space = spaces.Box(low=0,
high=255,
shape=(shp[0], shp[1], chan * k),
dtype=np.uint8)
def reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset()
for _ in range(self.k - 1):
self.frames.append(np.zeros_like(ob))
self.frames.append(ob)
return self.observation()
def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self.observation(), reward, done, info
def observation(self):
assert len(self.frames) == self.k
return np.stack(self.frames, axis=0)
class _FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
def step(self, action):
return self.env.step(action)
def FireResetEnv(env):
if isinstance(env, gym.Wrapper):
baseenv = env.unwrapped
else:
baseenv = env
if 'FIRE' in baseenv.get_action_meanings():
return _FireResetEnv(env)
return env
class LimitLength(gym.Wrapper):
def __init__(self, env, k):
gym.Wrapper.__init__(self, env)
self.k = k
def reset(self):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
ob = self.env.reset()
self.cnt = 0
return ob
def step(self, action):
ob, r, done, info = self.env.step(action)
self.cnt += 1
if self.cnt == self.k:
done = True
return ob, r, done, info
# -*- coding: utf-8 -*-
import numpy as np
import copy
from collections import deque, namedtuple
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'isOver'])
class ReplayMemory(object):
def __init__(self, max_size, state_shape, context_len):
self.max_size = int(max_size)
self.state_shape = state_shape
self.context_len = int(context_len)
self.state = np.zeros((self.max_size, ) + state_shape, dtype='uint8')
self.action = np.zeros((self.max_size, ), dtype='int32')
self.reward = np.zeros((self.max_size, ), dtype='float32')
self.isOver = np.zeros((self.max_size, ), dtype='bool')
self._curr_size = 0
self._curr_pos = 0
self._context = deque(maxlen=context_len - 1)
def append(self, exp):
"""append a new experience into replay memory
"""
if self._curr_size < self.max_size:
self._assign(self._curr_pos, exp)
self._curr_size += 1
else:
self._assign(self._curr_pos, exp)
self._curr_pos = (self._curr_pos + 1) % self.max_size
if exp.isOver:
self._context.clear()
else:
self._context.append(exp)
def recent_state(self):
""" maintain recent state for training"""
lst = list(self._context)
states = [np.zeros(self.state_shape, dtype='uint8')] * \
(self._context.maxlen - len(lst))
states.extend([k.state for k in lst])
return states
def sample(self, idx):
""" return state, action, reward, isOver,
note that some frames in state may be generated from last episode,
they should be removed from state
"""
state = np.zeros(
(self.context_len + 1, ) + self.state_shape, dtype=np.uint8)
state_idx = np.arange(idx, idx + self.context_len + 1) % self._curr_size
# confirm that no frame was generated from last episode
has_last_episode = False
for k in range(self.context_len - 2, -1, -1):
to_check_idx = state_idx[k]
if self.isOver[to_check_idx]:
has_last_episode = True
state_idx = state_idx[k + 1:]
state[k + 1:] = self.state[state_idx]
break
if not has_last_episode:
state = self.state[state_idx]
real_idx = (idx + self.context_len - 1) % self._curr_size
action = self.action[real_idx]
reward = self.reward[real_idx]
isOver = self.isOver[real_idx]
return state, reward, action, isOver
def __len__(self):
return self._curr_size
def _assign(self, pos, exp):
self.state[pos] = exp.state
self.reward[pos] = exp.reward
self.action[pos] = exp.action
self.isOver[pos] = exp.isOver
def sample_batch(self, batch_size):
"""sample a batch from replay memory for training
"""
batch_idx = np.random.randint(
self._curr_size - self.context_len - 1, size=batch_size)
batch_idx = (self._curr_pos + batch_idx) % self._curr_size
batch_exp = [self.sample(i) for i in batch_idx]
return self._process_batch(batch_exp)
def _process_batch(self, batch_exp):
state = np.asarray([e[0] for e in batch_exp], dtype='uint8')
reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
action = np.asarray([e[2] for e in batch_exp], dtype='int8')
isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
return [state, action, reward, isOver]
#-*- coding: utf-8 -*-
import argparse
import os
import numpy as np
import paddle.fluid as fluid
from train import get_player
from tqdm import tqdm
def predict_action(exe, state, predict_program, feed_names, fetch_targets,
action_dim):
if np.random.random() < 0.01:
act = np.random.randint(action_dim)
else:
state = np.expand_dims(state, axis=0)
pred_Q = exe.run(predict_program,
feed={feed_names[0]: state.astype('float32')},
fetch_list=fetch_targets)[0]
pred_Q = np.squeeze(pred_Q, axis=0)
act = np.argmax(pred_Q)
return act
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--use_cuda', action='store_true', help='if set, use cuda')
parser.add_argument('--rom', type=str, required=True, help='atari rom')
parser.add_argument(
'--model_path', type=str, required=True, help='dirname to load model')
parser.add_argument(
'--viz',
type=float,
default=0,
help='''viz: visualization setting:
Set to 0 to disable;
Set to a positive number to be the delay between frames to show.
''')
args = parser.parse_args()
env = get_player(args.rom, viz=args.viz)
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.Scope()
with fluid.scope_guard(inference_scope):
[predict_program, feed_names,
fetch_targets] = fluid.io.load_inference_model(args.model_path, exe)
episode_reward = []
for _ in tqdm(xrange(30), desc='eval agent'):
state = env.reset()
total_reward = 0
while True:
action = predict_action(exe, state, predict_program, feed_names,
fetch_targets, env.action_space.n)
state, reward, isOver, info = env.step(action)
total_reward += reward
if isOver:
break
episode_reward.append(total_reward)
eval_reward = np.mean(episode_reward)
print('Average reward of 30 epidose: {}'.format(eval_reward))
numpy
gym
tqdm
opencv-python
paddlepaddle-gpu>=1.0.0
#-*- coding: utf-8 -*-
from DQN_agent import DQNModel
from DoubleDQN_agent import DoubleDQNModel
from DuelingDQN_agent import DuelingDQNModel
from atari import AtariPlayer
import paddle.fluid as fluid
import gym
import argparse
import cv2
from tqdm import tqdm
from expreplay import ReplayMemory, Experience
import numpy as np
import os
from datetime import datetime
from atari_wrapper import FrameStack, MapState, FireResetEnv, LimitLength
from collections import deque
UPDATE_FREQ = 4
MEMORY_SIZE = 1e6
MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
ACTION_REPEAT = 4 # aka FRAME_SKIP
UPDATE_FREQ = 4
def run_train_episode(agent, env, exp):
total_reward = 0
state = env.reset()
step = 0
while True:
step += 1
context = exp.recent_state()
context.append(state)
context = np.stack(context, axis=0)
action = agent.act(context, train_or_test='train')
next_state, reward, isOver, _ = env.step(action)
exp.append(Experience(state, action, reward, isOver))
# train model
# start training
if len(exp) > MEMORY_WARMUP_SIZE:
if step % UPDATE_FREQ == 0:
batch_all_state, batch_action, batch_reward, batch_isOver = exp.sample_batch(
args.batch_size)
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :]
batch_next_state = batch_all_state[:, 1:, :, :]
agent.train(batch_state, batch_action, batch_reward,
batch_next_state, batch_isOver)
total_reward += reward
state = next_state
if isOver:
break
return total_reward, step
def get_player(rom, viz=False, train=False):
env = AtariPlayer(
rom,
frame_skip=ACTION_REPEAT,
viz=viz,
live_lost_as_eoe=train,
max_num_frames=60000)
env = FireResetEnv(env)
env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
if not train:
# in training, context is taken care of in expreplay buffer
env = FrameStack(env, CONTEXT_LEN)
return env
def eval_agent(agent, env):
episode_reward = []
for _ in tqdm(range(30), desc='eval agent'):
state = env.reset()
total_reward = 0
step = 0
while True:
step += 1
action = agent.act(state, train_or_test='test')
state, reward, isOver, info = env.step(action)
total_reward += reward
if isOver:
break
episode_reward.append(total_reward)
eval_reward = np.mean(episode_reward)
return eval_reward
def train_agent():
env = get_player(args.rom, train=True)
test_env = get_player(args.rom)
exp = ReplayMemory(args.mem_size, IMAGE_SIZE, CONTEXT_LEN)
action_dim = env.action_space.n
if args.alg == 'DQN':
agent = DQNModel(IMAGE_SIZE, action_dim, args.gamma, CONTEXT_LEN,
args.use_cuda)
elif args.alg == 'DoubleDQN':
agent = DoubleDQNModel(IMAGE_SIZE, action_dim, args.gamma, CONTEXT_LEN,
args.use_cuda)
elif args.alg == 'DuelingDQN':
agent = DuelingDQNModel(IMAGE_SIZE, action_dim, args.gamma, CONTEXT_LEN,
args.use_cuda)
else:
print('Input algorithm name error!')
return
with tqdm(total=MEMORY_WARMUP_SIZE, desc='Memory warmup') as pbar:
while len(exp) < MEMORY_WARMUP_SIZE:
total_reward, step = run_train_episode(agent, env, exp)
pbar.update(step)
# train
test_flag = 0
save_flag = 0
pbar = tqdm(total=1e8)
recent_100_reward = []
total_step = 0
max_reward = None
save_path = os.path.join(args.model_dirname, '{}-{}'.format(
args.alg, os.path.basename(args.rom).split('.')[0]))
while True:
# start epoch
total_reward, step = run_train_episode(agent, env, exp)
total_step += step
pbar.set_description('[train]exploration:{}'.format(agent.exploration))
pbar.update(step)
if total_step // args.test_every_steps == test_flag:
pbar.write("testing")
eval_reward = eval_agent(agent, test_env)
test_flag += 1
print("eval_agent done, (steps, eval_reward): ({}, {})".format(
total_step, eval_reward))
if max_reward is None or eval_reward > max_reward:
max_reward = eval_reward
fluid.io.save_inference_model(save_path, ['state'],
agent.pred_value, agent.exe,
agent.predict_program)
pbar.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--alg',
type=str,
default='DQN',
help='Reinforcement learning algorithm, support: DQN, DoubleDQN, DuelingDQN'
)
parser.add_argument(
'--use_cuda', action='store_true', help='if set, use cuda')
parser.add_argument(
'--gamma',
type=float,
default=0.99,
help='discount factor for accumulated reward computation')
parser.add_argument(
'--mem_size',
type=int,
default=1000000,
help='memory size for experience replay')
parser.add_argument(
'--batch_size', type=int, default=64, help='batch size for training')
parser.add_argument('--rom', help='atari rom', required=True)
parser.add_argument(
'--model_dirname',
type=str,
default='saved_model',
help='dirname to save model')
parser.add_argument(
'--test_every_steps',
type=int,
default=100000,
help='every steps number to run test')
args = parser.parse_args()
train_agent()
PaddleRL
============
强化学习
--------
强化学习是近年来一个愈发重要的机器学习方向,特别是与深度学习相结合而形成的深度强化学习(Deep Reinforcement Learning, DRL),取得了很多令人惊异的成就。人们所熟知的战胜人类顶级围棋职业选手的 AlphaGo 就是 DRL 应用的一个典型例子,除游戏领域外,其它的应用还包括机器人、自然语言处理等。
深度强化学习的开山之作是在Atari视频游戏中的成功应用, 其可直接接受视频帧这种高维输入并根据图像内容端到端地预测下一步的动作,所用到的模型被称为深度Q网络(Deep Q-Network, DQN)。本实例就是利用PaddlePaddle Fluid这个灵活的框架,实现了 DQN 及其变体,并测试了它们在 Atari 游戏中的表现。
- [DeepQNetwork](https://github.com/PaddlePaddle/models/blob/develop/PaddleRL/DeepQNetwork/README_cn.md)
运行本目录下的程序示例需要使用PaddlePaddle的最新develop分枝。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
---
# Policy Gradient RL by PaddlePaddle
本文介绍了如何使用PaddlePaddle通过policy-based的强化学习方法来训练一个player(actor model), 我们希望这个player可以完成简单的走阶梯任务。
内容分为:
- 任务描述
- 模型
- 策略(目标函数)
- 算法(Gradient ascent)
- PaddlePaddle实现
## 1. 任务描述
假设有一个阶梯,连接A、B点,player从A点出发,每一步只能向前走一步或向后走一步,到达B点即为完成任务。我们希望训练一个聪明的player,它知道怎么最快的从A点到达B点。
我们在命令行以下边的形式模拟任务:
```
A - O - - - - - B
```
一个‘-'代表一个阶梯,A点在行头,B点在行末,O代表player当前在的位置。
## 2. Policy Gradient
### 2.1 模型
#### inputyer
模型的输入是player观察到的当前阶梯的状态$S$, 要包含阶梯的长度和player当前的位置信息。
在命令行模拟的情况下,player的位置和阶梯长度连个变量足以表示当前的状态,但是我们为了便于将这个demo推广到更复杂的任务场景,我们这里用一个向量来表示游戏状态$S$.
向量$S$的长度为阶梯的长度,每一维代表一个阶梯,player所在的位置为1,其它位置为0.
下边是一个例子:
```
S = [0, 1, 0, 0] // 阶梯长度为4,player在第二个阶梯上。
```
#### hidden layer
隐藏层采用两个全连接layer `FC_1``FC_2`, 其中`FC_1` 的size为10, `FC_2`的size为2.
#### output layer
我们使用softmax将`FC_2`的output映射为所有可能的动作(前进或后退)的概率分布(Probability of taking the action),即为一个二维向量`act_probs`, 其中,`act_probs[0]` 为后退的概率,`act_probs[1]`为前进的概率。
#### 模型表示
我将我们的player模型(actor)形式化表示如下:
$$a = \pi_\theta(s)$$
其中$\theta$表示模型的参数,$s$是输入状态。
### 2.2 策略(目标函数)
我们怎么评估一个player(模型)的好坏呢?首先我们定义几个术语:
我们让$\pi_\theta(s)$来玩一局游戏,$s_t$表示第$t$时刻的状态,$a_t$表示在状态$s_t$做出的动作,$r_t$表示做过动作$a_t$后得到的奖赏。
一局游戏的过程可以表示如下:
$$\tau = [s_1, a_1, r_1, s_2, a_2, r_2 ... s_T, a_T, r_T] \tag{1}$$
一局游戏的奖励表示如下:
$$R(\tau) = \sum_{t=1}^Tr_t$$
player玩一局游戏,可能会出现多种操作序列$\tau$ ,某个$\tau$出现的概率是依赖于player model的$\theta$, 记做:
$$P(\tau | \theta)$$
那么,给定一个$\theta$(player model), 玩一局游戏,期望得到的奖励是:
$$\overline {R}_\theta = \sum_\tau R(\tau)\sum_\tau R(\tau) P(\tau|\theta)$$
大多数情况,我们无法穷举出所有的$\tau$,所以我们就抽取N个$\tau$来计算近似的期望:
$$\overline {R}_\theta = \sum_\tau R(\tau) P(\tau|\theta) \approx \frac{1}{N} \sum_{n=1}^N R(\tau^n)$$
$\overline {R}_\theta$就是我们需要的目标函数,它表示了一个参数为$\theta$的player玩一局游戏得分的期望,这个期望越大,代表这个player能力越强。
### 2.3 算法(Gradient ascent)
我们的目标函数是$\overline {R}_\theta$, 我们训练的任务就是, 我们训练的任务就是:
$$\theta^* = \arg\max_\theta \overline {R}_\theta$$
为了找到理想的$\theta$,我们使用Gradient ascent方法不断在$\overline {R}_\theta$的梯度方向更新$\theta$,可表示如下:
$$\theta' = \theta + \eta * \bigtriangledown \overline {R}_\theta$$
$$ \bigtriangledown \overline {R}_\theta = \sum_\tau R(\tau) \bigtriangledown P(\tau|\theta)\\
= \sum_\tau R(\tau) P(\tau|\theta) \frac{\bigtriangledown P(\tau|\theta)}{P(\tau|\theta)} \\
=\sum_\tau R(\tau) P(\tau|\theta) {\bigtriangledown \log P(\tau|\theta)} $$
$$P(\tau|\theta) = P(s_1)P(a_1|s_1,\theta)P(s_2, r_1|s_1,a_1)P(a_2|s_2,\theta)P(s_3,r_2|s_2,a_2)...P(a_t|s_t,\theta)P(s_{t+1}, r_t|s_t,a_t)\\
=P(s_1) \sum_{t=1}^T P(a_t|s_t,\theta)P(s_{t+1}, r_t|s_t,a_t)$$
$$\log P(\tau|\theta) = \log P(s_1) + \sum_{t=1}^T [\log P(a_t|s_t,\theta) + \log P(s_{t+1}, r_t|s_t,a_t)]$$
$$ \bigtriangledown \log P(\tau|\theta) = \sum_{t=1}^T \bigtriangledown \log P(a_t|s_t,\theta)$$
$$ \bigtriangledown \overline {R}_\theta = \sum_\tau R(\tau) P(\tau|\theta) {\bigtriangledown \log P(\tau|\theta)} \\
\approx \frac{1}{N} \sum_{n=1}^N R(\tau^n) {\bigtriangledown \log P(\tau|\theta)} \\
= \frac{1}{N} \sum_{n=1}^N R(\tau^n) {\sum_{t=1}^T \bigtriangledown \log P(a_t|s_t,\theta)} \\
= \frac{1}{N} \sum_{n=1}^N \sum_{t=1}^T R(\tau^n) { \bigtriangledown \log P(a_t|s_t,\theta)} \tag{11}$$
#### 2.3.2 导数解释
在使用深度学习框架进行训练求解时,一般用梯度下降方法,所以我们把Gradient ascent转为Gradient
descent, 重写等式$(5)(6)$为:
$$\theta^* = \arg\min_\theta (-\overline {R}_\theta \tag{13}$$
$$\theta' = \theta - \eta * \bigtriangledown (-\overline {R}_\theta)) \tag{14}$$
根据上一节的推导,$ (-\bigtriangledown \overline {R}_\theta) $结果如下:
$$ -\bigtriangledown \overline {R}_\theta
= \frac{1}{N} \sum_{n=1}^N \sum_{t=1}^T R(\tau^n) { \bigtriangledown -\log P(a_t|s_t,\theta)} \tag{15}$$
根据等式(14), 我们的player的模型可以设计为:
<p align="center">
<img src="images/PG_1.svg" width="620" hspace='10'/> <br/>
图 1
</p>
用户的在一局游戏中的一次操作可以用元组$(s_t, a_t)$, 就是在状态$s_t$状态下做了动作$a_t$, 我们通过图(1)中的前向网络计算出来cross entropy cost为$−\log P(a_t|s_t,\theta)$, 恰好是等式(15)中我们需要微分的一项。
图1是我们需要的player模型,我用这个网络的前向计算可以预测任何状态下该做什么动作。但是怎么去训练学习这个网络呢?在等式(15)中还有一项$R(\tau^n)$, 我做反向梯度传播的时候要加上这一项,所以我们需要在图1基础上再加上$R(\tau^n)$, 如 图2 所示:
<p align="center">
<img src="images/PG_2.svg" width="620" hspace='10'/> <br/>
图 2
</p>
图2就是我们最终的网络结构。
#### 2.3.3 直观理解
对于等式(15),我只看游戏中的一步操作,也就是这一项: $R(\tau^n) { \bigtriangledown -\log P(a_t|s_t,\theta)}$, 我们可以简单的认为我们训练的目的是让 $R(\tau^n) {[ -\log P(a_t|s_t,\theta)]}$尽可能的小,也就是$R(\tau^n) \log P(a_t|s_t,\theta)$尽可能的大。
- 如果我们当前游戏局的奖励$R(\tau^n)$为正,那么我们希望当前操作的出现的概率$P(a_t|s_t,\theta)$尽可能大。
- 如果我们当前游戏局的奖励$R(\tau^n)$为负,那么我们希望当前操作的出现的概率$P(a_t|s_t,\theta)$尽可能小。
#### 2.3.4 一个问题
一人犯错,诛连九族。一人得道,鸡犬升天。如果一局游戏得到奖励,我们希望帮助获得奖励的每一次操作都被重视;否则,导致惩罚的操作都要被冷落一次。
是不是很有道理的样子?但是,如果有些游戏场景只有奖励,没有惩罚,怎么办?也就是所有的$R(\tau^n)$都为正。
针对不同的游戏场景,我们有不同的解决方案:
1. 每局游戏得分不一样:将每局的得分减去一个bias,结果就有正有负了。
2. 每局游戏得分一样:把完成一局的时间作为计分因素,并减去一个bias.
我们在第一章描述的游戏场景,需要用第二种 ,player每次到达终点都会收到1分的奖励,我们可以按完成任务所用的步数来定义奖励R.
更进一步,我们认为一局游戏中每步动作对结局的贡献是不同的,有聪明的动作,也有愚蠢的操作。直观的理解,一般是靠前的动作是愚蠢的,靠后的动作是聪明的。既然有了这个价值观,那么我们拿到1分的奖励,就不能平均分给每个动作了。
如图3所示,让所有动作按先后排队,从后往前衰减地给每个动作奖励,然后再每个动作的奖励再减去所有动作奖励的平均值:
<p align="center">
<img src="images/PG_3.svg" width="620" hspace='10'/> <br/>
图 3
</p>
## 3. 训练效果
demo运行训练效果如下,经过1000轮尝试,我们的player就学会了如何有效的完成任务了:
```
---------O epoch: 0; steps: 42
---------O epoch: 1; steps: 77
---------O epoch: 2; steps: 82
---------O epoch: 3; steps: 64
---------O epoch: 4; steps: 79
---------O epoch: 501; steps: 19
---------O epoch: 1001; steps: 9
---------O epoch: 1501; steps: 9
---------O epoch: 2001; steps: 11
---------O epoch: 2501; steps: 9
---------O epoch: 3001; steps: 9
---------O epoch: 3002; steps: 9
---------O epoch: 3003; steps: 9
---------O epoch: 3004; steps: 9
---------O epoch: 3005; steps: 9
---------O epoch: 3006; steps: 9
---------O epoch: 3007; steps: 9
---------O epoch: 3008; steps: 9
---------O epoch: 3009; steps: 9
---------O epoch: 3010; steps: 11
---------O epoch: 3011; steps: 9
---------O epoch: 3012; steps: 9
---------O epoch: 3013; steps: 9
---------O epoch: 3014; steps: 9
```
import numpy as np
import paddle.fluid as fluid
# reproducible
np.random.seed(1)
class PolicyGradient:
def __init__(
self,
n_actions,
n_features,
learning_rate=0.01,
reward_decay=0.95,
output_graph=False, ):
self.n_actions = n_actions
self.n_features = n_features
self.lr = learning_rate
self.gamma = reward_decay
self.ep_obs, self.ep_as, self.ep_rs = [], [], []
self.place = fluid.CPUPlace()
self.exe = fluid.Executor(self.place)
def build_net(self):
obs = fluid.layers.data(
name='obs', shape=[self.n_features], dtype='float32')
acts = fluid.layers.data(name='acts', shape=[1], dtype='int64')
vt = fluid.layers.data(name='vt', shape=[1], dtype='float32')
# fc1
fc1 = fluid.layers.fc(input=obs, size=10, act="tanh") # tanh activation
# fc2
all_act_prob = fluid.layers.fc(input=fc1,
size=self.n_actions,
act="softmax")
self.inferece_program = fluid.defaul_main_program().clone()
# to maximize total reward (log_p * R) is to minimize -(log_p * R)
neg_log_prob = fluid.layers.cross_entropy(
input=self.all_act_prob,
label=acts) # this is negative log of chosen action
neg_log_prob_weight = fluid.layers.elementwise_mul(x=neg_log_prob, y=vt)
loss = fluid.layers.reduce_mean(
neg_log_prob_weight) # reward guided loss
sgd_optimizer = fluid.optimizer.SGD(self.lr)
sgd_optimizer.minimize(loss)
self.exe.run(fluid.default_startup_program())
def choose_action(self, observation):
prob_weights = self.exe.run(self.inferece_program,
feed={"obs": observation[np.newaxis, :]},
fetch_list=[self.all_act_prob])
prob_weights = np.array(prob_weights[0])
# select action w.r.t the actions prob
action = np.random.choice(
range(prob_weights.shape[1]), p=prob_weights.ravel())
return action
def store_transition(self, s, a, r):
self.ep_obs.append(s)
self.ep_as.append(a)
self.ep_rs.append(r)
def learn(self):
# discount and normalize episode reward
discounted_ep_rs_norm = self._discount_and_norm_rewards()
tensor_obs = np.vstack(self.ep_obs).astype("float32")
tensor_as = np.array(self.ep_as).astype("int64")
tensor_as = tensor_as.reshape([tensor_as.shape[0], 1])
tensor_vt = discounted_ep_rs_norm.astype("float32")[:, np.newaxis]
# train on episode
self.exe.run(
fluid.default_main_program(),
feed={
"obs": tensor_obs, # shape=[None, n_obs]
"acts": tensor_as, # shape=[None, ]
"vt": tensor_vt # shape=[None, ]
})
self.ep_obs, self.ep_as, self.ep_rs = [], [], [] # empty episode data
return discounted_ep_rs_norm
def _discount_and_norm_rewards(self):
# discount episode rewards
discounted_ep_rs = np.zeros_like(self.ep_rs)
running_add = 0
for t in reversed(range(0, len(self.ep_rs))):
running_add = running_add * self.gamma + self.ep_rs[t]
discounted_ep_rs[t] = running_add
# normalize episode rewards
discounted_ep_rs -= np.mean(discounted_ep_rs)
discounted_ep_rs /= np.std(discounted_ep_rs)
return discounted_ep_rs
import time
import sys
import numpy as np
class Env():
def __init__(self, stage_len, interval):
self.stage_len = stage_len
self.end = self.stage_len - 1
self.position = 0
self.interval = interval
self.step = 0
self.epoch = -1
self.render = False
def reset(self):
self.end = self.stage_len - 1
self.position = 0
self.epoch += 1
self.step = 0
if self.render:
self.draw(True)
def status(self):
s = np.zeros([self.stage_len]).astype("float32")
s[self.position] = 1
return s
def move(self, action):
self.step += 1
reward = 0.0
done = False
if action == 0:
self.position = max(0, self.position - 1)
else:
self.position = min(self.end, self.position + 1)
if self.render:
self.draw()
if self.position == self.end:
reward = 1.0
done = True
return reward, done, self.status()
def draw(self, new_line=False):
if new_line:
print ""
else:
print "\r",
for i in range(self.stage_len):
if i == self.position:
sys.stdout.write("O")
else:
sys.stdout.write("-")
sys.stdout.write(" epoch: %d; steps: %d" % (self.epoch, self.step))
sys.stdout.flush()
time.sleep(self.interval)
<?xml version="1.0" encoding="utf-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xl="http://www.w3.org/1999/xlink" version="1.1" viewBox="162 59 594 567" width="594pt" height="567pt" xmlns:dc="http://purl.org/dc/elements/1.1/"><metadata> Produced by OmniGraffle 6.0.5 <dc:date>2017-12-01 08:39Z</dc:date></metadata><defs><font-face font-family="STIXGeneral" font-size="20" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-816.5001" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><marker orient="auto" overflow="visible" markerUnits="strokeWidth" id="FilledArrow_Marker" viewBox="-1 -3 6 6" markerWidth="6" markerHeight="6" color="#9a9a9a"><g><path d="M 3.7333333 0 L 0 -1.4 L 0 1.4 Z" fill="currentColor" stroke="currentColor" stroke-width="1"/></g></marker><font-face font-family="Helvetica Neue" font-size="20" panose-1="2 0 8 3 0 0 0 9 0 4" units-per-em="1000" underline-position="-100" underline-thickness="50" slope="0" x-height="517" cap-height="714" ascent="975.0061" descent="-216.99524" font-weight="bold"><font-face-src><font-face-name name="HelveticaNeue-Bold"/></font-face-src></font-face><font-face font-family="Helvetica Neue" font-size="21" panose-1="2 0 5 3 0 0 0 2 0 4" units-per-em="1000" underline-position="-100" underline-thickness="50" slope="0" x-height="517" cap-height="714" ascent="951.99585" descent="-212.99744" font-weight="500"><font-face-src><font-face-name name="HelveticaNeue"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="21" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-777.61913" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="19" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-859.4738" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="19" units-per-em="1000" underline-position="-227" underline-thickness="66" slope="0" x-height="450" cap-height="662" ascent="1055.00214" descent="-455.00092" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Regular"/></font-face-src></font-face></defs><g stroke="none" stroke-opacity="1" stroke-dasharray="none" fill="none" fill-opacity="1"><title>神经网络</title><g><title>Layer 1</title><circle cx="312.32677" cy="437.85433" r="32.500052" fill="#d5c0ff"/><circle cx="312.32677" cy="437.85433" r="32.500052" stroke="#695f7e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><circle cx="312.32677" cy="581.58662" r="32.500052" fill="#bfeaff"/><circle cx="312.32677" cy="581.58662" r="32.500052" stroke="#5f747e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(288.32677 566.58662)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="9.56" y="21" textLength="28.88">x_2</tspan></text><circle cx="312.32677" cy="232.15354" r="32.500052" fill="#ffd6d8"/><circle cx="312.32677" cy="232.15354" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(288.32677 217.15354)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="9.56" y="21" textLength="28.88">y_2</tspan></text><circle cx="207.32677" cy="232.15354" r="32.500052" fill="#ffd6d8"/><circle cx="207.32677" cy="232.15354" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(183.32677 217.15354)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="9.56" y="21" textLength="28.88">y_0</tspan></text><circle cx="207.32677" cy="581.58662" r="32.500052" fill="#bfeaff" fill-opacity=".8"/><circle cx="207.32677" cy="581.58662" r="32.500052" stroke="#5f747e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(183.32677 566.58662)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="9.56" y="21" textLength="28.88">x_1</tspan></text><circle cx="207.32677" cy="437.85433" r="32.500052" fill="#d5c0ff" fill-opacity=".8"/><circle cx="207.32677" cy="437.85433" r="32.500052" stroke="#695f7e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><circle cx="473.78347" cy="581.58662" r="32.500052" fill="#bfeaff"/><circle cx="473.78347" cy="581.58662" r="32.500052" stroke="#5f747e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(449.78347 566.58662)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="9.56" y="21" textLength="28.88">x_n</tspan></text><rect x="180.33071" y="313.22835" width="322.95276" height="44.811023" fill="#ffec8a"/><rect x="180.33071" y="313.22835" width="322.95276" height="44.811023" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(182.33071 320.63386)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="127.24638" y="21" textLength="64.46">Softmax</tspan></text><circle cx="467.02756" cy="232.06693" r="32.500052" fill="#ffd6d8"/><circle cx="467.02756" cy="232.06693" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(443.02756 217.06693)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="7.34" y="21" textLength="33.32">y_m</tspan></text><circle cx="473.32284" cy="437.85433" r="32.500052" fill="#d5c0ff"/><circle cx="473.32284" cy="437.85433" r="32.500052" stroke="#695f7e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><line x1="207.53024" y1="404.85494" x2="207.72086" y2="373.93907" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="207.32676" y1="548.5866" x2="207.32676" y2="486.75435" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="312.32676" y1="548.5866" x2="312.32676" y2="486.75435" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="473.68325" y1="548.58675" x2="473.49546" y2="486.75404" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="473.11448" y1="404.85497" x2="472.91929" y2="373.93906" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="312.1168" y1="404.85498" x2="311.92007" y2="373.93905" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="205.8189" y1="312.03937" x2="206.40393" y2="281.04489" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="311.82677" y1="312.57482" x2="312.02275" y2="281.05262" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="468.52756" y1="312.57482" x2="467.9385" y2="280.9585" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="292.85776" y1="554.9359" x2="236.17501" y2="477.34407" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="444.7321" y1="565.9156" x2="250.37242" y2="461.07324" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="336.94445" y1="559.60872" x2="436.84423" y2="470.4213" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="226.79579" y1="554.9359" x2="283.47854" y2="477.34407" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="236.36684" y1="565.89467" x2="430.29435" y2="461.105" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="449.13498" y1="559.64322" x2="348.85352" y2="470.36736" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><circle cx="310.83071" cy="103.452757" r="32.500052" fill="#c2ffc4"/><circle cx="310.83071" cy="103.452757" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(286.83071 88.452757)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="8.99" y="21" textLength="30.02">a_2</tspan></text><circle cx="205.83071" cy="103.452757" r="32.500052" fill="#c2ffc4"/><circle cx="205.83071" cy="103.452757" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(181.83071 88.452757)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="8.99" y="21" textLength="30.02">a_0</tspan></text><circle cx="465.5315" cy="103.36614" r="32.500052" fill="#c2ffc4"/><circle cx="465.5315" cy="103.36614" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(441.5315 88.36614)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="6.77" y="21" textLength="34.46">a_m</tspan></text><text transform="translate(371.41733 91.03937)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="bold" x="3.2771656" y="20" textLength="27.8">. . .</tspan></text><text transform="translate(372.5 219.59843)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="bold" x="3.2771656" y="20" textLength="27.8">. . .</tspan></text><text transform="translate(376.08662 421.10237)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="bold" x="3.2771656" y="20" textLength="27.8">. . .</tspan></text><text transform="translate(375.87796 569.08662)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="bold" x="3.2771656" y="20" textLength="27.8">. . .</tspan></text><text transform="translate(589.35434 572.08662)" fill="black"><tspan font-family="Helvetica Neue" font-size="21" font-weight="500" x=".1925" y="20" textLength="27.615">s_t</tspan></text><text transform="translate(597.35434 499.73623)" fill="#262626"><tspan font-family="STIXGeneral" font-size="21" font-style="italic" font-weight="500" fill="#262626" x="0" y="22" textLength="10.08">θ</tspan></text><text transform="translate(542 222.3504)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="57.152">y_t = P</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="57.152" y="20" textLength="6.327">(</tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="63.479" y="20" textLength="61.199">a_t | s_t</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="124.678" y="20" textLength="9.5">, </tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="134.178" y="20" textLength="9.12">θ</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="143.298" y="20" textLength="6.327">)</tspan></text><text transform="translate(501.97638 152.90158)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="131.024">-log(y_t) = -logP</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="131.024" y="20" textLength="6.327">(</tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="137.351" y="20" textLength="61.199">a_t | s_t</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="198.55" y="20" textLength="9.5">, </tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="208.05" y="20" textLength="9.12">θ</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="217.17" y="20" textLength="6.327">)</tspan></text><text transform="translate(271.4567 154.73622)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="171.171">CROSS ENTROPY =</tspan></text><path d="M 510.65355 515.65355 L 528.9232 507.65355 L 528.9232 512.30024 L 566.08468 512.30024 L 566.08468 507.65355 L 584.35434 515.65355 L 566.08468 523.65355 L 566.08468 519.00685 L 528.9232 519.00685 L 528.9232 523.65355 Z" fill="white"/><path d="M 510.65355 515.65355 L 528.9232 507.65355 L 528.9232 512.30024 L 566.08468 512.30024 L 566.08468 507.65355 L 584.35434 515.65355 L 566.08468 523.65355 L 566.08468 519.00685 L 528.9232 519.00685 L 528.9232 523.65355 Z" stroke="black" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><path d="M 510.4882 586.26772 L 528.75785 578.26772 L 528.75785 582.9144 L 565.91932 582.9144 L 565.91932 578.26772 L 584.18898 586.26772 L 565.91932 594.26772 L 565.91932 589.62103 L 528.75785 589.62103 L 528.75785 594.26772 Z" fill="white"/><path d="M 510.4882 586.26772 L 528.75785 578.26772 L 528.75785 582.9144 L 565.91932 582.9144 L 565.91932 578.26772 L 584.18898 586.26772 L 565.91932 594.26772 L 565.91932 589.62103 L 528.75785 589.62103 L 528.75785 594.26772 Z" stroke="black" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/></g></g></svg>
<?xml version="1.0" encoding="utf-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xl="http://www.w3.org/1999/xlink" version="1.1" viewBox="80 136 614 415" width="614pt" height="415pt" xmlns:dc="http://purl.org/dc/elements/1.1/"><metadata> Produced by OmniGraffle 6.0.5 <dc:date>2017-12-01 08:39Z</dc:date></metadata><defs><font-face font-family="Helvetica Neue" font-size="20" panose-1="2 11 6 4 2 2 2 2 2 4" units-per-em="1000" underline-position="-75" underline-thickness="50" slope="0" x-height="517" cap-height="714" ascent="975.0061" descent="-216.99524" font-weight="600"><font-face-src><font-face-name name="HelveticaNeue-Medium"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="20" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-816.5001" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><marker orient="auto" overflow="visible" markerUnits="strokeWidth" id="FilledArrow_Marker" viewBox="-1 -3 6 6" markerWidth="6" markerHeight="6" color="#9a9a9a"><g><path d="M 3.7333333 0 L 0 -1.4 L 0 1.4 Z" fill="currentColor" stroke="currentColor" stroke-width="1"/></g></marker><font-face font-family="STIXGeneral" font-size="19" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-859.4738" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="19" units-per-em="1000" underline-position="-227" underline-thickness="66" slope="0" x-height="450" cap-height="662" ascent="1055.00214" descent="-455.00092" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Regular"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="21" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-777.61913" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><font-face font-family="Helvetica Neue" font-size="18" panose-1="2 11 6 4 2 2 2 2 2 4" units-per-em="1000" underline-position="-75" underline-thickness="50" slope="0" x-height="517" cap-height="714" ascent="975.0061" descent="-216.99524" font-weight="600"><font-face-src><font-face-name name="HelveticaNeue-Medium"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="18" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-907.2223" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="12" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-1360.8335" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="14" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-1166.4287" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face></defs><g stroke="none" stroke-opacity="1" stroke-dasharray="none" fill="none" fill-opacity="1"><title>神经网络 2</title><g><title>Layer 1</title><circle cx="170.05906" cy="507.30315" r="32.500052" fill="#bfeaff"/><circle cx="170.05906" cy="507.30315" r="32.500052" stroke="#5f747e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(146.05906 494.80315)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="600" x="10.48" y="20" textLength="27.04">s_t</tspan></text><circle cx="169.59842" cy="238.01181" r="32.500052" fill="#ffd6d8"/><circle cx="169.59842" cy="238.01181" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(145.598425 225.51181)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="600" x="17.52" y="20" textLength="12.96">Y</tspan></text><circle cx="169.59842" cy="367.57087" r="32.500052" fill="#d5c0ff"/><circle cx="169.59842" cy="367.57087" r="32.500052" stroke="#695f7e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(145.598425 352.57087)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" fill="black" x="11.22" y="21" textLength="25.56">FC</tspan></text><line x1="169.95027" y1="474.30332" x2="169.75962" y2="416.47062" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="169.59843" y1="334.57085" x2="169.59843" y2="286.91183" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><circle cx="304.46063" cy="507.30315" r="32.500052" fill="#c2ffc4"/><circle cx="304.46063" cy="507.30315" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(280.46063 494.80315)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="600" x="10.11" y="20" textLength="27.78">a_t</tspan></text><text transform="translate(226.54725 151.48425)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="42.218">-logP</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="42.218" y="20" textLength="6.327">(</tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="48.545" y="20" textLength="61.199">a_t | s_t</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="109.744" y="20" textLength="9.5">, </tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="119.244" y="20" textLength="9.12">θ</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="128.364" y="20" textLength="6.327">)</tspan></text><text transform="translate(94.125985 292.63386)" fill="#262626"><tspan font-family="STIXGeneral" font-size="21" font-style="italic" font-weight="500" fill="#262626" x="0" y="22" textLength="67.683">Softmax</tspan></text><circle cx="430.8189" cy="507.30315" r="32.500052" fill="#c2ffc4"/><circle cx="430.8189" cy="507.30315" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(406.8189 493.80315)" fill="black"><tspan font-family="Helvetica Neue" font-size="18" font-weight="600" x=".672" y="19" textLength="17.676">R(</tspan><tspan font-family="STIXGeneral" font-size="18" font-style="italic" font-weight="500" fill="#262626" x="18.348" y="19" textLength="14.976">τ^</tspan><tspan font-family="STIXGeneral" font-size="18" font-style="italic" font-weight="500" fill="#262626" x="33.324" y="19" textLength="9">n</tspan><tspan font-family="Helvetica Neue" font-size="18" font-weight="600" x="42.324" y="19" textLength="5.004">)</tspan></text><circle cx="300.20866" cy="238.01181" r="32.500052" fill="#c2ffc4"/><circle cx="300.20866" cy="238.01181" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(276.20866 220.01181)" fill="black"><tspan font-family="STIXGeneral" font-size="12" font-style="italic" font-weight="500" x="9.996" y="13" textLength="31.008">Cross </tspan><tspan font-family="STIXGeneral" font-size="12" font-style="italic" font-weight="500" x="4.644" y="31" textLength="38.712">Entropy</tspan></text><line x1="303.93964" y1="474.30723" x2="300.98067" y2="286.90576" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="202.59844" y1="238.01183" x2="251.30865" y2="238.01183" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><circle cx="430.8189" cy="238.01181" r="32.500052" fill="#c2ffc4"/><circle cx="430.8189" cy="238.01181" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(406.8189 223.01181)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="7.89" y="21" textLength="32.22">Mul</tspan></text><line x1="333.20868" y1="238.01182" x2="381.91889" y2="238.01182" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="430.81892" y1="474.30314" x2="430.81892" y2="286.91183" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><text transform="translate(488.3937 223.51181)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="6.327">-</tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="6.327" y="20" textLength="11.609">R</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="17.936" y="20" textLength="6.327">(</tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="24.263" y="20" textLength="15.808">τ^</tspan><tspan font-family="STIXGeneral" font-size="14" font-style="italic" font-weight="500" fill="#262626" x="40.071" y="20" textLength="7">n</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="47.071" y="20" textLength="6.327">)</tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="53.398" y="20" textLength="35.891">logP</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="89.289" y="20" textLength="6.327">(</tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="95.616" y="20" textLength="61.199">a_t | s_t</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="156.815" y="20" textLength="9.5">, </tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="166.315" y="20" textLength="9.12">θ</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="175.435" y="20" textLength="6.327">)</tspan></text><text transform="translate(131.72441 424.19292)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="9.12">θ</tspan></text></g></g></svg>
<?xml version="1.0" encoding="utf-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xl="http://www.w3.org/1999/xlink" version="1.1" viewBox="51 454 689 160" width="689pt" height="160pt" xmlns:dc="http://purl.org/dc/elements/1.1/"><metadata> Produced by OmniGraffle 6.0.5 <dc:date>2017-12-01 09:42Z</dc:date></metadata><defs><font-face font-family="Helvetica Neue" font-size="20" panose-1="2 11 6 4 2 2 2 2 2 4" units-per-em="1000" underline-position="-75" underline-thickness="50" slope="0" x-height="517" cap-height="714" ascent="975.0061" descent="-216.99524" font-weight="600"><font-face-src><font-face-name name="HelveticaNeue-Medium"/></font-face-src></font-face><font-face font-family="Helvetica Neue" font-size="16" panose-1="2 11 6 4 2 2 2 2 2 4" units-per-em="1000" underline-position="-75" underline-thickness="50" slope="0" x-height="517" cap-height="714" ascent="975.0061" descent="-216.99524" font-weight="600"><font-face-src><font-face-name name="HelveticaNeue-Medium"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="19" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-859.4738" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><marker orient="auto" overflow="visible" markerUnits="strokeWidth" id="FilledArrow_Marker" viewBox="-1 -3 6 6" markerWidth="6" markerHeight="6" color="#9a9a9a"><g><path d="M 3.7333333 0 L 0 -1.4 L 0 1.4 Z" fill="currentColor" stroke="currentColor" stroke-width="1"/></g></marker><font-face font-family="Helvetica" font-size="19" units-per-em="1000" underline-position="-75.683594" underline-thickness="49.316406" slope="0" x-height="522.94922" cap-height="717.28516" ascent="770.01953" descent="-229.98047" font-weight="500"><font-face-src><font-face-name name="Helvetica"/></font-face-src></font-face></defs><g stroke="none" stroke-opacity="1" stroke-dasharray="none" fill="none" fill-opacity="1"><title>神经网络 3</title><g><title>Layer 1</title><circle cx="695.8071" cy="498.79922" r="32.500052" fill="#ffd6d8"/><circle cx="695.8071" cy="498.79922" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(671.8071 486.29922)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="600" x="16.96" y="20" textLength="14.08">R</tspan></text><circle cx="232.22048" cy="498.79922" r="32.500052" fill="#c2ffc4"/><circle cx="232.22048" cy="498.79922" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(208.22048 489.29922)" fill="black"><tspan font-family="Helvetica Neue" font-size="16" font-weight="600" x="11.104" y="16" textLength="25.792">a_2</tspan></text><text transform="translate(581.49607 471.46063)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="60.325">= 0.9 * </tspan></text><line x1="652" y1="498" x2="569.0881" y2="498.50272" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="487.18897" y1="498.79923" x2="420.7819" y2="498.79923" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><text transform="translate(280.70866 487.29922)" fill="#262626"><tspan font-family="Helvetica" font-size="19" font-weight="500" fill="#262626" x="0" y="19" textLength="34.836426"></tspan></text><circle cx="94.862205" cy="498.79922" r="32.500052" fill="#c2ffc4"/><circle cx="94.862205" cy="498.79922" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(70.862205 489.29922)" fill="black"><tspan font-family="Helvetica Neue" font-size="16" font-weight="600" x="11.104" y="16" textLength="25.792">a_1</tspan></text><circle cx="371.8819" cy="498.79922" r="32.500052" fill="#c2ffc4"/><circle cx="371.8819" cy="498.79922" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(347.8819 489.29922)" fill="black"><tspan font-family="Helvetica Neue" font-size="16" font-weight="600" x=".88" y="16" textLength="46.24">a_(t-1)</tspan></text><circle cx="520.18898" cy="498.79922" r="32.500052" fill="#c2ffc4"/><circle cx="520.18898" cy="498.79922" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(496.18898 489.29922)" fill="black"><tspan font-family="Helvetica Neue" font-size="16" font-weight="600" x="12.888" y="16" textLength="22.224">a_t</tspan></text><line x1="199.22046" y1="498.79923" x2="143.76222" y2="498.79923" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><text transform="translate(414.6063 469.12993)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="63.593">= 0.9^2 </tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="63.593" y="20" textLength="9.5">*</tspan></text><text transform="translate(129.30709 468.79134)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="59.375">= 0.9^t </tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="59.375" y="20" textLength="9.5">*</tspan></text><text transform="translate(194.84252 576.43308)" fill="#262626"><tspan font-family="Helvetica" font-size="19" font-weight="500" fill="#262626" x="0" y="19" textLength="218.0918"> -= mean(a_1, a_2 … a_t)</tspan></text></g></g></svg>
from brain import PolicyGradient
from env import Env
import numpy as np
n_actions = 2
interval = 0.01
stage_len = 10
epoches = 10000
if __name__ == "__main__":
brain = PolicyGradient(n_actions, stage_len)
e = Env(stage_len, interval)
brain.build_net()
done = False
for epoch in range(epoches):
if (epoch % 500 == 1) or epoch < 5 or epoch > 3000:
e.render = True
else:
e.render = False
e.reset()
while not done:
s = e.status()
action = brain.choose_action(s)
r, done, _ = e.move(action)
brain.store_transition(s, action, r)
done = False
brain.learn()
The minimum PaddlePaddle version needed for the code sample in this directory is the lastest develop branch. If you are on a version of PaddlePaddle earlier than this, [please update your installation](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html).
## Deep Automatic Speech Recognition
### Introduction
TBD
### Installation
#### Kaldi
The decoder depends on [kaldi](https://github.com/kaldi-asr/kaldi), install it by flowing its instructions. Then
```shell
export KALDI_ROOT=<absolute path to kaldi>
```
#### Decoder
```shell
git clone https://github.com/PaddlePaddle/models.git
cd models/fluid/DeepASR/decoder
sh setup.sh
```
### Data reprocessing
TBD
### Training
TBD
### Inference & Decoding
TBD
### Question and Contribution
TBD
运行本目录下的程序示例需要使用 PaddlePaddle v0.14及以上版本。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新 PaddlePaddle 安装版本。
---
DeepASR (Deep Automatic Speech Recognition) 是一个基于PaddlePaddle FLuid与[Kaldi](http://www.kaldi-asr.org)的语音识别系统。其利用Fluid框架完成语音识别中声学模型的配置和训练,并集成 Kaldi 的解码器。旨在方便已对 Kaldi 的较为熟悉的用户实现中声学模型的快速、大规模训练,并利用kaldi完成复杂的语音数据预处理和最终的解码过程。
### 目录
- [模型概览](#model-overview)
- [安装](#installation)
- [数据预处理](#data-reprocessing)
- [模型训练](#training)
- [训练过程中的时间分析](#perf-profiling)
- [预测和解码](#infer-decoding)
- [评估错误率](#scoring-error-rate)
- [Aishell 实例](#aishell-example)
- [欢迎贡献更多的实例](#how-to-contrib)
### 模型概览
DeepASR的声学模型是一个单卷积层加多层层叠LSTMP 的结构,利用卷积来进行初步的特征提取,并用多层的LSTMP来对时序关系进行建模,所用到的损失函数是交叉熵。[LSTMP](https://arxiv.org/abs/1402.1128)(LSTM with recurrent projection layer)是传统 LSTM 的拓展,在 LSTM 的基础上增加了一个映射层,将隐含层映射到较低的维度并输入下一个时间步,这种结构在大为减小 LSTM 的参数规模和计算复杂度的同时还提升了 LSTM 的性能表现。
<p align="center">
<img src="images/lstmp.png" height=240 width=480 hspace='10'/> <br />
图1 LSTMP 的拓扑结构
</p>
### 安装
#### kaldi的安装与设置
DeepASR解码过程中所用的解码器依赖于[Kaldi的安装](https://github.com/kaldi-asr/kaldi),如环境中无Kaldi, 请`git clone`其源代码,并按给定的命令安装好kaldi,最后设置环境变量`KALDI_ROOT`
```shell
export KALDI_ROOT=<kaldi的安装路径>
```
#### 解码器的安装
进入解码器源码所在的目录
```shell
cd models/fluid/DeepASR/decoder
```
运行安装脚本
```shell
sh setup.sh
```
编译过程完成即成功地安转了解码器。
### 数据预处理
参考[Kaldi的数据准备流程](http://kaldi-asr.org/doc/data_prep.html)完成音频数据的特征提取和标签对齐
### 声学模型的训练
可以选择在CPU或GPU模式下进行声学模型的训练,例如在GPU模式下的训练
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 python -u train.py \
--train_feature_lst train_feature.lst \
--train_label_lst train_label.lst \
--val_feature_lst val_feature.lst \
--val_label_lst val_label.lst \
--mean_var global_mean_var \
--parallel
```
其中`train_feature.lst``train_label.lst`分别是训练数据集的特征列表文件和标注列表文件,类似的,`val_feature.lst``val_label.lst`对应的则是验证集的列表文件。实际训练过程中要正确指定建模单元大小、学习率等重要参数。关于这些参数的说明,请运行
```shell
python train.py --help
```
获取更多信息。
### 训练过程中的时间分析
利用Fluid提供的性能分析工具profiler,可对训练过程进行性能分析,获取网络中operator级别的执行时间
```shell
CUDA_VISIBLE_DEVICES=0 python -u tools/profile.py \
--train_feature_lst train_feature.lst \
--train_label_lst train_label.lst \
--val_feature_lst val_feature.lst \
--val_label_lst val_label.lst \
--mean_var global_mean_var
```
### 预测和解码
在充分训练好声学模型之后,利用训练过程中保存下来的模型checkpoint,可对输入的音频数据进行解码输出,得到声音到文字的识别结果
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python -u infer_by_ckpt.py \
--batch_size 96 \
--checkpoint deep_asr.pass_1.checkpoint \
--infer_feature_lst test_feature.lst \
--infer_label_lst test_label.lst \
--mean_var global_mean_var \
--parallel
```
### 评估错误率
对语音识别系统的评价常用的指标有词错误率(Word Error Rate, WER)和字错误率(Character Error Rate, CER), 在DeepASR中也实现了相关的度量工具,其运行方式为
```
python score_error_rate.py --error_rate_type cer --ref ref.txt --hyp decoding.txt
```
参数`error_rate_type`表示测量错误率的类型,即 WER 或 CER;`ref.txt``decoding.txt` 分别表示参考文本和实际解码出的文本,它们有着同样的格式:
```
key1 text1
key2 text2
key3 text3
...
```
### Aishell 实例
本节以[Aishell数据集](http://www.aishelltech.com/kysjcp)为例,展示如何完成从数据预处理到解码输出。Aishell是由北京希尔贝克公司所开放的中文普通话语音数据集,时长178小时,包含了400名来自不同口音区域录制者的语音,原始数据可由[openslr](http://www.openslr.org/33)获取。为简化流程,这里提供了已完成预处理的数据集供下载:
```
cd examples/aishell
sh prepare_data.sh
```
其中包括了声学模型的训练数据以及解码过程中所用到的辅助文件等。下载数据完成后,在开始训练之前可对训练过程进行分析
```
sh profile.sh
```
执行训练
```
sh train.sh
```
默认是用4卡GPU进行训练,在实际过程中可根据可用GPU的数目和显存大小对`batch_size`、学习率等参数进行动态调整。训练过程中典型的损失函数和精度的变化趋势如图2所示
<p align="center">
<img src="images/learning_curve.png" height=480 width=640 hspace='10'/> <br />
图2 在Aishell数据集上训练声学模型的学习曲线
</p>
完成模型训练后,即可执行预测识别测试集语音中的文字:
```
sh infer_by_ckpt.sh
```
其中包括了声学模型的预测和解码器的解码输出两个重要的过程。以下是解码输出的样例:
```
...
BAC009S0764W0239 十一 五 期间 我 国 累计 境外 投资 七千亿 美元
BAC009S0765W0140 在 了解 送 方 的 资产 情况 与 需求 之后
BAC009S0915W0291 这 对 苹果 来说 不 是 件 容易 的 事 儿
BAC009S0769W0159 今年 土地 收入 预计 近 四万亿 元
BAC009S0907W0451 由 浦东 商店 作为 掩护
BAC009S0768W0128 土地 交易 可能 随着 供应 淡季 的 到来 而 降温
...
```
每行对应一个输出,均以音频样本的关键字开头,随后是按词分隔的解码出的中文文本。解码完成后运行脚本评估字错误率(CER)
```
sh score_cer.sh
```
其输出类似于如下所示
```
Error rate[cer] = 0.101971 (10683/104765),
total 7176 sentences in hyp, 0 not presented in ref.
```
利用经过20轮左右训练的声学模型,可以在Aishell的测试集上得到CER约10%的识别结果。
### 欢迎贡献更多的实例
DeepASR目前只开放了Aishell实例,我们欢迎用户在更多的数据集上测试完整的训练流程并贡献到这个项目中。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
16.2845556399 11.6891798673
17.21509949 12.3788567902
18.1143704548 14.9912618017
19.2335963752 18.5419556172
19.9266772451 21.2768220522
19.8245737202 21.2347210705
19.5432940972 20.2784036567
19.4631271754 20.2934452329
19.3929919324 20.457971868
19.2924788362 20.3626439234
18.9207244502 19.9196569759
18.7202605641 19.5920276899
18.4844279398 19.2068349019
18.2670948624 18.8716893824
18.0929628855 18.5439666541
17.8428896026 18.0255891747
17.6646850635 17.473764296
17.4955705896 16.8966859471
17.3706720293 16.4294027467
17.2530867792 16.0514717623
17.1304341172 15.7234699057
17.0038353287 15.4344471514
16.902550309 15.1603287337
16.8375590047 14.9304337826
16.816287853 14.9119310513
16.828838265 15.0930023024
16.8602209498 15.3771992423
16.9101763812 15.6897991789
16.9466065143 15.9364556489
16.9486061956 16.0699417826
16.9041374104 16.0796970272
16.8410093699 16.0111444599
16.7045718836 15.7991985601
16.51128489 15.5208920129
16.3253910608 15.2603181921
16.1297317333 14.9499965958
15.903428372 14.5958280409
15.6131718105 14.2709618
15.1395035533 13.9993939893
14.4298229999 13.3841189151
0.0034970565424 0.246184766149
0.00501284154705 0.238484972472
0.00605942680019 0.269064381708
0.00687266156243 0.319479238011
0.00734065019253 0.371947383205
0.00718807218417 0.384426479694
0.00652195540212 0.384676838281
0.00660416525951 0.395543910317
0.00680202057642 0.400803979681
0.00659144183007 0.393228973031
0.00605294530423 0.385021118038
0.00590452969394 0.361763039625
0.00612315374687 0.346777773373
0.00582354093973 0.335802403976
0.00574556002554 0.320733728218
0.00612254485891 0.310153103033
0.00626733043219 0.299854747445
0.00567398408041 0.293353685493
0.00519236700706 0.287668810947
0.00529581474367 0.281479660772
0.00479019484082 0.27451415777
0.00486381039428 0.266294391154
0.00491126372868 0.258105116126
0.00452105305011 0.252926328298
0.00531483334271 0.250910887373
0.00546572110469 0.253302256977
0.00479544857908 0.258484183394
0.00422106426297 0.264582900173
0.00401824135188 0.268467945623
0.0041705465252 0.269699480291
0.00405239564143 0.270406162975
0.0040059737566 0.270407601782
0.00406426729317 0.267951582656
0.00416613791013 0.264543833042
0.00427847607653 0.26247798891
0.00428050903034 0.259635263243
0.00454842971786 0.255829377617
0.00393747552387 0.253802307025
0.00374143688909 0.251011478787
0.00335475310258 0.236543650856
0.000373194755312 0.0419494800709
0.000230909648678 0.0394102370205
0.000150840015851 0.0414956922398
8.44401840771e-05 0.0460502231327
-6.24759314572e-06 0.0528049937739
-8.82957758148e-05 0.055711244886
1.16795791952e-05 0.0563188428833
-1.68716267856e-05 0.0575232763711
-0.000112625308645 0.057979929947
-0.000122619090002 0.0564126233493
1.73569637319e-05 0.05522573909
6.49872782342e-05 0.0507353361334
4.17746389178e-05 0.0479568131253
5.13884475653e-05 0.0461253238047
1.8860115143e-05 0.0436860476919
-5.64317701105e-05 0.042516381059
-0.000136859948115 0.0413574820205
-7.00847019726e-05 0.0409516370727
-5.39392223336e-05 0.040441504085
-9.24897162815e-05 0.0397800398173
4.7104970622e-05 0.039046286243
6.24805896165e-06 0.0380185986602
-2.35272813418e-05 0.036851063786
5.88344154127e-05 0.0361640489242
-8.39162076993e-05 0.0357639427311
-0.000108702805776 0.0358774639538
3.22013961834e-06 0.0363644530435
9.43501518394e-05 0.0370309934774
0.000134406229423 0.0374972993343
3.84007008533e-05 0.037676222515
3.05989328157e-05 0.0379111939182
9.52201629091e-05 0.0380927209106
0.000102126083729 0.0379925358499
6.98628072264e-05 0.0377276252241
4.55782256339e-05 0.0375165468654
4.76370987786e-05 0.0371482526345
-2.24128832709e-05 0.0366810742947
0.000125621306953 0.036628355271
0.000134568666093 0.0364860461759
0.000159858844464 0.0345583593149
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import unittest
import numpy as np
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
import data_utils.augmentor.trans_delay as trans_delay
class TestTransMeanVarianceNorm(unittest.TestCase):
"""unit test for TransMeanVarianceNorm
"""
def setUp(self):
self._file_path = "./data_utils/augmentor/tests/data/" \
"global_mean_var_search26kHr"
def test(self):
feature = np.zeros((2, 120), dtype="float32")
feature.fill(1)
trans = trans_mean_variance_norm.TransMeanVarianceNorm(self._file_path)
(feature1, label1, name) = trans.perform_trans((feature, None, None))
(mean, var) = trans.get_mean_var()
feature_flat1 = feature1.flatten()
feature_flat = feature.flatten()
one = np.ones((1), dtype="float32")
for idx, val in enumerate(feature_flat1):
cur_idx = idx % 120
self.assertAlmostEqual(val, (one[0] - mean[cur_idx]) * var[cur_idx])
class TestTransAddDelta(unittest.TestCase):
"""unit test TestTransAddDelta
"""
def test_regress(self):
"""test regress
"""
feature = np.zeros((14, 120), dtype="float32")
feature[0:5, 0:40].fill(1)
feature[0 + 5, 0:40].fill(1)
feature[1 + 5, 0:40].fill(2)
feature[2 + 5, 0:40].fill(3)
feature[3 + 5, 0:40].fill(4)
feature[8:14, 0:40].fill(4)
trans = trans_add_delta.TransAddDelta()
feature = feature.reshape((14 * 120))
trans._regress(feature, 5 * 120, feature, 5 * 120 + 40, 40, 4, 120)
trans._regress(feature, 5 * 120 + 40, feature, 5 * 120 + 80, 40, 4, 120)
feature = feature.reshape((14, 120))
tmp_feature = feature[5:5 + 4, :]
self.assertAlmostEqual(1.0, tmp_feature[0][0])
self.assertAlmostEqual(0.24, tmp_feature[0][119])
self.assertAlmostEqual(2.0, tmp_feature[1][0])
self.assertAlmostEqual(0.13, tmp_feature[1][119])
self.assertAlmostEqual(3.0, tmp_feature[2][0])
self.assertAlmostEqual(-0.13, tmp_feature[2][119])
self.assertAlmostEqual(4.0, tmp_feature[3][0])
self.assertAlmostEqual(-0.24, tmp_feature[3][119])
def test_perform(self):
"""test perform
"""
feature = np.zeros((4, 40), dtype="float32")
feature[0, 0:40].fill(1)
feature[1, 0:40].fill(2)
feature[2, 0:40].fill(3)
feature[3, 0:40].fill(4)
trans = trans_add_delta.TransAddDelta()
(feature, label, name) = trans.perform_trans((feature, None, None))
self.assertAlmostEqual(feature.shape[0], 4)
self.assertAlmostEqual(feature.shape[1], 120)
self.assertAlmostEqual(1.0, feature[0][0])
self.assertAlmostEqual(0.24, feature[0][119])
self.assertAlmostEqual(2.0, feature[1][0])
self.assertAlmostEqual(0.13, feature[1][119])
self.assertAlmostEqual(3.0, feature[2][0])
self.assertAlmostEqual(-0.13, feature[2][119])
self.assertAlmostEqual(4.0, feature[3][0])
self.assertAlmostEqual(-0.24, feature[3][119])
class TestTransSplict(unittest.TestCase):
"""unit test Test TransSplict
"""
def test_perfrom(self):
feature = np.zeros((8, 10), dtype="float32")
for i in xrange(feature.shape[0]):
feature[i, :].fill(i)
trans = trans_splice.TransSplice()
(feature, label, name) = trans.perform_trans((feature, None, None))
self.assertEqual(feature.shape[1], 110)
for i in xrange(8):
nzero_num = 5 - i
cur_val = 0.0
if nzero_num < 0:
cur_val = i - 5 - 1
for j in xrange(11):
if j <= nzero_num:
for k in xrange(10):
self.assertAlmostEqual(feature[i][j * 10 + k], cur_val)
else:
if cur_val < 7:
cur_val += 1.0
for k in xrange(10):
self.assertAlmostEqual(feature[i][j * 10 + k], cur_val)
class TestTransDelay(unittest.TestCase):
"""unittest TransDelay
"""
def test_perform(self):
label = np.zeros((10, 1), dtype="int64")
for i in xrange(10):
label[i][0] = i
trans = trans_delay.TransDelay(5)
(_, label, _) = trans.perform_trans((None, label, None))
for i in xrange(5):
self.assertAlmostEqual(label[i + 5][0], i)
for i in xrange(5):
self.assertAlmostEqual(label[i][0], 0)
if __name__ == '__main__':
unittest.main()
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import math
import copy
class TransAddDelta(object):
""" add delta of feature data
trans feature for shape(a, b) to shape(a, b * 3)
Attributes:
_norder(int):
_window(int):
"""
def __init__(self, norder=2, nwindow=2):
""" init construction
Args:
norder: default 2
nwindow: default 2
"""
self._norder = norder
self._nwindow = nwindow
def perform_trans(self, sample):
""" add delta for feature
trans feature shape from (a,b) to (a, b * 3)
Args:
sample(object,tuple): contain feature numpy and label numpy
Returns:
(feature, label, name)
"""
(feature, label, name) = sample
frame_dim = feature.shape[1]
d_frame_dim = frame_dim * 3
head_filled = 5
tail_filled = 5
mat = np.zeros(
(feature.shape[0] + head_filled + tail_filled, d_frame_dim),
dtype="float32")
#copy first frame
for i in xrange(head_filled):
np.copyto(mat[i, 0:frame_dim], feature[0, :])
np.copyto(mat[head_filled:head_filled + feature.shape[0], 0:frame_dim],
feature[:, :])
# copy last frame
for i in xrange(head_filled + feature.shape[0], mat.shape[0], 1):
np.copyto(mat[i, 0:frame_dim], feature[feature.shape[0] - 1, :])
nframe = feature.shape[0]
start = head_filled
tmp_shape = mat.shape
mat = mat.reshape((tmp_shape[0] * tmp_shape[1]))
self._regress(mat, start * d_frame_dim, mat,
start * d_frame_dim + frame_dim, frame_dim, nframe,
d_frame_dim)
self._regress(mat, start * d_frame_dim + frame_dim, mat,
start * d_frame_dim + 2 * frame_dim, frame_dim, nframe,
d_frame_dim)
mat.shape = tmp_shape
return (mat[head_filled:mat.shape[0] - tail_filled, :], label, name)
def _regress(self, data_in, start_in, data_out, start_out, size, n, step):
""" regress
Args:
data_in: in data
start_in: start index of data_in
data_out: out data
start_out: start index of data_out
size: frame dimentional
n: frame num
step: 3 * (frame num)
Returns:
None
"""
sigma_t2 = 0.0
delta_window = self._nwindow
for t in xrange(1, delta_window + 1):
sigma_t2 += t * t
sigma_t2 *= 2.0
for i in xrange(n):
fp1 = start_in
fp2 = start_out
for j in xrange(size):
back = fp1
forw = fp1
sum = 0.0
for t in xrange(1, delta_window + 1):
back -= step
forw += step
sum += t * (data_in[forw] - data_in[back])
data_out[fp2] = sum / sigma_t2
fp1 += 1
fp2 += 1
start_in += step
start_out += step
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import math
class TransDelay(object):
""" Delay label, and copy first label value in the front.
Attributes:
_delay_time : the delay frame num of label
"""
def __init__(self, delay_time):
"""init construction
Args:
delay_time : the delay frame num of label
"""
self._delay_time = delay_time
def perform_trans(self, sample):
"""
Args:
sample(object):input sample, contain feature numpy and label numpy, sample name list
Returns:
(feature, label, name)
"""
(feature, label, name) = sample
shape = label.shape
assert len(shape) == 2
label[self._delay_time:shape[0]] = label[0:shape[0] - self._delay_time]
for i in xrange(self._delay_time):
label[i][0] = label[self._delay_time][0]
return (feature, label, name)
此差异已折叠。
ThreadPool
build
post_latgen_faster_mapped.so
pybind11
此差异已折叠。
此差异已折叠。
set -e
if [ ! -d pybind11 ]; then
git clone https://github.com/pybind/pybind11.git
fi
if [ ! -d ThreadPool ]; then
git clone https://github.com/progschj/ThreadPool.git
echo -e "\n"
fi
python setup.py build_ext -i
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册