提交 73d5f419 编写于 作者: J jerrywgz

Merge branch 'master' of https://github.com/PaddlePaddle/AutoDL into add_more_config_for_lrc

*.DS_Store
*.vs
build/
build_doc/
*.user
.vscode
.idea
.project
.cproject
.pydevproject
.settings/
CMakeSettings.json
Makefile
.test_env/
third_party/
*~
bazel-*
third_party/
build_*
# clion workspace
cmake-build-*
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*paddle)^.*$
- id: end-of-file-fixer
files: \.md$
- id: trailing-whitespace
files: \.md$
- repo: https://github.com/Lucas-C/pre-commit-hooks
sha: v1.0.1
hooks:
- id: forbid-crlf
files: \.md$
- id: remove-crlf
files: \.md$
- id: forbid-tabs
files: \.md$
- id: remove-tabs
files: \.md$
[style]
based_on_style = pep8
column_limit = 80
#!/bin/bash
function abort(){
echo "Your commit does not fit PaddlePaddle code style" 1>&2
echo "Please use pre-commit scripts to auto-format your code" 1>&2
exit 1
}
trap 'abort' 0
set -e
cd `dirname $0`
cd ..
export PATH=/usr/bin:$PATH
pre-commit install
if ! pre-commit run -a ; then
ls -lh
git diff --exit-code
exit 1
fi
trap : 0
# AutoDL Design 简介
## 目录
- [安装](#安装)
- [简介](#简介)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
## 安装
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.3.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据安装文档中的说明来更新PaddlePaddle。
* 安装Python2.7
* 训练执行依赖[PARL](https://github.com/PaddlePaddle/PARL) 框架和[absl-py](https://github.com/abseil/abseil-py/tree/master/absl) 库,通过如下命令安装
```
pip install parl
pip install absl-py
```
## 简介
[AutoDL](http://www.paddlepaddle.org/paddle/ModelAutoDL)是一种高效的自动搜索构建最佳网络结构的方法,通过增强学习在不断训练过程中得到定制化高质量的模型。系统由两部分组成,第一部分是网络结构的编码器,第二部分是网络结构的评测器。编码器通常以 RNN 的方式把网络结构进行编码,然后评测器把编码的结果拿去进行训练和评测,拿到包括准确率、模型大小在内的一些指标,反馈给编码器,编码器进行修改,再次编码,如此迭代。经过若干次迭代以后,最终得到一个设计好的模型。这里开源的AutoDl Design是基于PaddlePaddle框架的一种AutoDL技术实现。第二节介绍AutoDL Design的使用步骤。第三节介绍AutoDL Design的实现原理与示例。
## 数据准备
* 克隆[PaddlePaddle/AutoDL](https://github.com/PaddlePaddle/AutoDL.git)到测试机,并进入AutoDL Design路径。
* 下载[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)训练数据,并放入AutoDL Design/cifar路径下并解压,按如下命令执行`dataset_maker.py` 生成10类每类100张图片的pickle训练小数据集。
```
tar zxf cifar-10-python.tar.gz
python dataset_maker.py
```
## 模型训练
AutoDL Design在训练过程中每次通过Agent的策略网络生成用于训练的Tokens和邻接矩阵Adj,然后Trainer通过Tokens和Adj进行CNN训练网络的组建与训练,训练20个epoch以后返回训练的acc值作为Reward返回给Agent,Agent收到Reward值以后更新策略,通过不断的迭代,最终可以自动搜索到效果不错的深度神经网络。
![图片](./img/cnn_net.png)
这里提供了如下两种测试的方法。
### 针对生成tokens个数的收敛性测试
由于CNN训练每次执行时间长,为了测试整体Agent框架的正确性,这里我们把生成tokens个数作为模拟CNN训练的返回Reward,最终自动搜索出来的网络可以使得tokens的个数越来越多。这里tokens向量总长度设置的是20。执行以下命令:
```
export FLAGS_fraction_of_gpu_memory_to_use=0.98
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fast_eager_deletion_mode=1
CUDA_VISIBLE_DEVICES=0 python -u simple_main.py
```
预期结果:
日志中 `average rewards`逐步向20收敛递增
```
Simple run target is 20
mid=0, average rewards=2.500
...
mid=450, average rewards=17.100
mid=460, average rewards=17.000
```
### AutoDL网络自动搜索训练
基于CIFAR-10小数据集上执行自动网络搜索策略,每次先执行Agent策略网络生成新的策略,然后Trainer根据生成的策略去进行模型的训练与Reward(即acc指标)结果的反馈,最终不断迭代搜索到准确率更高的模型网络结构。执行以下命令:
```
export FLAGS_fraction_of_gpu_memory_to_use=0.98
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fast_eager_deletion_mode=1
CUDA_VISIBLE_DEVICES=0 python -u main.py
```
__注意:__ 这里训练需要使用两张卡来训练,Agent使用的卡为`CUDA_VISIBLE_DEVICES=0`(设置在启动`main.py`命令中);Trainer训练使用卡为`CUDA_VISIBLE_DEVICES=1`(设置在[autodl.py](https://github.com/PaddlePaddle/AutoDL/blob/master/AutoDL%20Design/autodl.py#L124)文件中)
预期结果:
日志中 `average accuracy`逐步增大
```
step = 0, average accuracy = 0.633
step = 1, average accuracy = 0.688
step = 2, average accuracy = 0.626
step = 3, average accuracy = 0.682
......
step = 842, average accuracy = 0.823
step = 843, average accuracy = 0.825
step = 844, average accuracy = 0.808
......
```
### 结果展示
![图片](./img/search_result.png)
横坐标是迭代的step轮数,纵坐标是模型训练的acc指标,通过图中所示,通过不断的迭代搜索,使得自动构建模型的效果在不断增强。
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
AutoDL definition
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import argparse
import numpy as np
import subprocess
import paddle.fluid as fluid
from reinforce_policy_gradient import ReinforcePolicyGradient
from policy_model import PolicyModel
from autodl_agent import AutoDLAgent
import utils
import collections
class AutoDL(object):
"""
AutoDL class
"""
def __init__(self):
"""
init
"""
self.parse_args = self._init_parser()
self.bl_decay = self.parse_args.bl_decay
self.log_dir = self.parse_args.log_dir
self.early_stop = self.parse_args.early_stop
self.data_path = self.parse_args.data_path
self.num_models = self.parse_args.num_models
self.batch_size = self.parse_args.batch_size
self.chunk_size= self.parse_args.chunk_size
self._init_dir_path()
self.model = PolicyModel(self.parse_args)
algo_hyperparas = {'lr': self.parse_args.learning_rate}
self.algorithm = ReinforcePolicyGradient(self.model,
hyperparas=algo_hyperparas)
self.autodl_agent = AutoDLAgent(self.algorithm, self.parse_args)
self.total_reward = 0
def _init_dir_path(self):
"""
init dir path
"""
utils.prepare(self.log_dir)
utils.prepare(self.log_dir, "actions")
utils.prepare(self.log_dir, "rewards")
utils.prepare(self.log_dir, "checkpoints")
def _init_parser(self):
"""
init parser
"""
parser = argparse.ArgumentParser(description='AutoDL Parser',
prog='AutoDL')
parser.add_argument('-v', '--version', action='version',
version='%(prog)s 0.1')
parser.add_argument('--num_nodes', dest="num_nodes", nargs="?",
type=int, const=10, default=10,
help="number of nodes")
parser.add_argument('--num_tokens', dest="num_tokens", nargs="?",
type=int, const=10, default=10,
help="number of tokens")
parser.add_argument('--learning_rate', dest="learning_rate", nargs="?",
type=float, default=1e-3,
help="learning rate")
parser.add_argument('--batch_size', dest="batch_size", nargs="?",
type=int, const=10, default=10, help="batch size")
parser.add_argument('--num_models', dest="num_models", nargs="?",
type=int, const=32000, default=32000,
help="maximum number of models sampled")
parser.add_argument('--early_stop', dest="early_stop", nargs="?",
type=int, const=20, default=20, help="early stop")
parser.add_argument('--log_dir', dest="log_dir", nargs="?", type=str,
const="./log", default="./log",
help="directory of log")
parser.add_argument('--input_size', dest="input_size", nargs="?",
type=int, const=10, default=10, help="input size")
parser.add_argument('--hidden_size', dest="hidden_size", nargs="?",
type=int, const=64, default=64, help="hidden size")
parser.add_argument('--num_layers', dest="num_layers", nargs="?",
type=int, const=2, default=2, help="num layers")
parser.add_argument('--bl_decay', dest="bl_decay", nargs="?",
type=float, const=0.9, default=0.9,
help="base line decay")
# inception train config
parser.add_argument('--data_path', dest="data_path", nargs="?",
type=str, default="./cifar/pickle-cifar-10",
help="path of data files")
parser.add_argument('--chunk_size', dest="chunk_size", nargs="?",
type=int, const=100, default=100,
help="chunk size")
parse_args = parser.parse_args()
return parse_args
def supervisor(self, mid):
"""
execute cnn training
sample cmd: python -u inception_train/train.py --mid=9 \
--early_stop=20 --data_path=./cifar/pickle-cifar-10
"""
tokens, adjvec = utils.load_action(mid, self.log_dir)
cmd = ("CUDA_VISIBLE_DEVICES=1 python -u inception_train/train.py \
--mid=%d --early_stop=%d --logdir=%s --data_path=%s --chunk_size=%d") % \
(mid, self.early_stop, self.log_dir, self.data_path, self.chunk_size)
print("cmd:{}".format(cmd))
while True:
try:
subprocess.check_call(cmd, shell=True)
break
except subprocess.CalledProcessError as e:
print("[%s] training model #%d exits with exit code %d" %
(utils.stime(), mid, e.returncode), file=sys.stderr)
return
def simple_run(self):
"""
simple run
"""
print("Simple run target is 20")
mid = 0
shadow = 0
is_first = True
while mid <= self.num_models:
actions_to, actions_ad = self.autodl_agent.sample()
rewards = np.count_nonzero(actions_to == 1, axis=1).astype("int32")
# moving average
current_mean_reward = np.mean(rewards)
if is_first:
shadow = current_mean_reward
is_first = False
else:
shadow = shadow * self.bl_decay \
+ current_mean_reward * (1 - self.bl_decay)
self.autodl_agent.learn((np.array(actions_to).astype("int32"),
np.array(actions_ad).astype("int32")),
rewards - shadow)
if mid % 10 == 0:
print('mid=%d, average rewards=%.3f' % (mid, np.mean(rewards)))
mid += 1
def run(self):
"""
run
"""
rewards = []
mid = 0
while mid <= self.num_models:
actions_to, actions_ad = self.autodl_agent.sample()
for action in zip(actions_to, actions_ad):
utils.dump_action(mid, action, self.log_dir)
self.supervisor(mid)
current_reward = utils.load_reward(mid, self.log_dir)
if not np.isnan(current_reward):
rewards.append(current_reward.item())
mid += 1
if len(rewards) % self.batch_size == 0:
print("[%s] step = %d, average accuracy = %.3f" %
(utils.stime(), self.autodl_agent.global_step,
np.mean(rewards)))
rewards_array = np.array(rewards).astype("float32")
if self.total_reward == 0:
self.total_reward = rewards_array.mean()
else:
self.total_reward = self.total_reward * self.bl_decay \
+ (1 - self.bl_decay) * rewards_array.mean()
rewards_array = rewards_array - self.total_reward
self.autodl_agent.learn([actions_to.astype("int32"),
actions_ad.astype("int32")],
rewards_array ** 3)
rewards = []
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
AutoDL Agent Definition
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from parl.framework.agent_base import Agent
class AutoDLAgent(Agent):
"""
AutoDLAgent
"""
def __init__(self, algorithm, parse_args):
"""
init
"""
self.global_step = 0
self.parse_args = parse_args
self.num_nodes = self.parse_args.num_nodes
self.batch_size = self.parse_args.batch_size
super(AutoDLAgent, self).__init__(algorithm)
self.inputs_data = np.zeros([self.batch_size,
1]).astype('int32')
def build_program(self):
"""
build program
"""
self.predict_program = fluid.Program()
self.train_program = fluid.Program()
with fluid.program_guard(self.predict_program):
self.predict_inputs = layers.data(
name='input',
append_batch_size=False,
shape=[self.batch_size, 1],
dtype='int32')
self.predict_tokens, self.predict_adjvec = self.alg.define_predict(
self.predict_inputs)
with fluid.program_guard(self.train_program):
self.train_inputs = layers.data(
name='input',
append_batch_size=False,
shape=[self.batch_size, 1],
dtype='int32')
self.actions_to = layers.data(
name='actions_to',
append_batch_size=False,
shape=[self.batch_size,
self.num_nodes * 2],
dtype='int32')
self.actions_ad = layers.data(
name='actions_ad',
append_batch_size=False,
shape=[self.batch_size,
self.num_nodes * (self.num_nodes - 1)],
dtype='int32')
self.rewards = layers.data(
name='rewards',
append_batch_size=False,
shape=[self.batch_size],
dtype='float32')
self.cost = self.alg.define_learn(
obs=self.train_inputs, reward=self.rewards,
action=[self.actions_to, self.actions_ad])
def sample(self):
"""
sample
"""
feed_dict = {'input': self.inputs_data}
[actions_to, actions_ad] = self.fluid_executor.run(
self.predict_program, feed=feed_dict,
fetch_list=[self.predict_tokens, self.predict_adjvec])
return actions_to, actions_ad
def learn(self, actions, reward):
"""
learn
"""
(actions_to, actions_ad) = actions
feed_dict = {'input': self.inputs_data, 'actions_to': actions_to,
'actions_ad': actions_ad, 'rewards': reward}
cost = self.fluid_executor.run(
self.train_program, feed=feed_dict, fetch_list=[self.cost])[0]
self.global_step += 1
return cost
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate pkl files from cifar10
"""
import os
import cPickle as pickle
import random
import numpy as np
import sys
import argparse
def init_parser():
"""
init_parser
"""
parser = argparse.ArgumentParser(description='Data generator')
parser.add_argument('--chunk_size', dest="chunk_size", nargs="?",
type=int, default=100,
help="size of chunk")
parser.add_argument('--input_dir', dest="input_dir", nargs="?",
type=str, default='./cifar-10-batches-py',
help="path of input")
parser.add_argument('--output_dir', dest="output_dir", nargs="?",
type=str, default='./pickle-cifar-10',
help="path of output")
parse_args, unknown_flags = parser.parse_known_args()
return parse_args
def get_file_names(input_dir):
"""
get all file names located in dir_path
"""
sub_name = 'data_batch'
files = os.listdir(input_dir)
names = [each_item for each_item in files if sub_name in each_item]
return names
def check_output_dir(output_dir):
"""
check exist of output dir
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
def get_datasets(input_dir, chunk_size):
"""
get image datasets
chunk_size is the number of each class
"""
total_size = chunk_size * 10
names = get_file_names(parse_args.input_dir)
img_count = 0
datasets = []
class_map = {i: 0 for i in range(10)}
for name in names:
print("Reading file " + name)
batch = pickle.load(open(input_dir + "/" + name, 'rb'))
data = batch['data']
labels = batch.get('labels', batch.get('fine_labels', None))
assert labels is not None
data_tuples = zip(data, labels)
for data in data_tuples:
if class_map[data[1]] < chunk_size:
datasets.append(data)
class_map[data[1]] += 1
img_count += 1
if img_count >= total_size:
random.shuffle(datasets)
for k, v in class_map.items():
print("label:{} count:{}".format(k, v))
return np.array(datasets)
random.shuffle(datasets)
return np.array(datasets)
def dump_pkl(datasets, output_dir):
"""
dump_pkl
"""
chunk_size = parse_args.chunk_size
for i in range(10):
sub_dataset = datasets[i * chunk_size:(i + 1) * chunk_size, :]
sub_dataset.dump(output_dir + "/" + 'data_batch_' + str(i) + '.pkl')
if __name__ == "__main__":
parse_args = init_parser()
check_output_dir(parse_args.output_dir)
datasets = get_datasets(parse_args.input_dir, parse_args.chunk_size)
dump_pkl(datasets, parse_args.output_dir)
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implementation of binomial and multinomial distribution
"""
import paddle.fluid as fluid
import functools
import numpy as np
def create_tmp_var(name, dtype, shape, program=None):
"""
Create variable which is used to store the py_func result
"""
if program is None:
return fluid.default_main_program().current_block().create_var(
name=fluid.unique_name.generate(name),
dtype=dtype, shape=shape)
else:
return program.current_block().create_var(
name=fluid.unique_name.generate(name),
dtype=dtype, shape=shape)
def sigmoid(x):
"""
Sigmoid
"""
return (1 / (1 + np.exp(-x)))
def softmax(x):
"""
Compute softmax values for each sets of scores in x.
"""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
def py_func_bernoulli(input):
"""
Binormial python function definition
"""
prob_array = sigmoid(np.array(input))
sample = np.random.binomial(1, prob_array)
return sample
def bernoulli(input_logits, output_shape, program=None):
"""
Bernoulli
"""
# the output_shape is the same as input_logits
samples_var = create_tmp_var(name='binomial_result_var',
dtype='float32', shape=output_shape,
program=program)
fluid.layers.py_func(func=py_func_bernoulli, x=input_logits,
out=samples_var, backward_func=None,
skip_vars_in_backward_input=None)
return samples_var
def py_func_multinomial(logits, num_samples_var):
"""
Multinomial python function definition
Input:
input: list of [logits_array, num_samples_int]
"""
def generate(x, prob_array):
"""
Sample multinomial
"""
sample = np.random.multinomial(1, prob_array)
ret = np.argmax(sample)
return ret
num_samples = int(np.array(num_samples_var)[0])
logits_array = np.array(logits)
if len(logits_array.shape) != 2:
raise Exception("Shape must be rank 2 but is rank {} \
for 'multinomial/Multinomial' (op: 'Multinomial') \
with input shapes:{}".format(len(logits_array.shape),
logits_array.shape))
ret = np.array([])
for logits in logits_array:
prob = softmax(logits)
func = functools.partial(generate, prob_array=prob)
sample = np.zeros(num_samples)
sample = np.array(list(map(func, sample)))
ret = np.append(ret, sample)
ret = ret.reshape(-1, num_samples).astype("int32")
return ret
def multinomial(input_logits, output_shape, num_samples, program=None):
"""
Multinomial
input_logits's dimension is [M * D]
output_shape's dimension is [M * num_samples]
"""
samples_var = create_tmp_var(name='multinomial_result_var',
dtype='int32', shape=output_shape,
program=program)
num_samples_var = fluid.layers.fill_constant(shape=[1], value=num_samples,
dtype='int32')
fluid.layers.py_func(func=py_func_multinomial,
x=[input_logits, num_samples_var],
out=samples_var, backward_func=None,
skip_vars_in_backward_input=None)
return samples_var
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Inception Definition
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from absl import flags
import numpy as np
import models.layers as layers
import models.ops as _ops
FLAGS = flags.FLAGS
flags.DEFINE_integer("num_stages", 3, "number of stages")
flags.DEFINE_integer("num_cells", 3, "number of cells per stage")
flags.DEFINE_integer("width", 64, "network width")
flags.DEFINE_integer("ratio", 4, "compression ratio")
num_classes = 10
ops = [
_ops.conv_1x1,
_ops.conv_3x3,
_ops.conv_5x5,
_ops.dilated_3x3,
_ops.conv_1x3_3x1,
_ops.conv_1x5_5x1,
_ops.maxpool_3x3,
_ops.maxpool_5x5,
_ops.avgpool_3x3,
_ops.avgpool_5x5,
]
def net(inputs, output, tokens, adjvec):
"""
create net
"""
num_nodes = len(tokens) // 2
def slice(vec):
"""
slice vec
"""
mat = np.zeros([num_nodes, num_nodes])
def pos(x):
"""
pos
"""
return x * (x - 1) // 2
for i in range(1, num_nodes):
mat[0:i, i] = vec[pos(i):pos(i + 1)]
return mat
normal_to, reduce_to = np.split(tokens, 2)
normal_ad, reduce_ad = map(slice, np.split(adjvec, 2))
x = layers.conv(inputs, FLAGS.width, (3, 3))
c = 1
for _ in range(FLAGS.num_cells):
x = cell(x, normal_to, normal_ad)
c += 1
for _ in range(1, FLAGS.num_stages):
x = cell(x, reduce_to, reduce_ad, downsample=True)
c += 1
for _ in range(1, FLAGS.num_cells):
x = cell(x, normal_to, normal_ad)
c += 1
x = layers.bn_relu(x)
x = layers.global_avgpool(x)
x = layers.dropout(x)
logits = layers.fully_connected(x, num_classes)
x = fluid.layers.softmax_with_cross_entropy(logits, output,
numeric_stable_mode=True)
loss = fluid.layers.reduce_mean(x)
accuracy = fluid.layers.accuracy(input=logits, label=output)
return loss, accuracy
def cell(inputs, tokens, adjmat, downsample=False, name=None):
"""
cell
"""
filters = inputs.shape[1]
d = filters // FLAGS.ratio
num_nodes, tensors = len(adjmat), []
for n in range(num_nodes):
func = ops[tokens[n]]
idx, = np.nonzero(adjmat[:, n])
if len(idx) == 0:
x = layers.bn_relu(inputs)
x = layers.conv(x, d, (1, 1))
x = layers.bn_relu(x)
x = func(x, downsample)
else:
x = fluid.layers.sums([tensors[i] for i in idx])
x = layers.bn_relu(x)
x = func(x)
tensors.append(x)
free_ends, = np.where(~adjmat.any(axis=1))
tensors = [tensors[i] for i in free_ends]
filters = filters * 2 if downsample else filters
x = fluid.layers.concat(tensors, axis=1)
x = layers.conv(x, filters, (1, 1))
return x
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Layers Definition
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import operator
import numpy as np
import paddle.fluid as fluid
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_float("weight_decay", 0.0001,
"weight decay")
flags.DEFINE_float("bn_decay", 0.9,
"batch norm decay")
flags.DEFINE_float("relu_leakiness", 0.1,
"relu leakiness")
flags.DEFINE_float("dropout_rate", 0.5,
"dropout rate")
def calc_padding(img_width, stride, dilation, filter_width):
"""
calculate pixels to padding in order to keep input/output size same.
"""
filter_width = dilation * (filter_width - 1) + 1
if img_width % stride == 0:
pad_along_width = max(filter_width - stride, 0)
else:
pad_along_width = max(filter_width - (img_width % stride), 0)
return pad_along_width // 2, pad_along_width - pad_along_width // 2
def conv(inputs,
filters,
kernel,
strides=None,
dilation=None,
num_groups=1,
conv_param=None,
name=None):
"""
normal conv layer
"""
if strides is None:
strides = (1, 1)
if dilation is None:
dilation = (1, 1)
if isinstance(kernel, (tuple, list)):
n = operator.mul(*kernel) * inputs.shape[1]
else:
n = kernel * kernel * inputs.shape[1]
# pad input
padding = (0, 0, 0, 0) \
+ calc_padding(inputs.shape[2], strides[0], dilation[0], kernel[0]) \
+ calc_padding(inputs.shape[3], strides[1], dilation[1], kernel[1])
if sum(padding) > 0:
inputs = fluid.layers.pad(inputs, padding, 0)
param_attr = fluid.param_attr.ParamAttr(
initializer=fluid.initializer.TruncatedNormal(
0.0, scale=np.sqrt(2.0 / n)),
regularizer=fluid.regularizer.L2Decay(FLAGS.weight_decay))
return fluid.layers.conv2d(
inputs,
filters,
kernel,
stride=strides,
padding=0,
dilation=dilation,
groups=num_groups,
param_attr=param_attr if conv_param is None else conv_param,
use_cudnn=False if num_groups == inputs.shape[1] == filters else True,
name=name)
def sep(inputs, filters, kernel, strides=None, dilation=None, name=None):
"""
Separable convolution layer
"""
if strides is None:
strides = (1, 1)
if dilation is None:
dilation = (1, 1)
if isinstance(kernel, (tuple, list)):
n_depth = operator.mul(*kernel)
else:
n_depth = kernel * kernel
n_point = inputs.shape[1]
if isinstance(strides, (tuple, list)):
multiplier = strides[0]
else:
multiplier = strides
depthwise_param = fluid.param_attr.ParamAttr(
initializer=fluid.initializer.TruncatedNormal(
0.0, scale=np.sqrt(2.0 / n_depth)),
regularizer=fluid.regularizer.L2Decay(FLAGS.weight_decay))
pointwise_param = fluid.param_attr.ParamAttr(
initializer=fluid.initializer.TruncatedNormal(
0.0, scale=np.sqrt(2.0 / n_point)),
regularizer=fluid.regularizer.L2Decay(FLAGS.weight_decay))
depthwise_conv = conv(
inputs=inputs,
kernel=kernel,
filters=int(filters * multiplier),
strides=strides,
dilation=dilation,
num_groups=int(filters * multiplier),
conv_param=depthwise_param,
name='depthwise_' + name)
return conv(
inputs=depthwise_conv,
kernel=(1, 1),
filters=int(filters * multiplier),
strides=(1, 1),
dilation=dilation,
conv_param=pointwise_param,
name='pointwise_' + name)
def maxpool(inputs, kernel, strides=None, name=None):
"""
maxpool
"""
if strides is None:
strides = (1, 1)
padding = (0, 0, 0, 0) \
+ calc_padding(inputs.shape[2], strides[0], 1, kernel[0]) \
+ calc_padding(inputs.shape[3], strides[1], 1, kernel[1])
if sum(padding) > 0:
inputs = fluid.layers.pad(inputs, padding, 0)
return fluid.layers.pool2d(
inputs, kernel, 'max', strides, pool_padding=0,
ceil_mode=False, name=name)
def avgpool(inputs, kernel, strides=None, name=None):
"""
avgpool
"""
if strides is None:
strides = (1, 1)
padding_pixel = (0, 0, 0, 0)
padding_pixel += calc_padding(inputs.shape[2], strides[0], 1, kernel[0])
padding_pixel += calc_padding(inputs.shape[3], strides[1], 1, kernel[1])
if padding_pixel[4] == padding_pixel[5] and padding_pixel[
6] == padding_pixel[7]:
# same padding pixel num on all sides.
return fluid.layers.pool2d(
inputs,
kernel,
'avg',
strides,
pool_padding=(padding_pixel[4], padding_pixel[6]),
ceil_mode=False)
elif padding_pixel[4] + 1 == padding_pixel[5] \
and padding_pixel[6] + 1 == padding_pixel[7] \
and strides == (1, 1):
# different padding size: first pad then crop.
x = fluid.layers.pool2d(
inputs,
kernel,
'avg',
strides,
pool_padding=(padding_pixel[5], padding_pixel[7]),
ceil_mode=False)
x_shape = x.shape
return fluid.layers.crop(
x,
shape=(-1, x_shape[1], x_shape[2] - 1, x_shape[3] - 1),
offsets=(0, 0, 1, 1), name=name)
else:
# not support. use padding-zero and pool2d.
print("Warning: use zero-padding in avgpool")
outputs = fluid.layers.pad(inputs, padding_pixel, 0)
return fluid.layers.pool2d(
outputs, kernel, 'avg', strides, pool_padding=0,
ceil_mode=False, name=name)
def global_avgpool(inputs, name=None):
"""
global avgpool
"""
return fluid.layers.reduce_mean(inputs, dim=[2, 3], name=name)
def fully_connected(inputs, units, name=None):
"""
fully connected
"""
n = inputs.shape[1]
param_attr = fluid.param_attr.ParamAttr(
initializer=fluid.initializer.TruncatedNormal(
0.0, scale=np.sqrt(2.0 / n)),
regularizer=fluid.regularizer.L2Decay(FLAGS.weight_decay))
return fluid.layers.fc(inputs,
units,
param_attr=param_attr)
def batch_norm(inputs, name=None):
"""
batch norm
"""
param_attr = fluid.param_attr.ParamAttr(
regularizer=fluid.regularizer.L2Decay(FLAGS.weight_decay))
bias_attr = fluid.param_attr.ParamAttr(
regularizer=fluid.regularizer.L2Decay(FLAGS.weight_decay))
return fluid.layers.batch_norm(
inputs, momentum=FLAGS.bn_decay, epsilon=0.001,
param_attr=param_attr,
bias_attr=bias_attr)
def relu(inputs, name=None):
"""
relu
"""
if FLAGS.relu_leakiness:
return fluid.layers.leaky_relu(inputs, FLAGS.relu_leakiness, name=name)
return fluid.layers.relu(inputs, name=name)
def bn_relu(inputs, name=None):
"""
batch norm + rely layer
"""
output = batch_norm(inputs)
return fluid.layers.relu(output, name=name)
def dropout(inputs, name=None):
"""
dropout layer
"""
return fluid.layers.dropout(inputs, dropout_prob=FLAGS.dropout_rate,
dropout_implementation='upscale_in_train',
name=name)
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base Ops Definition
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import models.layers as layers
def conv_1x1(inputs, downsample=False):
"""
conv_1x1
"""
return conv_base(inputs, (1, 1), downsample=downsample)
def conv_2x2(inputs, downsample=False):
"""
conv_2x2
"""
return conv_base(inputs, (2, 2), downsample=downsample)
def conv_3x3(inputs, downsample=False):
"""
conv_3x3
"""
return conv_base(inputs, (3, 3), downsample=downsample)
def conv_4x4(inputs, downsample=False):
"""
conv_4x4
"""
return conv_base(inputs, (4, 4), downsample=downsample)
def conv_5x5(inputs, downsample=False):
"""
conv_5x5
"""
return conv_base(inputs, (5, 5), downsample=downsample)
def dilated_2x2(inputs, downsample=False):
"""
dilated_2x2
"""
return conv_base(inputs, (2, 2), (2, 2), downsample)
def dilated_3x3(inputs, downsample=False):
"""
dilated_3x3
"""
return conv_base(inputs, (3, 3), (2, 2), downsample)
def conv_1x2_2x1(inputs, downsample=False):
"""
conv_1x2_2x1
"""
return pair_base(inputs, 2, downsample)
def conv_1x3_3x1(inputs, downsample=False):
"""
conv_1x3_3x1
"""
return pair_base(inputs, 3, downsample)
def conv_1x4_4x1(inputs, downsample=False):
"""
conv_1x4_4x1
"""
return pair_base(inputs, 4, downsample)
def conv_1x5_5x1(inputs, downsample=False):
"""
conv_1x5_5x1
"""
return pair_base(inputs, 5, downsample)
def sep_2x2(inputs, downsample=False):
"""
sep_2x2
"""
return sep_base(inputs, (2, 2), downsample=downsample)
def sep_3x3(inputs, downsample=False):
"""
sep_3x3
"""
return sep_base(inputs, (3, 3), downsample=downsample)
def sep_4x4(inputs, downsample=False):
"""
sep_4x4
"""
return sep_base(inputs, (4, 4), downsample=downsample)
def sep_5x5(inputs, downsample=False):
"""
sep_5x5
"""
return sep_base(inputs, (5, 5), downsample=downsample)
def maxpool_2x2(inputs, downsample=False):
"""
maxpool_2x2
"""
return maxpool_base(inputs, (2, 2), downsample)
def maxpool_3x3(inputs, downsample=False):
"""
maxpool_3x3
"""
return maxpool_base(inputs, (3, 3), downsample)
def maxpool_4x4(inputs, downsample=False):
"""
maxpool_4x4
"""
return maxpool_base(inputs, (4, 4), downsample)
def maxpool_5x5(inputs, downsample=False):
"""
maxpool_5x5
"""
return maxpool_base(inputs, (5, 5), downsample)
def avgpool_2x2(inputs, downsample=False):
"""
avgpool_2x2
"""
return avgpool_base(inputs, (2, 2), downsample)
def avgpool_3x3(inputs, downsample=False):
"""
avgpool_3x3
"""
return avgpool_base(inputs, (3, 3), downsample)
def avgpool_4x4(inputs, downsample=False):
"""
avgpool_4x4
"""
return avgpool_base(inputs, (4, 4), downsample)
def avgpool_5x5(inputs, downsample=False):
"""
avgpool_5x5
"""
return avgpool_base(inputs, (5, 5), downsample)
def conv_base(inputs, kernel, dilation=None, downsample=False):
"""
conv_base
"""
if dilation is None:
dilation = (1, 1)
filters = inputs.shape[1]
if downsample:
output = layers.conv(inputs, filters * 2, kernel, (2, 2))
else:
output = layers.conv(inputs, filters, kernel, dilation=dilation)
return output
def pair_base(inputs, kernel, downsample=False):
"""
pair_base
"""
filters = inputs.shape[1]
if downsample:
output = layers.conv(inputs, filters, (1, kernel), (1, 2))
output = layers.conv(output, filters, (kernel, 1), (2, 1))
output = layers.conv(output, filters * 2, (1, 1))
else:
output = layers.conv(inputs, filters, (1, kernel))
output = layers.conv(output, filters, (kernel, 1))
return output
def sep_base(inputs, kernel, dilation=None, downsample=False):
"""
sep_base
"""
if dilation is None:
dilation = (1, 1)
filters = inputs.shape[1]
if downsample:
output = layers.sep(inputs, filters * 2, kernel, (2, 2))
else:
output = layers.sep(inputs, filters, kernel, dilation=dilation)
return output
def maxpool_base(inputs, kernel, downsample=False):
"""
maxpool_base
"""
if downsample:
filters = inputs.shape[1]
output = layers.maxpool(inputs, kernel, (2, 2))
output = layers.conv(output, filters * 2, (1, 1))
else:
output = layers.maxpool(inputs, kernel)
return output
def avgpool_base(inputs, kernel, downsample=False):
"""
avgpool_base
"""
if downsample:
filters = inputs.shape[1]
output = layers.avgpool(inputs, kernel, (2, 2))
output = layers.conv(output, filters * 2, (1, 1))
else:
output = layers.avgpool(inputs, kernel)
return output
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Network Definition
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cPickle as cp
import paddle.fluid as fluid
import paddle.fluid.layers.ops as ops
import paddle.fluid as fluid
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
import math
from paddle.fluid.initializer import init_on_cpu
from models import inception
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_float("lr_max", 0.1,
"initial learning rate")
flags.DEFINE_float("lr_min", 0.0001,
"limiting learning rate")
flags.DEFINE_integer("batch_size", 128,
"batch size")
flags.DEFINE_integer("T_0", 200,
"number of epochs")
flags.DEFINE_integer("chunk_size", 100,
"chunk size")
class CIFARModel(object):
"""
CIFARModel class
"""
def __init__(self, tokens, adjvec, im_shape):
"""
CIFARModel init
"""
chunk_size = FLAGS.chunk_size
self.batch_size = FLAGS.batch_size
self.tokens = tokens
self.adjvec = adjvec
self.im_shape = im_shape
max_step = chunk_size * 9 * FLAGS.T_0 // FLAGS.batch_size
test_batch = chunk_size // FLAGS.batch_size
def cosine_decay():
"""
Applies cosine decay to the learning rate.
"""
global_step = _decay_step_counter()
with init_on_cpu():
frac = (1 + ops.cos(global_step / max_step * math.pi)) / 2
return FLAGS.lr_min + (FLAGS.lr_max - FLAGS.lr_min) * frac
self.lr_strategy = cosine_decay
def fn_model(self, py_reader):
"""
fn model
"""
self.image, self.label = fluid.layers.read_file(py_reader)
self.loss, self.accuracy = inception.net(
self.image, self.label, self.tokens, self.adjvec)
return self.loss, self.accuracy
def build_input(self, image_shape, is_train):
"""
build_input
"""
name = 'train_reader' if is_train else 'test_reader'
py_reader = fluid.layers.py_reader(
capacity=64,
shapes=[[-1] + image_shape, [-1, 1]],
lod_levels=[0, 0],
dtypes=["float32", "int64"],
use_double_buffer=True,
name=name)
return py_reader
def build_program(self, main_prog, startup_prog, is_train):
"""
build_program
"""
out = []
with fluid.program_guard(main_prog, startup_prog):
py_reader = self.build_input(self.im_shape, is_train)
if is_train:
with fluid.unique_name.guard():
loss, accuracy = self.fn_model(py_reader)
optimizer = fluid.optimizer.Momentum(
learning_rate=self.lr_strategy(),
momentum=0.9,
use_nesterov=True)
optimizer.minimize(loss)
out = [py_reader, loss, accuracy]
else:
with fluid.unique_name.guard():
loss, accuracy = self.fn_model(py_reader)
out = [py_reader, loss, accuracy]
return out
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data preprocess
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from PIL import Image
from PIL import ImageOps
from PIL import ImageEnhance
import numpy as np
FLAGS = flags.FLAGS
flags.DEFINE_boolean("random_flip_left_right", True,
"random flip left and right")
flags.DEFINE_boolean("random_flip_up_down", False,
"random flip up and down")
flags.DEFINE_boolean("random_brightness", False,
"randomly adjust brightness")
image_size = 32
def augmentation(sample, is_training):
"""
augmentation
"""
image_array = sample.reshape(3, image_size, image_size)
rgb_array = np.transpose(image_array, (1, 2, 0))
img = Image.fromarray(rgb_array, 'RGB')
if is_training:
# pad and crop
img = ImageOps.expand(img, (4, 4, 4, 4), fill=0) # pad to 40 * 40 * 3
left_top = np.random.randint(9, size=2) # rand 0 - 8
img = img.crop((left_top[0], left_top[1], left_top[0] + image_size,
left_top[1] + image_size))
if FLAGS.random_flip_left_right:
if np.random.randint(2):
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if FLAGS.random_flip_up_down:
if np.random.randint(2):
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if FLAGS.random_brightness:
delta = np.random.uniform(-0.3, 0.3) + 1.
img = ImageEnhance.Brightness(img).enhance(delta)
img = np.array(img).astype(np.float32)
# per_image_standardization
img_float = img / 255.0
num_pixels = img_float.size
img_mean = img_float.mean()
img_std = img_float.std()
scale = np.maximum(np.sqrt(num_pixels), img_std)
img = (img_float - img_mean) / scale
img = np.transpose(img, (2, 0, 1))
return img
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rig hts Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Based on:
# --------------------------------------------------------
# DARTS
# Copyright (c) 2018, Hanxiao Liu.
# Licensed under the Apache License, Version 2.0;
# --------------------------------------------------------
"""
CIFAR-10 dataset.
This module will download dataset from
https://www.cs.toronto.edu/~kriz/cifar.html and parse train/test set into
paddle reader creators.
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes,
with 6000 images per class. There are 50000 training images
and 10000 test images.
"""
import numpy as np
try:
import cPickle as pickle
except ImportError:
import pickle
import random
import utils
import paddle.fluid as fluid
import os
from preprocess import augmentation
def reader_creator_filepath(filename, sub_name, is_training,
batch_size, data_list):
"""
reader creator
"""
dataset = []
for name in data_list:
print("Reading file " + name)
file_path = os.path.join(filename, name)
batch_data = pickle.load(open(file_path))
dataset.append(batch_data)
datasets = np.concatenate(dataset)
if is_training:
np.random.shuffle(dataset)
def read_batch(datasets, is_training):
"""
read batch
"""
for sample, label in datasets:
im = augmentation(sample, is_training)
yield im, [int(label)]
def reader():
"""
get reader
"""
batch_data = []
batch_label = []
for data, label in read_batch(datasets, is_training):
batch_data.append(data)
batch_label.append(label)
if len(batch_data) == batch_size:
batch_data = np.array(batch_data, dtype='float32')
batch_label = np.array(batch_label, dtype='int64')
batch_out = [[batch_data, batch_label]]
yield batch_out
batch_data = []
batch_label = []
if len(batch_data) != 0:
batch_data = np.array(batch_data, dtype='float32')
batch_label = np.array(batch_label, dtype='int64')
batch_out = [[batch_data, batch_label]]
yield batch_out
batch_data = []
batch_label = []
return reader
def train10(data, batch_size, data_list):
"""
CIFAR-10 training set creator.
It returns a reader creator, each sample in the reader is image pixels in
[0, 1] and label in [0, 9].
:return: Training reader creator
:rtype: callable
"""
return reader_creator_filepath(data, 'data_batch', True,
batch_size, data_list)
def test10(data, batch_size, data_list):
"""
CIFAR-10 test set creator.
It returns a reader creator, each sample in the reader is image pixels in
[0, 1] and label in [0, 9].
:return: Test reader creator.
:rtype: callable
"""
return reader_creator_filepath(data, 'test_batch', False,
batch_size, data_list)
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Trainer Definition
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import reader
import sys
import os
import time
import paddle.fluid as fluid
import utils
import cPickle as cp
from absl import flags
from absl import app
from nn import CIFARModel
FLAGS = flags.FLAGS
flags.DEFINE_string("data_path",
"./cifar/pickle-cifar-10",
"data path")
flags.DEFINE_string("logdir", "log",
"logging directory")
flags.DEFINE_integer("mid", 0,
"model id")
flags.DEFINE_integer("early_stop", 20,
"early stop")
image_size = 32
def main(_):
"""
main
"""
image_shape = [3, image_size, image_size]
files = os.listdir(FLAGS.data_path)
names = [each_item for each_item in files]
np.random.shuffle(names)
train_list = names[:9]
test_list = names[-1]
tokens, adjvec = utils.load_action(FLAGS.mid)
model = CIFARModel(tokens, adjvec, image_shape)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
startup = fluid.Program()
train_prog = fluid.Program()
test_prog = fluid.Program()
train_vars = model.build_program(train_prog, startup, True)
test_vars = model.build_program(test_prog, startup, False)
exe.run(startup)
train_accuracy, epoch_id = train(model, FLAGS.early_stop,
train_prog, train_vars, exe, train_list)
if epoch_id < FLAGS.early_stop:
utils.dump_reward(FLAGS.mid, train_accuracy)
else:
test_accuracy = test(model, test_prog, test_vars, exe, [test_list])
utils.dump_reward(FLAGS.mid, test_accuracy)
def train(model, epoch_num, train_prog, train_vars, exe, data_list):
"""
train
"""
train_py_reader, loss_train, acc_train = train_vars
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 1
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = True
train_exe = fluid.ParallelExecutor(
main_program=train_prog,
use_cuda=True,
loss_name=loss_train.name,
exec_strategy=exec_strategy,
build_strategy=build_strategy)
train_reader = reader.train10(FLAGS.data_path, FLAGS.batch_size, data_list)
train_py_reader.decorate_paddle_reader(train_reader)
train_fetch_list = [loss_train, acc_train]
epoch_start_time = time.time()
for epoch_id in range(epoch_num):
train_py_reader.start()
epoch_end_time = time.time()
if epoch_id > 0:
print("Epoch {}, total time {}".format(epoch_id - 1, epoch_end_time
- epoch_start_time))
epoch_start_time = epoch_end_time
epoch_end_time
start_time = time.time()
step_id = 0
try:
while True:
prev_start_time = start_time
start_time = time.time()
loss_v, acc_v = train_exe.run(
fetch_list=[v.name for v in train_fetch_list])
if np.isnan(np.array(loss_v).mean()):
format_str = "[%s] jobs done, step = %d, loss = nan"
print(format_str % (utils.stime(), step_id))
return np.array(acc_v).mean(), epoch_id
print("Epoch {}, Step {}, loss {}, acc {}, time {}".format(
epoch_id, step_id, np.array(loss_v).mean(),
np.array(acc_v).mean(), start_time - prev_start_time))
step_id += 1
sys.stdout.flush()
except fluid.core.EOFException:
train_py_reader.reset()
return np.array(acc_v).mean(), epoch_id
def test(model, test_prog, test_vars, exe, data_list):
"""
test
"""
test_py_reader, loss_test, acc_test = test_vars
test_prog = test_prog.clone(for_test=True)
objs = utils.AvgrageMeter()
test_reader = reader.test10(FLAGS.data_path, FLAGS.batch_size, data_list)
test_py_reader.decorate_paddle_reader(test_reader)
test_py_reader.start()
test_fetch_list = [acc_test]
test_start_time = time.time()
step_id = 0
try:
while True:
prev_test_start_time = test_start_time
test_start_time = time.time()
acc_v, = exe.run(
test_prog, fetch_list=test_fetch_list)
objs.update(np.array(acc_v), np.array(acc_v).shape[0])
step_id += 1
except fluid.core.EOFException:
test_py_reader.reset()
print("test acc {0}".format(objs.avg))
return objs.avg
if __name__ == '__main__':
app.run(main)
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
"""
Utils Definition
"""
import os
import pickle
import time
from absl import flags
FLAGS = flags.FLAGS
def stime():
"""
stime
"""
return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
def load_action(mid):
"""
load action by mid
"""
filename = os.path.join(FLAGS.logdir, "actions", "%d.pkl" % mid)
return pickle.load(open(filename, "rb"))
def dump_action(mid, action):
"""
dump action
"""
filename = os.path.join(FLAGS.logdir, "actions", "%d.pkl" % mid)
pickle.dump(action, open(filename, "wb"))
def dump_reward(mid, reward):
"""
dump reward
"""
filename = os.path.join(FLAGS.logdir, "rewards", "%d.pkl" % mid)
pickle.dump(reward, open(filename, "wb"))
class AvgrageMeter(object):
"""
AvgrageMeter for test
"""
def __init__(self):
"""
init
"""
self.reset()
def reset(self):
"""
reset
"""
self.avg = 0
self.sum = 0
self.cnt = 0
def update(self, val, n=1):
"""
update
"""
self.sum += val * n
self.cnt += n
self.avg = self.sum / self.cnt
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
AutoDL main definition
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import traceback
import autodl
if __name__ == "__main__":
try:
autodl_exe = autodl.AutoDL()
autodl_exe.run()
except Exception as e:
print(str(e))
traceback.print_exc()
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PolicyModel definition
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
import paddle.fluid as fluid
from parl.framework.model_base import Model
import distribute_generator
class LstmUnit(object):
"""
implemetation of lstm unit
"""
def __init__(self, input_size, hidden_size, num_layers=1,
init_scale=0.1):
"""
init
"""
self.weight_1_arr = []
self.bias_1_arr = []
for i in range(num_layers):
weight_1 = fluid.layers.create_parameter(
[input_size + hidden_size, hidden_size * 4],
dtype="float32",
name="fc_weight1_" + str(i),
default_initializer=fluid.initializer.UniformInitializer(
low=-init_scale,
high=init_scale))
input_size = hidden_size
self.weight_1_arr.append(weight_1)
bias_1 = fluid.layers.create_parameter(
[hidden_size * 4],
dtype="float32",
name="fc_bias1_" + str(i),
default_initializer=fluid.initializer.Constant(0.0))
self.bias_1_arr.append(bias_1)
self.num_layers = num_layers
self.hidden_size = hidden_size
def lstm_step(self, inputs, hidden, cell):
"""
lstm step
"""
hidden_array = []
cell_array = []
for i in range(self.num_layers):
hidden_temp = fluid.layers.slice(hidden, axes=[0], starts=[i],
ends=[i + 1])
hidden_temp = fluid.layers.reshape(hidden_temp,
shape=[-1, self.hidden_size])
hidden_array.append(hidden_temp)
cell_temp = fluid.layers.slice(cell, axes=[0], starts=[i],
ends=[i + 1])
cell_temp = fluid.layers.reshape(cell_temp,
shape=[-1, self.hidden_size])
cell_array.append(cell_temp)
last_hidden_array = []
step_input = inputs
for k in range(self.num_layers):
pre_hidden = hidden_array[k]
pre_cell = cell_array[k]
weight = self.weight_1_arr[k]
bias = self.bias_1_arr[k]
nn = fluid.layers.concat([step_input, pre_hidden], 1)
gate_input = fluid.layers.matmul(x=nn, y=weight)
gate_input = fluid.layers.elementwise_add(gate_input, bias)
i, j, f, o = fluid.layers.split(gate_input, num_or_sections=4,
dim=-1)
c = pre_cell * fluid.layers.sigmoid(f) + fluid.layers.sigmoid(i) \
* fluid.layers.tanh(j)
m = fluid.layers.tanh(c) * fluid.layers.sigmoid(o)
hidden_array[k] = m
cell_array[k] = c
step_input = m
last_hidden = fluid.layers.concat(hidden_array, axis=0)
last_hidden = fluid.layers.reshape(last_hidden, shape=[
self.num_layers, -1, self.hidden_size])
last_cell = fluid.layers.concat(cell_array, axis=0)
last_cell = fluid.layers.reshape(
last_cell,
shape=[self.num_layers, -1, self.hidden_size])
return step_input, last_hidden, last_cell
def __call__(self, inputs, hidden, cell):
"""
lstm step call
"""
return self.lstm_step(inputs, hidden, cell)
class PolicyModel(Model):
"""
PolicyModel
"""
def __init__(self, parser_args):
"""
construct rnn net
"""
self.parser_args = parser_args
def policy(self, inputs):
"""
policy function is used by `define_predict` in PolicyGradient
"""
[tokens, softmax, adjvec, sigmoid] = self.build_rnn(inputs)
return [tokens, softmax, adjvec, sigmoid]
def build_rnn(self, inputs):
"""
build rnn net
"""
batch_size = self.parser_args.batch_size
input_size = self.parser_args.input_size
hidden_size = self.parser_args.hidden_size
num_layers = self.parser_args.num_layers
num_nodes = self.parser_args.num_nodes
num_tokens = self.parser_args.num_tokens
depth = max(num_nodes - 1, num_tokens)
lstm_unit = LstmUnit(input_size, hidden_size, num_layers)
def encode_token(inp):
"""
encode token
"""
token = fluid.layers.assign(inp)
token.stop_gradient = True
token = fluid.layers.one_hot(token, depth)
return token
def encode_adj(adj, step):
"""
encode adj
"""
adj = fluid.layers.cast(adj, dtype='float32')
adj_pad = fluid.layers.pad(x=adj, paddings=[0, 0, 0, depth - step],
pad_value=0.0)
return adj_pad
def decode_token(hidden):
"""
decode token
"""
initiallizer = fluid.initializer.TruncatedNormalInitializer(
scale=np.sqrt(2.0 / self.parser_args.hidden_size))
param_attr = fluid.ParamAttr(initializer=initiallizer)
logits = fluid.layers.fc(hidden, num_tokens, param_attr=param_attr)
temper = 5.0
tanh_c = 2.5
logits = fluid.layers.tanh(logits / temper) * tanh_c
token = distribute_generator.multinomial(logits,
[batch_size, 1], 1)
return token, fluid.layers.unsqueeze(logits, axes=[1])
def decode_adj(hidden, step):
"""
decode adj
"""
initiallizer = fluid.initializer.TruncatedNormalInitializer(
scale=np.sqrt(2.0 / self.parser_args.hidden_size))
param_attr = fluid.ParamAttr(initializer=initiallizer)
logits = fluid.layers.fc(hidden, step, param_attr=param_attr)
temper = 5.0
tanh_c = 2.5
logits = fluid.layers.tanh(logits / temper) * tanh_c
adj = distribute_generator.bernoulli(logits,
output_shape=logits.shape)
return adj, logits
tokens = []
softmax = []
adjvec = []
sigmoid = []
def rnn_block(hidden, last_hidden, last_cell):
"""
rnn block
"""
last_output, last_hidden, last_cell = lstm_unit(
hidden, last_hidden, last_cell)
token, logits = decode_token(last_output)
tokens.append(token)
softmax.append(logits)
for step in range(1, num_nodes):
token_vec = encode_token(token)
last_output, last_hidden, last_cell = lstm_unit(
token_vec, last_hidden, last_cell)
adj, logits = decode_adj(last_output, step)
adjvec.append(adj)
sigmoid.append(logits)
adj_vec = encode_adj(adj, step)
last_output, last_hidden, last_cell = lstm_unit(
adj_vec, last_hidden, last_cell)
token, logits = decode_token(last_output)
tokens.append(token)
softmax.append(logits)
return token, last_hidden, last_cell
init_hidden = fluid.layers.fill_constant(
shape=[num_layers, batch_size, hidden_size],
value=0.0, dtype='float32')
init_cell = fluid.layers.fill_constant(
shape=[num_layers, batch_size, hidden_size],
value=0.0, dtype='float32')
hidden = encode_adj(inputs, 1)
token, last_hidden, last_cell = rnn_block(hidden, init_hidden,
init_cell)
hidden = encode_token(token)
token, last_hidden, last_cell = rnn_block(hidden, last_hidden,
last_cell)
token_out = fluid.layers.concat(tokens, axis=1)
softmax_out = fluid.layers.concat(softmax, axis=1)
adjvec_out = fluid.layers.concat(adjvec, axis=1)
sigmoid_out = fluid.layers.concat(sigmoid, axis=1)
return [token_out, softmax_out, adjvec_out, sigmoid_out]
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
AutoDL definition
"""
import paddle.fluid as fluid
from parl.framework.algorithm_base import Algorithm
import paddle.fluid.layers as layers
import os
import sys
class ReinforcePolicyGradient(Algorithm):
"""
Implement REINFORCE policy gradient for autoDL
"""
def __init__(self, model, hyperparas):
"""
"""
Algorithm.__init__(self, model, hyperparas)
self.model = model
self.lr = hyperparas['lr']
def define_predict(self, obs):
"""
use policy model self.model to predict the action probability
obs is `inputs`
"""
with fluid.unique_name.guard():
[tokens, softmax, adjvec, sigmoid] = self.model.policy(obs)
return tokens, adjvec
def define_learn(self, obs, action, reward):
"""
update policy model self.model with policy gradient algorithm
obs is `inputs`
"""
tokens = action[0]
adjvec = action[1]
with fluid.unique_name.guard():
[_, softmax, _, sigmoid] = self.model.policy(obs)
reshape_softmax = layers.reshape(
softmax,
[-1, self.model.parser_args.num_tokens])
reshape_tokens = layers.reshape(tokens, [-1, 1])
reshape_tokens.stop_gradient = True
raw_neglogp_to = layers.softmax_with_cross_entropy(
soft_label=False,
logits=reshape_softmax,
label=fluid.layers.cast(x=reshape_tokens, dtype="int64"))
action_to_shape_sec = self.model.parser_args.num_nodes * 2
neglogp_to = layers.reshape(fluid.layers.cast(
raw_neglogp_to, dtype="float32"),
[-1, action_to_shape_sec])
adjvec = layers.cast(x=adjvec, dtype='float32')
neglogp_ad = layers.sigmoid_cross_entropy_with_logits(
x=sigmoid, label=adjvec)
neglogp = layers.elementwise_add(
x=layers.reduce_sum(neglogp_to, dim=1),
y=layers.reduce_sum(neglogp_ad, dim=1))
reward = layers.cast(reward, dtype="float32")
cost = layers.reduce_mean(
fluid.layers.elementwise_mul(x=neglogp, y=reward))
optimizer = fluid.optimizer.Adam(learning_rate=self.lr)
train_op = optimizer.minimize(cost)
return cost
export FLAGS_fraction_of_gpu_memory_to_use=0.98
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fast_eager_deletion_mode=1
CUDA_VISIBLE_DEVICES=0 python -u main.py > main.log 2>&1 &
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
AutoDL main definition
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import traceback
import autodl
if __name__ == "__main__":
try:
autodl_exe = autodl.AutoDL()
autodl_exe.simple_run()
except Exception as e:
print(str(e))
traceback.print_exc()
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
AutoDL definition
"""
import os
import time
import pickle
def stime():
"""
stime
"""
return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
def prepare(log_dir, category=""):
"""
prepare directory
"""
subdir = os.path.join(log_dir, category)
if not os.path.exists(subdir):
os.mkdir(subdir)
def dump_action(mid, action, log_dir):
"""
dump action
"""
filename = os.path.join(log_dir, "actions", "%d.pkl" % mid)
pickle.dump(action, open(filename, "wb"))
def load_action(mid, log_dir):
"""
load action
"""
filename = os.path.join(log_dir, "actions", "%d.pkl" % mid)
return pickle.load(open(filename, "rb"))
def dump_reward(mid, reward, log_dir):
"""
dump reward
"""
filename = os.path.join(log_dir, "rewards", "%d.pkl" % mid)
pickle.dump(reward, open(filename, "wb"))
def load_reward(mid, log_dir):
"""
load reward
"""
filename = os.path.join(log_dir, "rewards", "%d.pkl" % mid)
return pickle.load(open(filename, "rb"))
# LRC Local Rademachar Complexity Regularization
Regularization of Deep Neural Networks(DNNs) for the sake of improving their generalization capability is important and chllenging. This directory contains image classification model based on a novel regularizer rooted in Local Rademacher Complexity (LRC). We appreciate the contribution by [DARTS](https://arxiv.org/abs/1806.09055) for our research. The regularization by LRC and DARTS are combined in this model on CIFAR-10 dataset. Code accompanying the paper
Regularization of Deep Neural Networks(DNNs) for the sake of improving their generalization capability is important and chllenging. This directory contains image classification model based on a novel regularizer rooted in Local Rademacher Complexity (LRC). We appreciate the contribution by [DARTS](https://arxiv.org/abs/1806.09055) for our research. The regularization by LRC and DARTS are combined in this model to reach accuracy of 98.01% on CIFAR-10 dataset. Code accompanying the paper
> [An Empirical Study on Regularization of Deep Neural Networks by Local Rademacher Complexity](https://arxiv.org/abs/1902.00873)\
> Yingzhen Yang, Xingjian Li, Jun Huan.\
> _arXiv:1902.00873_.
......@@ -7,13 +7,21 @@ Regularization of Deep Neural Networks(DNNs) for the sake of improving their gen
---
# Table of Contents
- [Introduction of algorithm](#introduction-of-algorithm)
- [Installation](#installation)
- [Data preparation](#data-preparation)
- [Training](#training)
- [Testing](#testing)
- [Experimental result](#experimental-result)
- [Reference](#reference)
## Introduction of algorithm
Rademacher complexity is well known as a distribution-free complexity measure of function class and LRC focus on a restricted function class which leads to sharper convergence rates and potential better generalization. Our LRC based regularizer is developed by estimating the complexity of the function class centered at the minimizer of the empirical loss of DNNs.
## Installation
Running sample code in this directory requires PaddelPaddle Fluid v.1.2.0 and later. If the PaddlePaddle on your device is lower than this version, please follow the instructions in [installation document](http://www.paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html#paddlepaddle) and make an update.
Running sample code in this directory requires PaddelPaddle Fluid v.1.3.0 and later. If the PaddlePaddle on your device is lower than this version, please follow the instructions in [installation document](http://www.paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/install/index_cn.html#paddlepaddle) and make an update.
## Data preparation
......@@ -30,13 +38,8 @@ The dataset will be downloaded to `dataset/cifar/cifar-10-batches-py` in the sam
After data preparation, one can start the training step by:
python -u train_mixup.py \
--batch_size=80 \
--auxiliary \
--weight_decay=0.0003 \
--learning_rate=0.025 \
--lrc_loss_lambda=0.7 \
--cutout
sh run_cifar.sh
- Set ```export CUDA_VISIBLE_DEVICES=0``` to specifiy one GPU to train.
- For more help on arguments:
......@@ -44,7 +47,7 @@ After data preparation, one can start the training step by:
**data reader introduction:**
* Data reader is defined in `reader.py`.
* Data reader is defined in `reader_cifar.py`.
* Reshape the images to 32 * 32.
* In training stage, images are padding to 40 * 40 and cropped randomly to the original size.
* In training stage, images are horizontally random flipped.
......@@ -54,19 +57,40 @@ After data preparation, one can start the training step by:
**model configuration:**
* Use auxiliary loss and auxiliary\_weight=0.4.
* Use dropout and drop\_path\_prob=0.2.
* Set lrc\_loss\_lambda=0.7.
**training strategy:**
* Use momentum optimizer with momentum=0.9.
* Weight decay is 0.0003.
* Use cosine decay with init\_lr=0.025.
* Total epoch is 600.
* Use Xaiver initalizer to weight in conv2d, Constant initalizer to weight in batch norm and Normal initalizer to weight in fc.
* Initalize bias in batch norm and fc to zero constant and do not add bias to conv2d.
* Use global L2 norm to clip gradient.
* Other configurations are set in `run_cifar.sh`
## Tesing
one can start the testing step by:
sh run_cifar_test.sh
- Set ```export CUDA_VISIBLE_DEVICES=0``` to specifiy one GPU to train.
- For more help on arguments:
python test_mixup.py --help
After obtaining six models, one can get ensembled model by:
python voting.py
## Experimental result
Experimental result is shown as below:
| Model | based lr | batch size | model id | acc-1 |
| :--------------- | :--------: | :------------: | :------------------: |------: |
| [model_0](https://paddlemodels.bj.bcebos.com/autodl/lrc_model_0.tar.gz) | 0.01 | 64 | 0 | 97.12% |
| [model_1](https://paddlemodels.bj.bcebos.com/autodl/lrc_model_1.tar.gz) | 0.02 | 80 | 0 | 97.34% |
| [model_2](https://paddlemodels.bj.bcebos.com/autodl/lrc_model_2.tar.gz) | 0.015 | 80 | 1 | 97.31% |
| [model_3](https://paddlemodels.bj.bcebos.com/autodl/lrc_model_3.tar.gz) | 0.02 | 80 | 1 | 97.52% |
| [model_4](https://paddlemodels.bj.bcebos.com/autodl/lrc_model_4.tar.gz) | 0.03 | 80 | 1 | 97.30% |
| [model_5](https://paddlemodels.bj.bcebos.com/autodl/lrc_model_5.tar.gz) | 0.015 | 64 | 2 | 97.32% |
ensembled model acc-1=98.01%
## Reference
......
# LRC 局部Rademachar复杂度正则化
为了在深度神经网络中提升泛化能力,正则化的选择十分重要也具有挑战性。本目录包括了一种基于局部rademacher复杂度的新型正则(LRC)的图像分类模型。十分感谢[DARTS](https://arxiv.org/abs/1806.09055)模型对本研究提供的帮助。该模型将LRC正则和DARTS网络相结合,在CIFAR-10数据集中得到了很出色的效果。代码和文章一同发布
为了在深度神经网络中提升泛化能力,正则化的选择十分重要也具有挑战性。本目录包括了一种基于局部rademacher复杂度的新型正则(LRC)的图像分类模型。十分感谢[DARTS](https://arxiv.org/abs/1806.09055)模型对本研究提供的帮助。该模型将LRC正则和DARTS网络相结合,在CIFAR-10数据集中得到了98.01%的准确率。代码和文章一同发布
> [An Empirical Study on Regularization of Deep Neural Networks by Local Rademacher Complexity](https://arxiv.org/abs/1902.00873)\
> Yingzhen Yang, Xingjian Li, Jun Huan.\
> _arXiv:1902.00873_.
......@@ -7,13 +7,21 @@
---
# 内容
- [算法简介](#算法简介)
- [安装](#安装)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型测试](#模型测试)
- [实验结果](#实验结果)
- [引用](#引用)
## 算法简介
局部拉德马赫复杂度方法借鉴了已有的局部拉德马赫复杂度方法,仅考虑在经验损失函数的极小值点附近的一个球内的拉德马赫复杂度。采用最近的拉德马赫复杂度的估计方法,对折页损失函数 (Hinge Loss) 和交叉熵(cross entropy)推得了这个固定值的表达式,并且将其称之为局部拉德马赫正则化项,并加在经验损失函数上。将正则化方法作用在混合和模型集成之后,得到了CIFAR-10上目前最好的准确率。
## 安装
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.2.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html#paddlepaddle)中的说明来更新PaddlePaddle。
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.3.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/install/index_cn.html#paddlepaddle)中的说明来更新PaddlePaddle。
## 数据准备
......@@ -21,27 +29,22 @@
sh ./dataset/download.sh
请确保您的环境有互联网连接。数据会下载到`train.py`同目录下的`dataset/cifar/cifar-10-batches-py`。如果下载失败,您可以自行从https://www.cs.toronto.edu/~kriz/cifar.html上下载cifar-10-python.tar.gz并解压到上述位置。
请确保您的环境有互联网连接。数据会下载到`train.py`同目录下的`dataset/cifar/cifar-10-batches-py`。如果下载失败,您可以自行从 https://www.cs.toronto.edu/~kriz/cifar.html 上下载cifar-10-python.tar.gz并解压到上述位置。
## 模型训练
数据准备好后,可以通过如下命令开始训练:
python -u train_mixup.py \
--batch_size=80 \
--auxiliary \
--weight_decay=0.0003 \
--learning_rate=0.025 \
--lrc_loss_lambda=0.7 \
--cutout
- 通过设置 ```export CUDA_VISIBLE_DEVICES=0```指定单张GPU训练。
sh run_cifar.sh
-```run_cifar.sh```中通过设置 ```export CUDA_VISIBLE_DEVICES=0```指定GPU卡号进行训练。
- 可选参数见:
python train_mixup.py --help
**数据读取器说明:**
* 数据读取器定义在`reader.py`
* 数据读取器定义在`reader_cifar.py`
* 输入图像尺寸统一变换为32 * 32
* 训练时将图像填充为40 * 40然后随机剪裁为原输入图像大小
* 训练时图像随机水平翻转
......@@ -51,19 +54,41 @@
**模型配置:**
* 使用辅助损失,辅助损失权重为0.4
* 使用dropout,随机丢弃率为0.2
* 设置lrc\_loss\_lambda为0.7
**训练策略:**
* 采用momentum优化算法训练,momentum=0.9
* 权重衰减系数为0.0001
* 采用正弦学习率衰减,初始学习率为0.025
* 总共训练600轮
* 对卷积权重采用Xaiver初始化,对batch norm权重采用固定初始化,对全连接层权重采用高斯初始化
* 对batch norm和全连接层偏差采用固定初始化,不对卷积设置偏差
* 对梯度采用全局L2范数裁剪
* 其余模型配置在run_cifar.sh中
## 模型测试
可以通过如下命令开始测试:
sh run_cifar_test.sh
-```run_cifar_test.sh```中通过设置 ```export CUDA_VISIBLE_DEVICES=0```指定GPU卡号进行训练。
- 可选参数见:
python test_mixup.py --help
得到六个模型后运行如下脚本得到融合模型:
python voting.py
## 实验结果
下表为模型评估结果:
| 模型 | 初始学习率 | 批量大小 | 模型编号 | acc-1 |
| :--------------- | :--------: | :------------: | :------------------: |------: |
| [model_0](https://paddlemodels.bj.bcebos.com/autodl/lrc_model_0.tar.gz) | 0.01 | 64 | 0 | 97.12% |
| [model_1](https://paddlemodels.bj.bcebos.com/autodl/lrc_model_1.tar.gz) | 0.02 | 80 | 0 | 97.34% |
| [model_2](https://paddlemodels.bj.bcebos.com/autodl/lrc_model_2.tar.gz) | 0.015 | 80 | 1 | 97.31% |
| [model_3](https://paddlemodels.bj.bcebos.com/autodl/lrc_model_3.tar.gz) | 0.02 | 80 | 1 | 97.52% |
| [model_4](https://paddlemodels.bj.bcebos.com/autodl/lrc_model_4.tar.gz) | 0.03 | 80 | 1 | 97.30% |
| [model_5](https://paddlemodels.bj.bcebos.com/autodl/lrc_model_5.tar.gz) | 0.015 | 64 | 2 | 97.32% |
融合模型acc-1=98.01%
## 引用
......
......@@ -114,9 +114,33 @@ MY_DARTS = Genotype(
reduce_concat=range(2, 6))
MY_DARTS_list = [
Genotype(normal=[('sep_conv_3x3', 0), ('skip_connect', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 2)],normal_concat=range(2, 6), reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2),('max_pool_3x3', 0), ('skip_connect', 3), ('avg_pool_3x3', 1), ('skip_connect', 2), ('skip_connect', 3)], reduce_concat=range(2, 6)),
Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('dil_conv_3x3', 2), ('skip_connect', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 1)],normal_concat=range(2, 6), reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2),('dil_conv_3x3', 0), ('skip_connect', 3), ('skip_connect', 2), ('skip_connect', 3), ('skip_connect',2)], reduce_concat=range(2, 6)),
Genotype(normal=[('sep_conv_3x3', 0), ('skip_connect', 1), ('skip_connect', 0), ('dil_conv_5x5', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('skip_connect', 3)], reduce_concat=range(2, 6))
Genotype(
normal=[('sep_conv_3x3', 0), ('skip_connect', 1), ('sep_conv_3x3', 0),
('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1),
('skip_connect', 0), ('sep_conv_3x3', 2)],
normal_concat=range(2, 6),
reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2),
('max_pool_3x3', 0), ('skip_connect', 3), ('avg_pool_3x3', 1),
('skip_connect', 2), ('skip_connect', 3)],
reduce_concat=range(2, 6)),
Genotype(
normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('skip_connect', 0),
('dil_conv_3x3', 2), ('skip_connect', 0), ('sep_conv_3x3', 1),
('skip_connect', 0), ('skip_connect', 1)],
normal_concat=range(2, 6),
reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2),
('dil_conv_3x3', 0), ('skip_connect', 3), ('skip_connect', 2),
('skip_connect', 3), ('skip_connect', 2)],
reduce_concat=range(2, 6)),
Genotype(
normal=[('sep_conv_3x3', 0), ('skip_connect', 1), ('skip_connect', 0),
('dil_conv_5x5', 1), ('skip_connect', 0), ('sep_conv_3x3', 1),
('skip_connect', 0), ('sep_conv_3x3', 1)],
normal_concat=range(2, 6),
reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('max_pool_3x3', 0),
('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2),
('skip_connect', 2), ('skip_connect', 3)],
reduce_concat=range(2, 6))
]
DARTS = MY_DARTS_list[0]
文件已添加
......@@ -42,12 +42,12 @@ def cosine_decay(learning_rate, num_epoch, steps_one_epoch):
* math.pi / num_epoch) + 1)/2
return decayed_lr
def cosine_with_warmup_decay(learning_rate, lr_min, steps_one_epoch,
warmup_epochs, total_epoch, num_gpu):
def cosine_with_warmup_decay(learning_rate, lr_min, steps_one_epoch,
warmup_epochs, total_epoch, num_gpu):
global_step = _decay_step_counter()
epoch_idx = fluid.layers.floor(global_step / steps_one_epoch)
lr = fluid.layers.create_global_var(
shape=[1],
value=0.0,
......@@ -59,16 +59,17 @@ def cosine_with_warmup_decay(learning_rate, lr_min, steps_one_epoch,
shape=[1], dtype='float32', value=float(warmup_epochs), force_cpu=True)
num_gpu_var = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(num_gpu), force_cpu=True)
batch_idx = global_step - steps_one_epoch * epoch_idx
batch_idx = global_step - steps_one_epoch * epoch_idx
with fluid.layers.control_flow.Switch() as switch:
with switch.case(epoch_idx < warmup_epoch_var):
epoch_ = (batch_idx + 1) / steps_one_epoch
factor = 1 / num_gpu_var * (epoch_ * (num_gpu_var - 1) / warmup_epoch_var + 1)
factor = 1 / num_gpu_var * (
epoch_ * (num_gpu_var - 1) / warmup_epoch_var + 1)
decayed_lr = learning_rate * factor * num_gpu_var
fluid.layers.assign(decayed_lr, lr)
epoch_ = (batch_idx + 1) / steps_one_epoch
m = epoch_ / total_epoch
m = epoch_ / total_epoch
frac = (1 + ops.cos(math.pi * m)) / 2
cosine_lr = (lr_min + (learning_rate - lr_min) * frac) * num_gpu_var
with switch.default():
......
......@@ -226,7 +226,6 @@ class NetworkCIFAR(object):
name='test_reader')
return py_reader
def forward(self, init_channel, is_train):
self.training = is_train
self.logits_aux = None
......@@ -246,7 +245,7 @@ class NetworkCIFAR(object):
initializer=Normal(scale=1e-3),
name='classifier.weight'),
bias_attr=ParamAttr(
initializer=Constant(0,),
initializer=Constant(0),
name='classifier.bias'))
return self.logits, self.logits_aux
......@@ -255,8 +254,6 @@ class NetworkCIFAR(object):
self.non_label_reshape, self.rad_var = fluid.layers.read_file(py_reader)
self.logits, self.logits_aux = self.forward(init_channels, True)
self.mixup_loss = self.mixup_loss(aux, aux_w)
#self.lrc_loss = self.lrc_loss()
#return self.mixup_loss + loss_lambda * self.lrc_loss
return self.mixup_loss
def test_model(self, py_reader, init_channels):
......@@ -301,8 +298,7 @@ class NetworkCIFAR(object):
y_diff_label = fluid.layers.reshape(
y_diff_label_reshape, shape=(1, -1, 1))
y_diff_non_label = fluid.layers.reshape(
y_diff_non_label_reshape,
shape=(1, -1, self.class_num - 1))
y_diff_non_label_reshape, shape=(1, -1, self.class_num - 1))
y_diff_ = y_diff_non_label - y_diff_label
y_diff_ = fluid.layers.transpose(y_diff_, perm=[1, 2, 0])
......@@ -318,6 +314,7 @@ class NetworkCIFAR(object):
return lrc_loss_mean
def AuxiliaryHeadImageNet(input, num_classes, aux_name='auxiliary_head'):
relu_a = fluid.layers.relu(input)
pool_a = fluid.layers.pool2d(relu_a, 5, 'avg', pool_stride=3)
......@@ -376,7 +373,7 @@ def Stem0Conv(input, C_out):
initializer=Xavier(
uniform=False, fan_in=0), name='stem0.0.weight'),
bias_attr=False)
bn_a = fluid.layers.batch_norm(
relu_a = fluid.layers.batch_norm(
conv_a,
param_attr=ParamAttr(
initializer=Constant(1.), name='stem0.1.weight'),
......@@ -385,9 +382,8 @@ def Stem0Conv(input, C_out):
moving_mean_name='stem0.1.running_mean',
moving_variance_name='stem0.1.running_var',
act='relu')
#relu_a = fluid.layers.relu(bn_a,inplace=True)
conv_b = fluid.layers.conv2d(
bn_a,
relu_a,
C_out,
3,
stride=2,
......@@ -407,6 +403,7 @@ def Stem0Conv(input, C_out):
return bn_b
def Stem1Conv(input, C_out):
relu_a = fluid.layers.relu(input)
conv_a = fluid.layers.conv2d(
......@@ -429,6 +426,7 @@ def Stem1Conv(input, C_out):
moving_variance_name='stem1.2.running_var')
return bn_a
class NetworkImageNet(object):
def __init__(self, C, class_num, layers, genotype):
self.class_num = class_num
......@@ -461,8 +459,7 @@ class NetworkImageNet(object):
capacity=64,
shapes=[[-1] + image_shape, [-1, 1]],
lod_levels=[0, 0],
dtypes=[
"float32", "int64"],
dtypes=["float32", "int64"],
use_double_buffer=True,
name='train_reader')
else:
......@@ -475,7 +472,6 @@ class NetworkImageNet(object):
name='test_reader')
return py_reader
def forward(self, is_train):
self.training = is_train
self.logits_aux = None
......@@ -495,7 +491,7 @@ class NetworkImageNet(object):
initializer=Normal(scale=1e-3),
name='classifier.weight'),
bias_attr=ParamAttr(
initializer=Constant(0,),
initializer=Constant(0),
name='classifier.bias'))
return self.logits, self.logits_aux
......@@ -504,10 +500,6 @@ class NetworkImageNet(object):
loss = fluid.layers.cross_entropy(prob, self.label)
loss_mean = fluid.layers.reduce_mean(loss)
#if auxiliary:
# prob_aux = fluid.layers.softmax(self.logits_aux, use_cudnn=False)
# loss_aux = fluid.layers.cross_entropy(prob_aux, self.label)
# loss_aux_mean = fluid.layers.reduce_mean(loss_aux)
prob_aux = fluid.layers.softmax(self.logits_aux, use_cudnn=False)
loss_aux = fluid.layers.cross_entropy(prob_aux, self.label)
loss_aux_mean = fluid.layers.reduce_mean(loss_aux)
......@@ -527,4 +519,3 @@ class NetworkImageNet(object):
acc_1 = fluid.layers.accuracy(self.logits, self.label, k=1)
acc_5 = fluid.layers.accuracy(self.logits, self.label, k=5)
return prob, acc_1, acc_5
......@@ -312,7 +312,8 @@ def FactorizedReduce(input, C_out, name='', affine=True):
bias_attr=False)
h_end = relu_a.shape[2]
w_end = relu_a.shape[3]
slice_a = fluid.layers.slice(input=relu_a, axes=[2, 3], starts=[1, 1], ends=[h_end, w_end])
slice_a = fluid.layers.slice(
input=relu_a, axes=[2, 3], starts=[1, 1], ends=[h_end, w_end])
conv2d_b = fluid.layers.conv2d(
slice_a,
C_out // 2,
......
......@@ -126,7 +126,10 @@ def reader_creator_filepath(filename, sub_name, is_training, args):
datasets = []
for name in names:
print("Reading file " + name)
batch = pickle.load(open(filename + name, 'rb'))
try:
batch = pickle.load(open(filename + name, 'rb'), encode='latin1')
except TypeError:
batch = pickle.load(open(filename + name, 'rb'))
data = batch['data']
labels = batch.get('labels', batch.get('fine_labels', None))
assert labels is not None
......@@ -177,8 +180,7 @@ def reader_creator_filepath(filename, sub_name, is_training, args):
generate_reshape_label(batch_label, len(batch_data))
rad_var = generate_bernoulli_number(len(batch_data))
mixed_x, y_a, y_b, lam = utils.mixup_data(
batch_data, batch_label, len(batch_data),
args.mix_alpha)
batch_data, batch_label, len(batch_data), args.mix_alpha)
batch_out = [[mixed_x, y_a, y_b, lam, flatten_label, \
flatten_non_label, rad_var]]
yield batch_out
......
export FLAGS_fraction_of_gpu_memory_to_use=0.9
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fast_eager_deletion_mode=1
nohup env CUDA_VISIBLE_DEVICES=0 python -u train_mixup.py --batch_size=64 --auxiliary --mix_alpha=0.9 --model_id=0 --cutout --lrc_loss_lambda=0.5 --weight_decay=0.0002 --learning_rate=0.01 --save_model_path=model_0 > lrc_model_0.log 2>&1 &
nohup env CUDA_VISIBLE_DEVICES=1 python -u train_mixup.py --batch_size=64 --auxiliary --mix_alpha=0.6 --model_id=0 --cutout --lrc_loss_lambda=0.5 --weight_decay=0.0002 --learning_rate=0.02 --save_model_path=model_1 > lrc_model_1.log 2>&1 &
nohup env CUDA_VISIBLE_DEVICES=2 python -u train_mixup.py --batch_size=80 --auxiliary --mix_alpha=0.5 --model_id=1 --cutout --lrc_loss_lambda=0.5 --weight_decay=0.0002 --learning_rate=0.015 --save_model_path=model_2 > lrc_model_2.log 2>&1 &
nohup env CUDA_VISIBLE_DEVICES=3 python -u train_mixup.py --batch_size=80 --auxiliary --mix_alpha=0.6 --model_id=1 --cutout --lrc_loss_lambda=0.5 --weight_decay=0.0002 --learning_rate=0.02 --save_model_path=model_3 > lrc_model_3.log 2>&1 &
nohup env CUDA_VISIBLE_DEVICES=4 python -u train_mixup.py --batch_size=80 --auxiliary --mix_alpha=0.8 --model_id=1 --cutout --lrc_loss_lambda=0.5 --weight_decay=0.0002 --learning_rate=0.03 --save_model_path=model_4 > lrc_model_4.log 2>&1 &
nohup env CUDA_VISIBLE_DEVICES=5 python -u train_mixup.py --batch_size=64 --auxiliary --mix_alpha=0.5 --model_id=2 --cutout --lrc_loss_lambda=0.5 --weight_decay=0.0002 --learning_rate=0.015 --save_model_path=model_5 > lrc_model_5.log 2>&1 &
export FLAGS_fraction_of_gpu_memory_to_use=0.6
nohup env CUDA_VISIBLE_DEVICES=0 python -u test_mixup.py --batch_size=64 --auxiliary --model_id=0 --pretrained_model=model_0/final/ --dump_path=paddle_predict/prob_test_0.pkl > lrc_test_0.log 2>&1 &
nohup env CUDA_VISIBLE_DEVICES=1 python -u test_mixup.py --batch_size=64 --auxiliary --model_id=0 --pretrained_model=model_1/final/ --dump_path=paddle_predict/prob_test_1.pkl > lrc_test_1.log 2>&1 &
nohup env CUDA_VISIBLE_DEVICES=2 python -u test_mixup.py --batch_size=80 --auxiliary --model_id=1 --pretrained_model=model_2/final/ --dump_path=paddle_predict/prob_test_2.pkl > lrc_test_2.log 2>&1 &
nohup env CUDA_VISIBLE_DEVICES=3 python -u test_mixup.py --batch_size=80 --auxiliary --model_id=1 --pretrained_model=model_3/final/ --dump_path=paddle_predict/prob_test_3.pkl > lrc_test_3.log 2>&1 &
nohup env CUDA_VISIBLE_DEVICES=4 python -u test_mixup.py --batch_size=80 --auxiliary --model_id=1 --pretrained_model=model_4/final/ --dump_path=paddle_predict/prob_test_4.pkl > lrc_test_4.log 2>&1 &
nohup env CUDA_VISIBLE_DEVICES=5 python -u test_mixup.py --batch_size=64 --auxiliary --model_id=2 --pretrained_model=model_5/final/ --dump_path=paddle_predict/prob_test_5.pkl > lrc_test_5.log 2>&1 &
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
#
# Based on:
# --------------------------------------------------------
# DARTS
# Copyright (c) 2018, Hanxiao Liu.
# Licensed under the Apache License, Version 2.0;
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from learning_rate import cosine_decay
import numpy as np
import argparse
from model import NetworkCIFAR as Network
import reader_cifar as reader
import sys
import os
import time
import logging
import genotypes
import paddle.fluid as fluid
import shutil
import utils
parser = argparse.ArgumentParser("cifar")
# yapf: disable
parser.add_argument('--data', type=str, default='./dataset/cifar/cifar-10-batches-py/', help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
parser.add_argument('--model_id', type=int, help='model id')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument( '--init_channels', type=int, default=36, help='num of init channels')
parser.add_argument( '--layers', type=int, default=20, help='total number of layers')
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
parser.add_argument('--pretrained_model', type=str, default='/model_0/final/', help='pretrained model to load')
parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use')
parser.add_argument('--dump_path', type=str, default='prob_test_0.pkl', help='dump path')
# yapf: enable
args = parser.parse_args()
CIFAR_CLASSES = 10
dataset_train_size = 50000
image_size = 32
genotypes.DARTS = genotypes.MY_DARTS_list[args.model_id]
print(genotypes.DARTS)
def main():
image_shape = [3, image_size, image_size]
devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
devices_num = len(devices.split(","))
logging.info("args = %s", args)
genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
args.auxiliary, genotype)
test(model, args, image_shape)
def build_program(args, is_train, model, im_shape):
out = []
py_reader = model.build_input(im_shape, is_train)
prob, acc_1, acc_5 = model.test_model(py_reader, args.init_channels)
out = [py_reader, prob, acc_1, acc_5]
return out
def test(model, args, im_shape):
test_py_reader, prob, acc_1, acc_5 = build_program(args, False, model,
im_shape)
test_prog = fluid.default_main_program().clone(for_test=True)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# yapf: disable
if args.pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
# yapf: enable
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 1
compile_program = fluid.compiler.CompiledProgram(
test_prog).with_data_parallel(exec_strategy=exec_strategy)
test_reader = reader.test10(args)
test_py_reader.decorate_paddle_reader(test_reader)
test_fetch_list = [prob, acc_1, acc_5]
prob = []
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
test_py_reader.start()
test_start_time = time.time()
step_id = 0
try:
while True:
prev_test_start_time = test_start_time
test_start_time = time.time()
prob_v, acc_1_v, acc_5_v = exe.run(compile_program,
test_prog,
fetch_list=test_fetch_list)
prob.append(list(np.array(prob_v)))
top1.update(np.array(acc_1_v), np.array(prob_v).shape[0])
top5.update(np.array(acc_5_v), np.array(prob_v).shape[0])
if step_id % args.report_freq == 0:
print('prob shape:', np.array(prob_v).shape)
print("Step {}, acc_1 {}, acc_5 {}, time {}".format(
step_id,
np.array(acc_1_v),
np.array(acc_5_v), test_start_time - prev_test_start_time))
step_id += 1
except fluid.core.EOFException:
test_py_reader.reset()
np.concatenate(prob).dump(args.dump_path)
print("top1 {0}, top5 {1}".format(top1.avg, top5.avg))
if __name__ == '__main__':
main()
......@@ -38,66 +38,30 @@ import utils
import math
parser = argparse.ArgumentParser("cifar")
parser.add_argument(
'--data',
type=str,
default='./dataset/cifar/cifar-10-batches-py/',
help='location of the data corpus')
# yapf: disable
parser.add_argument('--data', type=str, default='./dataset/cifar/cifar-10-batches-py/', help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
parser.add_argument(
'--pretrained_model', type=str, default='/save_models/599', help='pretrained model to load')
parser.add_argument('--pretrained_model', type=str, default=None, help='pretrained model to load')
parser.add_argument('--model_id', type=int, help='model id')
parser.add_argument(
'--learning_rate', type=float, default=0.025, help='init learning rate')
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument(
'--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument(
'--report_freq', type=float, default=50, help='report frequency')
parser.add_argument(
'--epochs', type=int, default=600, help='num of training epochs')
parser.add_argument(
'--init_channels', type=int, default=36, help='num of init channels')
parser.add_argument(
'--layers', type=int, default=20, help='total number of layers')
parser.add_argument(
'--save_model_path',
type=str,
default='saved_models',
help='path to save the model')
parser.add_argument(
'--auxiliary',
action='store_true',
default=False,
help='use auxiliary tower')
parser.add_argument(
'--auxiliary_weight',
type=float,
default=0.4,
help='weight for auxiliary loss')
parser.add_argument(
'--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument(
'--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument(
'--drop_path_prob', type=float, default=0.2, help='drop path probability')
parser.add_argument(
'--arch', type=str, default='DARTS', help='which architecture to use')
parser.add_argument(
'--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument(
'--lr_exp_decay',
action='store_true',
default=False,
help='use exponential_decay learning_rate')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--epochs', type=int, default=600, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
parser.add_argument('--layers', type=int, default=20, help='total number of layers')
parser.add_argument('--save_model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--lr_exp_decay', action='store_true', default=False, help='use exponential_decay learning_rate')
parser.add_argument('--mix_alpha', type=float, default=0.5, help='mixup alpha')
parser.add_argument(
'--lrc_loss_lambda', default=0, type=float, help='lrc_loss_lambda')
parser.add_argument(
'--loss_type',
default=1,
type=float,
help='loss_type 0: cross entropy 1: multi margin loss 2: max margin loss')
parser.add_argument('--lrc_loss_lambda', default=0, type=float, help='lrc_loss_lambda')
# yapf: enable
args = parser.parse_args()
......@@ -130,11 +94,10 @@ def build_program(main_prog, startup_prog, args, is_train, model, im_shape,
args.auxiliary, args.auxiliary_weight,
args.lrc_loss_lambda)
optimizer = fluid.optimizer.Momentum(
learning_rate=cosine_decay(args.learning_rate, \
args.epochs, steps_one_epoch),
regularization=fluid.regularizer.L2Decay(\
args.weight_decay),
momentum=args.momentum)
learning_rate=cosine_decay(args.learning_rate, args.epochs,
steps_one_epoch),
regularization=fluid.regularizer.L2Decay(args.weight_decay),
momentum=args.momentum)
optimizer.minimize(loss)
out = [py_reader, loss]
else:
......@@ -146,25 +109,32 @@ def build_program(main_prog, startup_prog, args, is_train, model, im_shape,
def train(model, args, im_shape, steps_one_epoch):
train_startup_prog = fluid.Program()
test_startup_prog = fluid.Program()
startup_prog = fluid.Program()
train_prog = fluid.Program()
test_prog = fluid.Program()
train_py_reader, loss_train = build_program(train_prog, train_startup_prog,
args, True, model, im_shape,
steps_one_epoch)
train_py_reader, loss_train = build_program(
train_prog, startup_prog, args, True, model, im_shape, steps_one_epoch)
test_py_reader, prob, acc_1, acc_5 = build_program(
test_prog, test_startup_prog, args, False, model, im_shape,
steps_one_epoch)
test_prog, startup_prog, args, False, model, im_shape, steps_one_epoch)
test_prog = test_prog.clone(for_test=True)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(train_startup_prog)
exe.run(test_startup_prog)
exe.run(startup_prog)
if args.pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(
exe,
args.pretrained_model,
main_program=train_prog,
predicate=if_exist)
#if args.pretrained_model:
......@@ -175,21 +145,24 @@ def train(model, args, im_shape, steps_one_epoch):
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 1
train_exe = fluid.ParallelExecutor(
main_program=train_prog,
use_cuda=True,
loss_name=loss_train.name,
exec_strategy=exec_strategy)
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = True
compile_program = fluid.compiler.CompiledProgram(
train_prog).with_data_parallel(
loss_name=loss_train.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
train_reader = reader.train10(args)
test_reader = reader.test10(args)
train_py_reader.decorate_paddle_reader(train_reader)
test_py_reader.decorate_paddle_reader(test_reader)
fluid.clip.set_gradient_clip(fluid.clip.GradientClipByGlobalNorm(args.grad_clip), program=train_prog)
fluid.clip.set_gradient_clip(
fluid.clip.GradientClipByGlobalNorm(args.grad_clip), program=train_prog)
train_fetch_list = [loss_train]
fluid.memory_optimize(train_prog, skip_opt_set=set(train_fetch_list))
def save_model(postfix, main_prog):
model_path = os.path.join(args.save_model_path, postfix)
......@@ -199,8 +172,6 @@ def train(model, args, im_shape, steps_one_epoch):
def test(epoch_id):
test_fetch_list = [prob, acc_1, acc_5]
#objs = utils.AvgrageMeter()
#prob = []
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
test_py_reader.start()
......@@ -210,8 +181,8 @@ def train(model, args, im_shape, steps_one_epoch):
while True:
prev_test_start_time = test_start_time
test_start_time = time.time()
prob_v, acc_1_v, acc_5_v = exe.run(
test_prog, fetch_list=test_fetch_list)
prob_v, acc_1_v, acc_5_v = exe.run(test_prog,
fetch_list=test_fetch_list)
top1.update(np.array(acc_1_v), np.array(prob_v).shape[0])
top5.update(np.array(acc_5_v), np.array(prob_v).shape[0])
if step_id % args.report_freq == 0:
......@@ -242,7 +213,8 @@ def train(model, args, im_shape, steps_one_epoch):
while True:
prev_start_time = start_time
start_time = time.time()
loss_v, = train_exe.run(
loss_v, = exe.run(
compile_program,
fetch_list=[v.name for v in train_fetch_list])
print("Epoch {}, Step {}, loss {}, time {}".format(epoch_id, step_id, \
np.array(loss_v).mean(), start_time-prev_start_time))
......@@ -250,8 +222,10 @@ def train(model, args, im_shape, steps_one_epoch):
sys.stdout.flush()
except fluid.core.EOFException:
train_py_reader.reset()
if epoch_id % 50 == 0 or epoch_id == args.epochs - 1:
if epoch_id % 50 == 0:
save_model(str(epoch_id), train_prog)
if epoch_id == args.epochs - 1:
save_model('final', train_prog)
test(epoch_id)
......
import numpy as np
import cPickle as cp
try:
import cPickle as pickle
except ImportError:
import pickle
import sys, os
#model_path = 'final_paddle-results'
model_path = 'paddle-results'
model_path = 'paddle_predict'
fl = os.listdir(model_path)
labels = np.load('labels.npz')['arr_0']
pred = np.zeros((10000, 10))
fl.sort()
i = 0
weight=1
for f in fl:
if 'init' in f:
continue
print(f)
if i == 1: weight=1.2
if i == 2: weight=0.8
if i == 3: weight=1.3
if i == 4: weight=1.1
if i == 5: weight=0.9
pred += weight* cp.load(open(os.path.join(model_path, f)))
pred += pickle.load(open(os.path.join(model_path, f)))
print(np.mean(np.argmax(pred, axis=1) == labels))
i += 1
# AutoDL
\ No newline at end of file
# Introduction to AutoDL Design
## Content
- [Installation](#Installation)
- [Introduction](#Introduction)
- [Data Preparation](#Data-Preparation)
- [Model Training](#Model-Training)
## Installation
Running demo code in the current directory requires PadddlePaddle Fluid v.1.3.0 or above. If your runtime environment does not meet this requirement, please update PaddlePaddle according to the documents.
* Install Python2.7
* Install dependencies [PARL](https://github.com/PaddlePaddle/PARL) framework and [absl-py](https://github.com/abseil/abseil-py/tree/master/absl) library,as follows:
```
pip install parl
pip install absl-py
```
## Introduction
[AutoDL](http://www.paddlepaddle.org/paddle/ModelAutoDL) is an efficient automated neural architecture design method. It designs quality customized neural architecture via reinforcement learning. The system consists of two components: an encoder of the neural architecture, and a critic of the model performance. The encoder encodes neural architecture using a recurrent neural network, and the critic evaluates the sampled architecture in terms of accuracy, number of model parameters, etc., which are fed back to the encoder. The encoder updates its parameters accordingly, and samples a new batch of architectures. After several iterations, the encoder is trained to converge and finds a quality architecture. The open-sourced AutoDl Design is one implementation of AutoDL technique. Section 2 presents the usage of AutoDL. Section 3 presents the framework and examples.
## Data Preparation
* Clone [PaddlePaddle/AutoDL](https://github.com/PaddlePaddle/AutoDL.git) to local machine,and enter the path of AutoDL Design.
* Download [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz) training data, unzip to AutoDL Design/cifar, and generate a dataset of 10 classes and 100 images per class using `dataset_maker.py`
```
tar zxf cifar-10-python.tar.gz
python dataset_maker.py
```
## Model Training
In the training process, AutoDLa Design agent generates tokens and adjacency matrices used for training, and the trainer uses these tokens and matrices to construct and train convolutional neural networks. The validation accuracy after 20 epochs are used as feed back for the agent, and the agent updates its policy accordingly. After several iterations, the agent learns to find a quality deep neural network.
![Picture](./AutoDL%20Design/img/cnn_net.png)
Here we provide the following test on the method.
### Test on the convergence of the number of tokens produced
Due to the long training time of CNN, to test the validity of agent framework, we use the number of "correct" tokens produced as a pseudo reward. The agent will learn to produce more "correct" tokens per step. The total length of tokens is set at 20.
```
export FLAGS_fraction_of_gpu_memory_to_use=0.98
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fast_eager_deletion_mode=1
CUDA_VISIBLE_DEVICES=0 python -u simple_main.py
```
Expected results:
In the log, `average rewards` gradually converges to 20:
```
Simple run target is 20
mid=0, average rewards=2.500
...
mid=450, average rewards=17.100
mid=460, average rewards=17.000
```
### Training AutoDL to design CNN
Train AutoDL Design on the small scale dataset prepared in the previous section:
```
export FLAGS_fraction_of_gpu_memory_to_use=0.98
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fast_eager_deletion_mode=1
CUDA_VISIBLE_DEVICES=0 python -u main.py
```
__Note:__ It requires two GPUs for training, GPU used by the Agent is set by `CUDA_VISIBLE_DEVICES=0`(in `main.py`);Trainer uses GPU set by `CUDA_VISIBLE_DEVICES=1`(in [autodl.py](https://github.com/PaddlePaddle/AutoDL/blob/master/AutoDL%20Design/autodl.py#L124))
Expected results:
In the log, `average accuracy` gradually increases:
```
step = 0, average accuracy = 0.633
step = 1, average accuracy = 0.688
step = 2, average accuracy = 0.626
step = 3, average accuracy = 0.682
......
step = 842, average accuracy = 0.823
step = 843, average accuracy = 0.825
step = 844, average accuracy = 0.808
......
```
### Results
![Picture](./AutoDL%20Design/img/search_result.png)
The x-axis is the number of steps, and the y-axis is validation accuracy of the sampled models. The average performance of the sampled models improves over time.
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册