diff --git a/LRC/README.md b/LRC/README.md new file mode 100644 index 0000000000000000000000000000000000000000..df9af47d4a3876371673cbbfef0ad2553768b9a5 --- /dev/null +++ b/LRC/README.md @@ -0,0 +1,74 @@ +# LRC Local Rademachar Complexity Regularization +Regularization of Deep Neural Networks(DNNs) for the sake of improving their generalization capability is important and chllenging. This directory contains image classification model based on a novel regularizer rooted in Local Rademacher Complexity (LRC). We appreciate the contribution by [DARTS](https://arxiv.org/abs/1806.09055) for our research. The regularization by LRC and DARTS are combined in this model on CIFAR-10 dataset. Code accompanying the paper +> [An Empirical Study on Regularization of Deep Neural Networks by Local Rademacher Complexity](https://arxiv.org/abs/1902.00873)\ +> Yingzhen Yang, Xingjian Li, Jun Huan.\ +> _arXiv:1902.00873_. + +--- +# Table of Contents + +- [Installation](#installation) +- [Data preparation](#data-preparation) +- [Training](#training) + +## Installation + +Running sample code in this directory requires PaddelPaddle Fluid v.1.2.0 and later. If the PaddlePaddle on your device is lower than this version, please follow the instructions in [installation document](http://www.paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html#paddlepaddle) and make an update. + +## Data preparation + +When you want to use the cifar-10 dataset for the first time, you can download the dataset as: + + sh ./dataset/download.sh + +Please make sure your environment has an internet connection. + +The dataset will be downloaded to `dataset/cifar/cifar-10-batches-py` in the same directory as the `train.py`. If automatic download fails, you can download cifar-10-python.tar.gz from https://www.cs.toronto.edu/~kriz/cifar.html and decompress it to the location mentioned above. + + +## Training + +After data preparation, one can start the training step by: + + python -u train_mixup.py \ + --batch_size=80 \ + --auxiliary \ + --weight_decay=0.0003 \ + --learning_rate=0.025 \ + --lrc_loss_lambda=0.7 \ + --cutout +- Set ```export CUDA_VISIBLE_DEVICES=0``` to specifiy one GPU to train. +- For more help on arguments: + + python train_mixup.py --help + +**data reader introduction:** + +* Data reader is defined in `reader.py`. +* Reshape the images to 32 * 32. +* In training stage, images are padding to 40 * 40 and cropped randomly to the original size. +* In training stage, images are horizontally random flipped. +* Images are standardized to (0, 1). +* In training stage, cutout images randomly. +* Shuffle the order of the input images during training. + +**model configuration:** + +* Use auxiliary loss and auxiliary\_weight=0.4. +* Use dropout and drop\_path\_prob=0.2. +* Set lrc\_loss\_lambda=0.7. + +**training strategy:** + +* Use momentum optimizer with momentum=0.9. +* Weight decay is 0.0003. +* Use cosine decay with init\_lr=0.025. +* Total epoch is 600. +* Use Xaiver initalizer to weight in conv2d, Constant initalizer to weight in batch norm and Normal initalizer to weight in fc. +* Initalize bias in batch norm and fc to zero constant and do not add bias to conv2d. + + +## Reference + + - DARTS: Differentiable Architecture Search [`paper`](https://arxiv.org/abs/1806.09055) + - Differentiable architecture search in PyTorch [`code`](https://github.com/quark0/darts) diff --git a/LRC/README_cn.md b/LRC/README_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..06dc937074de199af31db97ee200e7690443b1b0 --- /dev/null +++ b/LRC/README_cn.md @@ -0,0 +1,71 @@ +# LRC 局部Rademachar复杂度正则化 +为了在深度神经网络中提升泛化能力,正则化的选择十分重要也具有挑战性。本目录包括了一种基于局部rademacher复杂度的新型正则(LRC)的图像分类模型。十分感谢[DARTS](https://arxiv.org/abs/1806.09055)模型对本研究提供的帮助。该模型将LRC正则和DARTS网络相结合,在CIFAR-10数据集中得到了很出色的效果。代码和文章一同发布 +> [An Empirical Study on Regularization of Deep Neural Networks by Local Rademacher Complexity](https://arxiv.org/abs/1902.00873)\ +> Yingzhen Yang, Xingjian Li, Jun Huan.\ +> _arXiv:1902.00873_. + +--- +# 内容 + +- [安装](#安装) +- [数据准备](#数据准备) +- [模型训练](#模型训练) + +## 安装 + +在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.2.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html#paddlepaddle)中的说明来更新PaddlePaddle。 + +## 数据准备 + +第一次使用CIFAR-10数据集时,您可以通过如果命令下载: + + sh ./dataset/download.sh + +请确保您的环境有互联网连接。数据会下载到`train.py`同目录下的`dataset/cifar/cifar-10-batches-py`。如果下载失败,您可以自行从https://www.cs.toronto.edu/~kriz/cifar.html上下载cifar-10-python.tar.gz并解压到上述位置。 + +## 模型训练 + +数据准备好后,可以通过如下命令开始训练: + + python -u train_mixup.py \ + --batch_size=80 \ + --auxiliary \ + --weight_decay=0.0003 \ + --learning_rate=0.025 \ + --lrc_loss_lambda=0.7 \ + --cutout +- 通过设置 ```export CUDA_VISIBLE_DEVICES=0```指定单张GPU训练。 +- 可选参数见: + + python train_mixup.py --help + +**数据读取器说明:** + +* 数据读取器定义在`reader.py`中 +* 输入图像尺寸统一变换为32 * 32 +* 训练时将图像填充为40 * 40然后随机剪裁为原输入图像大小 +* 训练时图像随机水平翻转 +* 对图像每个像素做归一化处理 +* 训练时对图像做随机遮挡 +* 训练时对输入图像做随机洗牌 + +**模型配置:** + +* 使用辅助损失,辅助损失权重为0.4 +* 使用dropout,随机丢弃率为0.2 +* 设置lrc\_loss\_lambda为0.7 + +**训练策略:** + +* 采用momentum优化算法训练,momentum=0.9 +* 权重衰减系数为0.0001 +* 采用正弦学习率衰减,初始学习率为0.025 +* 总共训练600轮 +* 对卷积权重采用Xaiver初始化,对batch norm权重采用固定初始化,对全连接层权重采用高斯初始化 +* 对batch norm和全连接层偏差采用固定初始化,不对卷积设置偏差 + + +## 引用 + + - DARTS: Differentiable Architecture Search [`论文`](https://arxiv.org/abs/1806.09055) + - Differentiable Architecture Search in PyTorch [`代码`](https://github.com/quark0/darts) diff --git a/LRC/dataset/download.sh b/LRC/dataset/download.sh new file mode 100644 index 0000000000000000000000000000000000000000..0981c3b6878421f80d392f314fd0ae836644a63c --- /dev/null +++ b/LRC/dataset/download.sh @@ -0,0 +1,10 @@ +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +cd "$DIR" +mkdir cifar +cd cifar +# Download the data. +echo "Downloading..." +wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz +# Extract the data. +echo "Extracting..." +tar zvxf cifar-10-python.tar.gz diff --git a/LRC/genotypes.py b/LRC/genotypes.py new file mode 100644 index 0000000000000000000000000000000000000000..349fbd2478a7c2d1bb4cc3dd901b470de3c8b906 --- /dev/null +++ b/LRC/genotypes.py @@ -0,0 +1,116 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Based on: +# -------------------------------------------------------- +# DARTS +# Copyright (c) 2018, Hanxiao Liu. +# Licensed under the Apache License, Version 2.0; +# -------------------------------------------------------- + +from collections import namedtuple + +Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') + +PRIMITIVES = [ + 'none', 'max_pool_3x3', 'avg_pool_3x3', 'skip_connect', 'sep_conv_3x3', + 'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5' +] + +NASNet = Genotype( + normal=[ + ('sep_conv_5x5', 1), + ('sep_conv_3x3', 0), + ('sep_conv_5x5', 0), + ('sep_conv_3x3', 0), + ('avg_pool_3x3', 1), + ('skip_connect', 0), + ('avg_pool_3x3', 0), + ('avg_pool_3x3', 0), + ('sep_conv_3x3', 1), + ('skip_connect', 1), + ], + normal_concat=[2, 3, 4, 5, 6], + reduce=[ + ('sep_conv_5x5', 1), + ('sep_conv_7x7', 0), + ('max_pool_3x3', 1), + ('sep_conv_7x7', 0), + ('avg_pool_3x3', 1), + ('sep_conv_5x5', 0), + ('skip_connect', 3), + ('avg_pool_3x3', 2), + ('sep_conv_3x3', 2), + ('max_pool_3x3', 1), + ], + reduce_concat=[4, 5, 6], ) + +AmoebaNet = Genotype( + normal=[ + ('avg_pool_3x3', 0), + ('max_pool_3x3', 1), + ('sep_conv_3x3', 0), + ('sep_conv_5x5', 2), + ('sep_conv_3x3', 0), + ('avg_pool_3x3', 3), + ('sep_conv_3x3', 1), + ('skip_connect', 1), + ('skip_connect', 0), + ('avg_pool_3x3', 1), + ], + normal_concat=[4, 5, 6], + reduce=[ + ('avg_pool_3x3', 0), + ('sep_conv_3x3', 1), + ('max_pool_3x3', 0), + ('sep_conv_7x7', 2), + ('sep_conv_7x7', 0), + ('avg_pool_3x3', 1), + ('max_pool_3x3', 0), + ('max_pool_3x3', 1), + ('conv_7x1_1x7', 0), + ('sep_conv_3x3', 5), + ], + reduce_concat=[3, 4, 6]) + +DARTS_V1 = Genotype( + normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 0), + ('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), + ('sep_conv_3x3', 0), ('skip_connect', 2)], + normal_concat=[2, 3, 4, 5], + reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), + ('max_pool_3x3', 0), ('max_pool_3x3', 0), ('skip_connect', 2), + ('skip_connect', 2), ('avg_pool_3x3', 0)], + reduce_concat=[2, 3, 4, 5]) +DARTS_V2 = Genotype( + normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), + ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), + ('skip_connect', 0), ('dil_conv_3x3', 2)], + normal_concat=[2, 3, 4, 5], + reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), + ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), + ('skip_connect', 2), ('max_pool_3x3', 1)], + reduce_concat=[2, 3, 4, 5]) + +MY_DARTS = Genotype( + normal=[('sep_conv_3x3', 0), ('skip_connect', 1), ('skip_connect', 0), + ('dil_conv_5x5', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), + ('skip_connect', 0), ('sep_conv_3x3', 1)], + normal_concat=range(2, 6), + reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('max_pool_3x3', 0), + ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2), + ('skip_connect', 2), ('skip_connect', 3)], + reduce_concat=range(2, 6)) + +DARTS = MY_DARTS diff --git a/LRC/learning_rate.py b/LRC/learning_rate.py new file mode 100644 index 0000000000000000000000000000000000000000..3965171b487884d36e4a7447f10f312204803bf8 --- /dev/null +++ b/LRC/learning_rate.py @@ -0,0 +1,43 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Based on: +# -------------------------------------------------------- +# DARTS +# Copyright (c) 2018, Hanxiao Liu. +# Licensed under the Apache License, Version 2.0; +# -------------------------------------------------------- + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers.ops as ops +from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter +import math +from paddle.fluid.initializer import init_on_cpu + + +def cosine_decay(learning_rate, num_epoch, steps_one_epoch): + """Applies cosine decay to the learning rate. + lr = 0.5 * (math.cos(epoch * (math.pi / 120)) + 1) + """ + global_step = _decay_step_counter() + + with init_on_cpu(): + decayed_lr = learning_rate * \ + (ops.cos((global_step / steps_one_epoch) \ + * math.pi / num_epoch) + 1)/2 + return decayed_lr diff --git a/LRC/model.py b/LRC/model.py new file mode 100644 index 0000000000000000000000000000000000000000..45a403495ecc0b7cc0ac3b541d75702adbef31b2 --- /dev/null +++ b/LRC/model.py @@ -0,0 +1,313 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +# +# Based on: +# -------------------------------------------------------- +# DARTS +# Copyright (c) 2018, Hanxiao Liu. +# Licensed under the Apache License, Version 2.0; +# -------------------------------------------------------- + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import sys +import numpy as np +import time +import functools +import paddle +import paddle.fluid as fluid +from operations import * + + +class Cell(): + def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, + reduction_prev): + print(C_prev_prev, C_prev, C) + + if reduction_prev: + self.preprocess0 = functools.partial(FactorizedReduce, C_out=C) + else: + self.preprocess0 = functools.partial( + ReLUConvBN, C_out=C, kernel_size=1, stride=1, padding=0) + self.preprocess1 = functools.partial( + ReLUConvBN, C_out=C, kernel_size=1, stride=1, padding=0) + if reduction: + op_names, indices = zip(*genotype.reduce) + concat = genotype.reduce_concat + else: + op_names, indices = zip(*genotype.normal) + concat = genotype.normal_concat + print(op_names, indices, concat, reduction) + self._compile(C, op_names, indices, concat, reduction) + + def _compile(self, C, op_names, indices, concat, reduction): + assert len(op_names) == len(indices) + self._steps = len(op_names) // 2 + self._concat = concat + self.multiplier = len(concat) + + self._ops = [] + for name, index in zip(op_names, indices): + stride = 2 if reduction and index < 2 else 1 + op = functools.partial(OPS[name], C=C, stride=stride, affine=True) + self._ops += [op] + self._indices = indices + + def forward(self, s0, s1, drop_prob, is_train, name): + self.training = is_train + preprocess0_name = name + 'preprocess0.' + preprocess1_name = name + 'preprocess1.' + s0 = self.preprocess0(s0, name=preprocess0_name) + s1 = self.preprocess1(s1, name=preprocess1_name) + out = [s0, s1] + for i in range(self._steps): + h1 = out[self._indices[2 * i]] + h2 = out[self._indices[2 * i + 1]] + op1 = self._ops[2 * i] + op2 = self._ops[2 * i + 1] + h3 = op1(h1, name=name + '_ops.' + str(2 * i) + '.') + h4 = op2(h2, name=name + '_ops.' + str(2 * i + 1) + '.') + if self.training and drop_prob > 0.: + if h3 != h1: + h3 = fluid.layers.dropout( + h3, + drop_prob, + dropout_implementation='upscale_in_train') + if h4 != h2: + h4 = fluid.layers.dropout( + h4, + drop_prob, + dropout_implementation='upscale_in_train') + s = h3 + h4 + out += [s] + return fluid.layers.concat([out[i] for i in self._concat], axis=1) + + +def AuxiliaryHeadCIFAR(input, num_classes, aux_name='auxiliary_head'): + relu_a = fluid.layers.relu(input) + pool_a = fluid.layers.pool2d(relu_a, 5, 'avg', 3) + conv2d_a = fluid.layers.conv2d( + pool_a, + 128, + 1, + name=aux_name + '.features.2', + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=aux_name + '.features.2.weight'), + bias_attr=False) + bn_a_name = aux_name + '.features.3' + bn_a = fluid.layers.batch_norm( + conv2d_a, + act='relu', + name=bn_a_name, + param_attr=ParamAttr( + initializer=Constant(1.), name=bn_a_name + '.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=bn_a_name + '.bias'), + moving_mean_name=bn_a_name + '.running_mean', + moving_variance_name=bn_a_name + '.running_var') + conv2d_b = fluid.layers.conv2d( + bn_a, + 768, + 2, + name=aux_name + '.features.5', + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=aux_name + '.features.5.weight'), + bias_attr=False) + bn_b_name = aux_name + '.features.6' + bn_b = fluid.layers.batch_norm( + conv2d_b, + act='relu', + name=bn_b_name, + param_attr=ParamAttr( + initializer=Constant(1.), name=bn_b_name + '.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=bn_b_name + '.bias'), + moving_mean_name=bn_b_name + '.running_mean', + moving_variance_name=bn_b_name + '.running_var') + fc_name = aux_name + '.classifier' + fc = fluid.layers.fc(bn_b, + num_classes, + name=fc_name, + param_attr=ParamAttr( + initializer=Normal(scale=1e-3), + name=fc_name + '.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=fc_name + '.bias')) + return fc + + +def StemConv(input, C_out, kernel_size, padding): + conv_a = fluid.layers.conv2d( + input, + C_out, + kernel_size, + padding=padding, + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), name='stem.0.weight'), + bias_attr=False) + bn_a = fluid.layers.batch_norm( + conv_a, + param_attr=ParamAttr( + initializer=Constant(1.), name='stem.1.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name='stem.1.bias'), + moving_mean_name='stem.1.running_mean', + moving_variance_name='stem.1.running_var') + return bn_a + + +class NetworkCIFAR(object): + def __init__(self, C, class_num, layers, auxiliary, genotype): + self.class_num = class_num + self._layers = layers + self._auxiliary = auxiliary + + stem_multiplier = 3 + self.drop_path_prob = 0 + C_curr = stem_multiplier * C + + C_prev_prev, C_prev, C_curr = C_curr, C_curr, C + self.cells = [] + reduction_prev = False + for i in range(layers): + if i in [layers // 3, 2 * layers // 3]: + C_curr *= 2 + reduction = True + else: + reduction = False + cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, + reduction_prev) + reduction_prev = reduction + self.cells += [cell] + C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr + if i == 2 * layers // 3: + C_to_auxiliary = C_prev + + def forward(self, init_channel, is_train): + self.training = is_train + self.logits_aux = None + num_channel = init_channel * 3 + s0 = StemConv(self.image, num_channel, kernel_size=3, padding=1) + s1 = s0 + for i, cell in enumerate(self.cells): + name = 'cells.' + str(i) + '.' + s0, s1 = s1, cell.forward(s0, s1, self.drop_path_prob, is_train, + name) + if i == int(2 * self._layers // 3): + if self._auxiliary and self.training: + self.logits_aux = AuxiliaryHeadCIFAR(s1, self.class_num) + out = fluid.layers.adaptive_pool2d(s1, (1, 1), "avg") + self.logits = fluid.layers.fc(out, + size=self.class_num, + param_attr=ParamAttr( + initializer=Normal(scale=1e-3), + name='classifier.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), + name='classifier.bias')) + return self.logits, self.logits_aux + + def build_input(self, image_shape, batch_size, is_train): + if is_train: + py_reader = fluid.layers.py_reader( + capacity=64, + shapes=[[-1] + image_shape, [-1, 1], [-1, 1], [-1, 1], [-1, 1], + [-1, 1], [-1, batch_size, self.class_num - 1]], + lod_levels=[0, 0, 0, 0, 0, 0, 0], + dtypes=[ + "float32", "int64", "int64", "float32", "int32", "int32", + "float32" + ], + use_double_buffer=True, + name='train_reader') + else: + py_reader = fluid.layers.py_reader( + capacity=64, + shapes=[[-1] + image_shape, [-1, 1]], + lod_levels=[0, 0], + dtypes=["float32", "int64"], + use_double_buffer=True, + name='test_reader') + return py_reader + + def train_model(self, py_reader, init_channels, aux, aux_w, batch_size, + loss_lambda): + self.image, self.ya, self.yb, self.lam, self.label_reshape,\ + self.non_label_reshape, self.rad_var = fluid.layers.read_file(py_reader) + self.logits, self.logits_aux = self.forward(init_channels, True) + self.mixup_loss = self.mixup_loss(aux, aux_w) + self.lrc_loss = self.lrc_loss(batch_size) + return self.mixup_loss + loss_lambda * self.lrc_loss + + def test_model(self, py_reader, init_channels): + self.image, self.ya = fluid.layers.read_file(py_reader) + self.logits, _ = self.forward(init_channels, False) + prob = fluid.layers.softmax(self.logits, use_cudnn=False) + loss = fluid.layers.cross_entropy(prob, self.ya) + acc_1 = fluid.layers.accuracy(self.logits, self.ya, k=1) + acc_5 = fluid.layers.accuracy(self.logits, self.ya, k=5) + return loss, acc_1, acc_5 + + def mixup_loss(self, auxiliary, auxiliary_weight): + prob = fluid.layers.softmax(self.logits, use_cudnn=False) + loss_a = fluid.layers.cross_entropy(prob, self.ya) + loss_b = fluid.layers.cross_entropy(prob, self.yb) + loss_a_mean = fluid.layers.reduce_mean(loss_a) + loss_b_mean = fluid.layers.reduce_mean(loss_b) + loss = self.lam * loss_a_mean + (1 - self.lam) * loss_b_mean + if auxiliary: + prob_aux = fluid.layers.softmax(self.logits_aux, use_cudnn=False) + loss_a_aux = fluid.layers.cross_entropy(prob_aux, self.ya) + loss_b_aux = fluid.layers.cross_entropy(prob_aux, self.yb) + loss_a_aux_mean = fluid.layers.reduce_mean(loss_a_aux) + loss_b_aux_mean = fluid.layers.reduce_mean(loss_b_aux) + loss_aux = self.lam * loss_a_aux_mean + (1 - self.lam + ) * loss_b_aux_mean + return loss + auxiliary_weight * loss_aux + + def lrc_loss(self, batch_size): + y_diff_reshape = fluid.layers.reshape(self.logits, shape=(-1, 1)) + label_reshape = fluid.layers.squeeze(self.label_reshape, axes=[1]) + non_label_reshape = fluid.layers.squeeze( + self.non_label_reshape, axes=[1]) + label_reshape.stop_gradient = True + non_label_reshape.stop_graident = True + + y_diff_label_reshape = fluid.layers.gather(y_diff_reshape, + label_reshape) + y_diff_non_label_reshape = fluid.layers.gather(y_diff_reshape, + non_label_reshape) + y_diff_label = fluid.layers.reshape( + y_diff_label_reshape, shape=(-1, batch_size, 1)) + y_diff_non_label = fluid.layers.reshape( + y_diff_non_label_reshape, + shape=(-1, batch_size, self.class_num - 1)) + y_diff_ = y_diff_non_label - y_diff_label + + y_diff_ = fluid.layers.transpose(y_diff_, perm=[1, 2, 0]) + rad_var_trans = fluid.layers.transpose(self.rad_var, perm=[1, 2, 0]) + rad_y_diff_trans = rad_var_trans * y_diff_ + lrc_loss_sum = fluid.layers.reduce_sum(rad_y_diff_trans, dim=[0, 1]) + lrc_loss_ = fluid.layers.abs(lrc_loss_sum) / (batch_size * + (self.class_num - 1)) + lrc_loss_mean = fluid.layers.reduce_mean(lrc_loss_) + + return lrc_loss_mean diff --git a/LRC/operations.py b/LRC/operations.py new file mode 100644 index 0000000000000000000000000000000000000000..b015722a1bc5dbf682c90812a971f3dbb2cd8c9a --- /dev/null +++ b/LRC/operations.py @@ -0,0 +1,349 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +# +# Based on: +# -------------------------------------------------------- +# DARTS +# Copyright (c) 2018, Hanxiao Liu. +# Licensed under the Apache License, Version 2.0; +# -------------------------------------------------------- +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import sys +import numpy as np +import time +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Xavier +from paddle.fluid.initializer import Normal +from paddle.fluid.initializer import Constant + +OPS = { + 'none' : lambda input, C, stride, name, affine: Zero(input, stride, name), + 'avg_pool_3x3' : lambda input, C, stride, name, affine: fluid.layers.pool2d(input, 3, 'avg', pool_stride=stride, pool_padding=1, name=name), + 'max_pool_3x3' : lambda input, C, stride, name, affine: fluid.layers.pool2d(input, 3, 'max', pool_stride=stride, pool_padding=1, name=name), + 'skip_connect' : lambda input,C, stride, name, affine: Identity(input, name) if stride == 1 else FactorizedReduce(input, C, name=name, affine=affine), + 'sep_conv_3x3' : lambda input,C, stride, name, affine: SepConv(input, C, C, 3, stride, 1, name=name, affine=affine), + 'sep_conv_5x5' : lambda input,C, stride, name, affine: SepConv(input, C, C, 5, stride, 2, name=name, affine=affine), + 'sep_conv_7x7' : lambda input,C, stride, name, affine: SepConv(input, C, C, 7, stride, 3, name=name, affine=affine), + 'dil_conv_3x3' : lambda input,C, stride, name, affine: DilConv(input, C, C, 3, stride, 2, 2, name=name, affine=affine), + 'dil_conv_5x5' : lambda input,C, stride, name, affine: DilConv(input, C, C, 5, stride, 4, 2, name=name, affine=affine), + 'conv_7x1_1x7' : lambda input,C, stride, name, affine: SevenConv(input, C, name=name, affine=affine) +} + + +def ReLUConvBN(input, C_out, kernel_size, stride, padding, name='', + affine=True): + relu_a = fluid.layers.relu(input) + conv2d_a = fluid.layers.conv2d( + relu_a, + C_out, + kernel_size, + stride, + padding, + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'op.1.weight'), + bias_attr=False) + if affine: + reluconvbn_out = fluid.layers.batch_norm( + conv2d_a, + param_attr=ParamAttr( + initializer=Constant(1.), name=name + 'op.2.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=name + 'op.2.bias'), + moving_mean_name=name + 'op.2.running_mean', + moving_variance_name=name + 'op.2.running_var') + else: + reluconvbn_out = fluid.layers.batch_norm( + conv2d_a, + param_attr=ParamAttr( + initializer=Constant(1.), + learning_rate=0., + name=name + 'op.2.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), + learning_rate=0., + name=name + 'op.2.bias'), + moving_mean_name=name + 'op.2.running_mean', + moving_variance_name=name + 'op.2.running_var') + return reluconvbn_out + + +def DilConv(input, + C_in, + C_out, + kernel_size, + stride, + padding, + dilation, + name='', + affine=True): + relu_a = fluid.layers.relu(input) + conv2d_a = fluid.layers.conv2d( + relu_a, + C_in, + kernel_size, + stride, + padding, + dilation, + groups=C_in, + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'op.1.weight'), + bias_attr=False, + use_cudnn=False) + conv2d_b = fluid.layers.conv2d( + conv2d_a, + C_out, + 1, + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'op.2.weight'), + bias_attr=False) + if affine: + dilconv_out = fluid.layers.batch_norm( + conv2d_b, + param_attr=ParamAttr( + initializer=Constant(1.), name=name + 'op.3.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=name + 'op.3.bias'), + moving_mean_name=name + 'op.3.running_mean', + moving_variance_name=name + 'op.3.running_var') + else: + dilconv_out = fluid.layers.batch_norm( + conv2d_b, + param_attr=ParamAttr( + initializer=Constant(1.), + learning_rate=0., + name=name + 'op.3.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), + learning_rate=0., + name=name + 'op.3.bias'), + moving_mean_name=name + 'op.3.running_mean', + moving_variance_name=name + 'op.3.running_var') + return dilconv_out + + +def SepConv(input, + C_in, + C_out, + kernel_size, + stride, + padding, + name='', + affine=True): + relu_a = fluid.layers.relu(input) + conv2d_a = fluid.layers.conv2d( + relu_a, + C_in, + kernel_size, + stride, + padding, + groups=C_in, + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'op.1.weight'), + bias_attr=False, + use_cudnn=False) + conv2d_b = fluid.layers.conv2d( + conv2d_a, + C_in, + 1, + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'op.2.weight'), + bias_attr=False) + if affine: + bn_a = fluid.layers.batch_norm( + conv2d_b, + param_attr=ParamAttr( + initializer=Constant(1.), name=name + 'op.3.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=name + 'op.3.bias'), + moving_mean_name=name + 'op.3.running_mean', + moving_variance_name=name + 'op.3.running_var') + else: + bn_a = fluid.layers.batch_norm( + conv2d_b, + param_attr=ParamAttr( + initializer=Constant(1.), + learning_rate=0., + name=name + 'op.3.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), + learning_rate=0., + name=name + 'op.3.bias'), + moving_mean_name=name + 'op.3.running_mean', + moving_variance_name=name + 'op.3.running_var') + + relu_b = fluid.layers.relu(bn_a) + conv2d_d = fluid.layers.conv2d( + relu_b, + C_in, + kernel_size, + 1, + padding, + groups=C_in, + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'op.5.weight'), + bias_attr=False, + use_cudnn=False) + conv2d_e = fluid.layers.conv2d( + conv2d_d, + C_out, + 1, + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'op.6.weight'), + bias_attr=False) + if affine: + sepconv_out = fluid.layers.batch_norm( + conv2d_e, + param_attr=ParamAttr( + initializer=Constant(1.), name=name + 'op.7.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=name + 'op.7.bias'), + moving_mean_name=name + 'op.7.running_mean', + moving_variance_name=name + 'op.7.running_var') + else: + sepconv_out = fluid.layers.batch_norm( + conv2d_e, + param_attr=ParamAttr( + initializer=Constant(1.), + learning_rate=0., + name=name + 'op.7.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), + learning_rate=0., + name=name + 'op.7.bias'), + moving_mean_name=name + 'op.7.running_mean', + moving_variance_name=name + 'op.7.running_var') + return sepconv_out + + +def SevenConv(input, C_out, stride, name='', affine=True): + relu_a = fluid.layers.relu(input) + conv2d_a = fluid.layers.conv2d( + relu_a, + C_out, (1, 7), (1, stride), (0, 3), + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'op.1.weight'), + bias_attr=False) + conv2d_b = fluid.layers.conv2d( + conv2d_a, + C_out, (7, 1), (stride, 1), (3, 0), + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'op.2.weight'), + bias_attr=False) + if affine: + out = fluid.layers.batch_norm( + conv2d_b, + param_attr=ParamAttr( + initializer=Constant(1.), name=name + 'op.3.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=name + 'op.3.bias'), + moving_mean_name=name + 'op.3.running_mean', + moving_variance_name=name + 'op.3.running_var') + else: + out = fluid.layers.batch_norm( + conv2d_b, + param_attr=ParamAttr( + initializer=Constant(1.), + learning_rate=0., + name=name + 'op.3.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), + learning_rate=0., + name=name + 'op.3.bias'), + moving_mean_name=name + 'op.3.running_mean', + moving_variance_name=name + 'op.3.running_var') + + +def Identity(input, name=''): + return input + + +def Zero(input, stride, name=''): + ones = np.ones(input.shape[-2:]) + ones[::stride, ::stride] = 0 + ones = fluid.layers.assign(ones) + return input * ones + + +def FactorizedReduce(input, C_out, name='', affine=True): + relu_a = fluid.layers.relu(input) + conv2d_a = fluid.layers.conv2d( + relu_a, + C_out // 2, + 1, + 2, + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'conv_1.weight'), + bias_attr=False) + h_end = relu_a.shape[2] + w_end = relu_a.shape[3] + slice_a = fluid.layers.slice(relu_a, [2, 3], [1, 1], [h_end, w_end]) + conv2d_b = fluid.layers.conv2d( + slice_a, + C_out // 2, + 1, + 2, + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'conv_2.weight'), + bias_attr=False) + out = fluid.layers.concat([conv2d_a, conv2d_b], axis=1) + if affine: + out = fluid.layers.batch_norm( + out, + param_attr=ParamAttr( + initializer=Constant(1.), name=name + 'bn.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=name + 'bn.bias'), + moving_mean_name=name + 'bn.running_mean', + moving_variance_name=name + 'bn.running_var') + else: + out = fluid.layers.batch_norm( + out, + param_attr=ParamAttr( + initializer=Constant(1.), + learning_rate=0., + name=name + 'bn.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), + learning_rate=0., + name=name + 'bn.bias'), + moving_mean_name=name + 'bn.running_mean', + moving_variance_name=name + 'bn.running_var') + return out diff --git a/LRC/reader.py b/LRC/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..20b32b504e9245c4ff3892f08736d800080daab4 --- /dev/null +++ b/LRC/reader.py @@ -0,0 +1,187 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rig hts Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Based on: +# -------------------------------------------------------- +# DARTS +# Copyright (c) 2018, Hanxiao Liu. +# Licensed under the Apache License, Version 2.0; +# -------------------------------------------------------- +""" +CIFAR-10 dataset. +This module will download dataset from +https://www.cs.toronto.edu/~kriz/cifar.html and parse train/test set into +paddle reader creators. +The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, +with 6000 images per class. There are 50000 training images and 10000 test images. +""" + +from PIL import Image +from PIL import ImageOps +import numpy as np + +import cPickle +import random +import utils +import paddle.fluid as fluid +import time +import os +import functools +import paddle.reader + +__all__ = ['train10', 'test10'] + +image_size = 32 +image_depth = 3 +half_length = 8 + +CIFAR_MEAN = [0.4914, 0.4822, 0.4465] +CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] + + +def generate_reshape_label(label, batch_size, CIFAR_CLASSES=10): + reshape_label = np.zeros((batch_size, 1), dtype='int32') + reshape_non_label = np.zeros( + (batch_size * (CIFAR_CLASSES - 1), 1), dtype='int32') + num = 0 + for i in range(batch_size): + label_i = label[i] + reshape_label[i] = label_i + i * CIFAR_CLASSES + for j in range(CIFAR_CLASSES): + if label_i != j: + reshape_non_label[num] = \ + j + i * CIFAR_CLASSES + num += 1 + return reshape_label, reshape_non_label + + +def generate_bernoulli_number(batch_size, CIFAR_CLASSES=10): + rcc_iters = 50 + rad_var = np.zeros((rcc_iters, batch_size, CIFAR_CLASSES - 1)) + for i in range(rcc_iters): + bernoulli_num = np.random.binomial(size=batch_size, n=1, p=0.5) + bernoulli_map = np.array([]) + ones = np.ones((CIFAR_CLASSES - 1, 1)) + for batch_id in range(batch_size): + num = bernoulli_num[batch_id] + var_id = 2 * ones * num - 1 + bernoulli_map = np.append(bernoulli_map, var_id) + rad_var[i] = bernoulli_map.reshape((batch_size, CIFAR_CLASSES - 1)) + return rad_var.astype('float32') + + +def preprocess(sample, is_training, args): + image_array = sample.reshape(3, image_size, image_size) + rgb_array = np.transpose(image_array, (1, 2, 0)) + img = Image.fromarray(rgb_array, 'RGB') + + if is_training: + # pad and ramdom crop + img = ImageOps.expand(img, (4, 4, 4, 4), fill=0) # pad to 40 * 40 * 3 + left_top = np.random.randint(9, size=2) # rand 0 - 8 + img = img.crop((left_top[0], left_top[1], left_top[0] + image_size, + left_top[1] + image_size)) + if np.random.randint(2): + img = img.transpose(Image.FLIP_LEFT_RIGHT) + + img = np.array(img).astype(np.float32) + + # per_image_standardization + img_float = img / 255.0 + img = (img_float - CIFAR_MEAN) / CIFAR_STD + + if is_training and args.cutout: + center = np.random.randint(image_size, size=2) + offset_width = max(0, center[0] - half_length) + offset_height = max(0, center[1] - half_length) + target_width = min(center[0] + half_length, image_size) + target_height = min(center[1] + half_length, image_size) + + for i in range(offset_height, target_height): + for j in range(offset_width, target_width): + img[i][j][:] = 0.0 + + img = np.transpose(img, (2, 0, 1)) + return img + + +def reader_creator_filepath(filename, sub_name, is_training, args): + files = os.listdir(filename) + names = [each_item for each_item in files if sub_name in each_item] + names.sort() + datasets = [] + for name in names: + print("Reading file " + name) + batch = cPickle.load(open(filename + name, 'rb')) + data = batch['data'] + labels = batch.get('labels', batch.get('fine_labels', None)) + assert labels is not None + dataset = zip(data, labels) + datasets.extend(dataset) + random.shuffle(datasets) + + def read_batch(datasets, args): + for sample, label in datasets: + im = preprocess(sample, is_training, args) + yield im, [int(label)] + + def reader(): + batch_data = [] + batch_label = [] + for data, label in read_batch(datasets, args): + batch_data.append(data) + batch_label.append(label) + if len(batch_data) == args.batch_size: + batch_data = np.array(batch_data, dtype='float32') + batch_label = np.array(batch_label, dtype='int64') + if is_training: + flatten_label, flatten_non_label = \ + generate_reshape_label(batch_label, args.batch_size) + rad_var = generate_bernoulli_number(args.batch_size) + mixed_x, y_a, y_b, lam = utils.mixup_data( + batch_data, batch_label, args.batch_size, + args.mix_alpha) + batch_out = [[mixed_x, y_a, y_b, lam, flatten_label, \ + flatten_non_label, rad_var]] + yield batch_out + else: + batch_out = [[batch_data, batch_label]] + yield batch_out + batch_data = [] + batch_label = [] + + return reader + + +def train10(args): + """ + CIFAR-10 training set creator. + It returns a reader creator, each sample in the reader is image pixels in + [0, 1] and label in [0, 9]. + :return: Training reader creator + :rtype: callable + """ + + return reader_creator_filepath(args.data, 'data_batch', True, args) + + +def test10(args): + """ + CIFAR-10 test set creator. + It returns a reader creator, each sample in the reader is image pixels in + [0, 1] and label in [0, 9]. + :return: Test reader creator. + :rtype: callable + """ + return reader_creator_filepath(args.data, 'test_batch', False, args) diff --git a/LRC/run.sh b/LRC/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..9f1a045d49789c3e9aebbc2a73b84b11da471b5a --- /dev/null +++ b/LRC/run.sh @@ -0,0 +1,8 @@ +CUDA_VISIBLE_DEVICES=0 python -u train_mixup.py \ +--batch_size=80 \ +--auxiliary \ +--weight_decay=0.0003 \ +--learning_rate=0.025 \ +--lrc_loss_lambda=0.7 \ +--cutout + diff --git a/LRC/train_mixup.py b/LRC/train_mixup.py new file mode 100644 index 0000000000000000000000000000000000000000..de752c84bcf9276aa83540d60370517e66c0704f --- /dev/null +++ b/LRC/train_mixup.py @@ -0,0 +1,247 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +# +# Based on: +# -------------------------------------------------------- +# DARTS +# Copyright (c) 2018, Hanxiao Liu. +# Licensed under the Apache License, Version 2.0; +# -------------------------------------------------------- + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from learning_rate import cosine_decay +import numpy as np +import argparse +from model import NetworkCIFAR as Network +import reader +import sys +import os +import time +import logging +import genotypes +import paddle.fluid as fluid +import shutil +import utils +import cPickle as cp + +parser = argparse.ArgumentParser("cifar") +parser.add_argument( + '--data', + type=str, + default='./dataset/cifar/cifar-10-batches-py/', + help='location of the data corpus') +parser.add_argument('--batch_size', type=int, default=96, help='batch size') +parser.add_argument( + '--learning_rate', type=float, default=0.025, help='init learning rate') +parser.add_argument('--momentum', type=float, default=0.9, help='momentum') +parser.add_argument( + '--weight_decay', type=float, default=3e-4, help='weight decay') +parser.add_argument( + '--report_freq', type=float, default=50, help='report frequency') +parser.add_argument( + '--epochs', type=int, default=600, help='num of training epochs') +parser.add_argument( + '--init_channels', type=int, default=36, help='num of init channels') +parser.add_argument( + '--layers', type=int, default=20, help='total number of layers') +parser.add_argument( + '--model_path', + type=str, + default='saved_models', + help='path to save the model') +parser.add_argument( + '--auxiliary', + action='store_true', + default=False, + help='use auxiliary tower') +parser.add_argument( + '--auxiliary_weight', + type=float, + default=0.4, + help='weight for auxiliary loss') +parser.add_argument( + '--cutout', action='store_true', default=False, help='use cutout') +parser.add_argument( + '--cutout_length', type=int, default=16, help='cutout length') +parser.add_argument( + '--drop_path_prob', type=float, default=0.2, help='drop path probability') +parser.add_argument('--save', type=str, default='EXP', help='experiment name') +parser.add_argument( + '--arch', type=str, default='DARTS', help='which architecture to use') +parser.add_argument( + '--grad_clip', type=float, default=5, help='gradient clipping') +parser.add_argument( + '--lr_exp_decay', + action='store_true', + default=False, + help='use exponential_decay learning_rate') +parser.add_argument('--mix_alpha', type=float, default=0.5, help='mixup alpha') +parser.add_argument( + '--lrc_loss_lambda', default=0, type=float, help='lrc_loss_lambda') +parser.add_argument( + '--loss_type', + default=1, + type=float, + help='loss_type 0: cross entropy 1: multi margin loss 2: max margin loss') + +args = parser.parse_args() + +CIFAR_CLASSES = 10 +dataset_train_size = 50000 +image_size = 32 + + +def main(): + image_shape = [3, image_size, image_size] + devices = os.getenv("CUDA_VISIBLE_DEVICES") or "" + devices_num = len(devices.split(",")) + logging.info("args = %s", args) + genotype = eval("genotypes.%s" % args.arch) + model = Network(args.init_channels, CIFAR_CLASSES, args.layers, + args.auxiliary, genotype) + steps_one_epoch = dataset_train_size / (devices_num * args.batch_size) + train(model, args, image_shape, steps_one_epoch) + + +def build_program(main_prog, startup_prog, args, is_train, model, im_shape, + steps_one_epoch): + out = [] + with fluid.program_guard(main_prog, startup_prog): + py_reader = model.build_input(im_shape, args.batch_size, is_train) + if is_train: + with fluid.unique_name.guard(): + loss = model.train_model(py_reader, args.init_channels, + args.auxiliary, args.auxiliary_weight, + args.batch_size, args.lrc_loss_lambda) + optimizer = fluid.optimizer.Momentum( + learning_rate=cosine_decay(args.learning_rate, \ + args.epochs, steps_one_epoch), + regularization=fluid.regularizer.L2Decay(\ + args.weight_decay), + momentum=args.momentum) + optimizer.minimize(loss) + out = [py_reader, loss] + else: + with fluid.unique_name.guard(): + loss, acc_1, acc_5 = model.test_model(py_reader, + args.init_channels) + out = [py_reader, loss, acc_1, acc_5] + return out + + +def train(model, args, im_shape, steps_one_epoch): + train_startup_prog = fluid.Program() + test_startup_prog = fluid.Program() + train_prog = fluid.Program() + test_prog = fluid.Program() + + train_py_reader, loss_train = build_program(train_prog, train_startup_prog, + args, True, model, im_shape, + steps_one_epoch) + + test_py_reader, loss_test, acc_1, acc_5 = build_program( + test_prog, test_startup_prog, args, False, model, im_shape, + steps_one_epoch) + + test_prog = test_prog.clone(for_test=True) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(train_startup_prog) + exe.run(test_startup_prog) + + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_threads = 1 + train_exe = fluid.ParallelExecutor( + main_program=train_prog, + use_cuda=True, + loss_name=loss_train.name, + exec_strategy=exec_strategy) + train_reader = reader.train10(args) + test_reader = reader.test10(args) + train_py_reader.decorate_paddle_reader(train_reader) + test_py_reader.decorate_paddle_reader(test_reader) + + fluid.clip.set_gradient_clip(fluid.clip.GradientClipByNorm(args.grad_clip)) + fluid.memory_optimize(fluid.default_main_program()) + + def save_model(postfix, main_prog): + model_path = os.path.join(args.model_path, postfix) + if os.path.isdir(model_path): + shutil.rmtree(model_path) + fluid.io.save_persistables(exe, model_path, main_program=main_prog) + + def test(epoch_id): + test_fetch_list = [loss_test, acc_1, acc_5] + objs = utils.AvgrageMeter() + top1 = utils.AvgrageMeter() + top5 = utils.AvgrageMeter() + test_py_reader.start() + test_start_time = time.time() + step_id = 0 + try: + while True: + prev_test_start_time = test_start_time + test_start_time = time.time() + loss_test_v, acc_1_v, acc_5_v = exe.run( + test_prog, fetch_list=test_fetch_list) + objs.update(np.array(loss_test_v), args.batch_size) + top1.update(np.array(acc_1_v), args.batch_size) + top5.update(np.array(acc_5_v), args.batch_size) + if step_id % args.report_freq == 0: + print("Epoch {}, Step {}, acc_1 {}, acc_5 {}, time {}". + format(epoch_id, step_id, + np.array(acc_1_v), + np.array(acc_5_v), test_start_time - + prev_test_start_time)) + step_id += 1 + except fluid.core.EOFException: + test_py_reader.reset() + print("Epoch {0}, top1 {1}, top5 {2}".format(epoch_id, top1.avg, + top5.avg)) + + train_fetch_list = [loss_train] + epoch_start_time = time.time() + for epoch_id in range(args.epochs): + model.drop_path_prob = args.drop_path_prob * epoch_id / args.epochs + train_py_reader.start() + epoch_end_time = time.time() + if epoch_id > 0: + print("Epoch {}, total time {}".format(epoch_id - 1, epoch_end_time + - epoch_start_time)) + epoch_start_time = epoch_end_time + epoch_end_time + start_time = time.time() + step_id = 0 + try: + while True: + prev_start_time = start_time + start_time = time.time() + loss_v, = train_exe.run( + fetch_list=[v.name for v in train_fetch_list]) + print("Epoch {}, Step {}, loss {}, time {}".format(epoch_id, step_id, \ + np.array(loss_v).mean(), start_time-prev_start_time)) + step_id += 1 + sys.stdout.flush() + except fluid.core.EOFException: + train_py_reader.reset() + if epoch_id % 50 == 0 or epoch_id == args.epochs - 1: + save_model(str(epoch_id), train_prog) + test(epoch_id) + + +if __name__ == '__main__': + main() diff --git a/LRC/utils.py b/LRC/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4002b57c6e91f9a4f7992156c4fa07f9e55d628c --- /dev/null +++ b/LRC/utils.py @@ -0,0 +1,55 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Based on: +# -------------------------------------------------------- +# DARTS +# Copyright (c) 2018, Hanxiao Liu. +# Licensed under the Apache License, Version 2.0; +# -------------------------------------------------------- + +import os +import sys +import time +import math +import numpy as np + + +def mixup_data(x, y, batch_size, alpha=1.0): + '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' + if alpha > 0.: + lam = np.random.beta(alpha, alpha) + else: + lam = 1. + index = np.random.permutation(batch_size) + + mixed_x = lam * x + (1 - lam) * x[index, :] + y_a, y_b = y, y[index] + return mixed_x.astype('float32'), y_a.astype('int64'),\ + y_b.astype('int64'), np.array(lam, dtype='float32') + + +class AvgrageMeter(object): + def __init__(self): + self.reset() + + def reset(self): + self.avg = 0 + self.sum = 0 + self.cnt = 0 + + def update(self, val, n=1): + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt