未验证 提交 823ca6bb 编写于 作者: B Bai Yifan 提交者: GitHub

Fix grad_clip in DARTS, grad_clip has been upgraded in Paddle2.0 (#229)

上级 388211f3
...@@ -29,15 +29,15 @@ python search.py --method='PC-DARTS' --batch_size=256 --learning_rate=0.1 --arch ...@@ -29,15 +29,15 @@ python search.py --method='PC-DARTS' --batch_size=256 --learning_rate=0.1 --arch
图1: 在CIFAR10数据集上进行搜索的模型结构变化,上半部分为reduction cell,下半部分为normal cell 图1: 在CIFAR10数据集上进行搜索的模型结构变化,上半部分为reduction cell,下半部分为normal cell
</p> </p>
使用三种搜索方法得到的结构Genotype已添加到了genotypes.py文件中,`DARTS_V1``DARTS_V2``PC-DARTS`分别代表使用DARTS一阶、二阶近似方法和PC-DARTS搜索方法得到的网络结构。 使用三种搜索方法得到的结构Genotype已添加到了genotypes.py文件中,`DARTS_V1``DARTS_V2``PC_DARTS`分别代表使用DARTS一阶、二阶近似方法和PC-DARTS搜索方法得到的网络结构。
## 网络结构评估训练 ## 网络结构评估训练
在得到搜索结构Genotype之后,可以对其进行评估训练,从而获得它在特定数据集上的真实性能 在得到搜索结构Genotype之后,可以对其进行评估训练,从而获得它在特定数据集上的真实性能
```bash ```bash
python train.py --arch='PC-DARTS' # 在CIFAR10数据集上对搜索到的结构评估训练 python train.py --arch='PC_DARTS' # 在CIFAR10数据集上对搜索到的结构评估训练
python train_imagenet.py --arch='PC-DARTS' # 在ImageNet数据集上对搜索得到的结构评估训练 python train_imagenet.py --arch='PC_DARTS' # 在ImageNet数据集上对搜索得到的结构评估训练
``` ```
对搜索到的`DARTS_V1``DARTS_V2``PC-DARTS`做评估训练的结果如下: 对搜索到的`DARTS_V1``DARTS_V2``PC-DARTS`做评估训练的结果如下:
...@@ -83,7 +83,7 @@ def train_search(batch_size, train_portion, is_shuffle, args): ...@@ -83,7 +83,7 @@ def train_search(batch_size, train_portion, is_shuffle, args):
使用以下命令对搜索得到的Genotype结构进行可视化观察 使用以下命令对搜索得到的Genotype结构进行可视化观察
```python ```python
python visualize.py PC-DARTS python visualize.py PC_DARTS
``` ```
`PC-DARTS`代表某个Genotype结构,需要预先添加到genotype.py中 `PC_DARTS`代表某个Genotype结构,需要预先添加到genotype.py中
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import ConstantInitializer, MSRAInitializer from paddle.fluid.initializer import ConstantInitializer, MSRAInitializer
......
...@@ -35,7 +35,7 @@ add_arg = functools.partial(add_arguments, argparser=parser) ...@@ -35,7 +35,7 @@ add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('log_freq', int, 50, "Log frequency.") add_arg('log_freq', int, 50, "Log frequency.")
add_arg('use_multiprocess', bool, True, "Whether use multiprocess reader.") add_arg('use_multiprocess', bool, False, "Whether use multiprocess reader.")
add_arg('num_workers', int, 4, "The multiprocess reader number.") add_arg('num_workers', int, 4, "The multiprocess reader number.")
add_arg('data', str, 'dataset/cifar10',"The dir of dataset.") add_arg('data', str, 'dataset/cifar10',"The dir of dataset.")
add_arg('batch_size', int, 64, "Minibatch size.") add_arg('batch_size', int, 64, "Minibatch size.")
......
...@@ -21,26 +21,24 @@ import sys ...@@ -21,26 +21,24 @@ import sys
import ast import ast
import argparse import argparse
import functools import functools
import logging import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from model import NetworkCIFAR as Network from paddleslim.common import AvgrageMeter, get_logger
from paddleslim.common import AvgrageMeter
import genotypes import genotypes
import reader import reader
from model import NetworkCIFAR as Network
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir) sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('use_multiprocess', bool, True, "Whether use multiprocess reader.") add_arg('use_multiprocess', bool, False, "Whether use multiprocess reader.")
add_arg('num_workers', int, 4, "The multiprocess reader number.") add_arg('num_workers', int, 4, "The multiprocess reader number.")
add_arg('data', str, 'dataset/cifar10',"The dir of dataset.") add_arg('data', str, 'dataset/cifar10',"The dir of dataset.")
add_arg('batch_size', int, 96, "Minibatch size.") add_arg('batch_size', int, 96, "Minibatch size.")
...@@ -61,7 +59,7 @@ add_arg('auxiliary_weight', float, 0.4, "Weight for auxiliary loss. ...@@ -61,7 +59,7 @@ add_arg('auxiliary_weight', float, 0.4, "Weight for auxiliary loss.
add_arg('drop_path_prob', float, 0.2, "Drop path probability.") add_arg('drop_path_prob', float, 0.2, "Drop path probability.")
add_arg('grad_clip', float, 5, "Gradient clipping.") add_arg('grad_clip', float, 5, "Gradient clipping.")
add_arg('arch', str, 'DARTS_V2', "Which architecture to use") add_arg('arch', str, 'DARTS_V2', "Which architecture to use")
add_arg('report_freq', int, 50, 'Report frequency') add_arg('log_freq', int, 50, 'Report frequency')
add_arg('use_data_parallel', ast.literal_eval, False, "The flag indicating whether to use data parallel mode to train the model.") add_arg('use_data_parallel', ast.literal_eval, False, "The flag indicating whether to use data parallel mode to train the model.")
# yapf: enable # yapf: enable
...@@ -95,9 +93,7 @@ def train(model, train_reader, optimizer, epoch, drop_path_prob, args): ...@@ -95,9 +93,7 @@ def train(model, train_reader, optimizer, epoch, drop_path_prob, args):
else: else:
loss.backward() loss.backward()
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm( optimizer.minimize(loss)
args.grad_clip)
optimizer.minimize(loss, grad_clip=grad_clip)
model.clear_gradients() model.clear_gradients()
n = image.shape[0] n = image.shape[0]
...@@ -105,7 +101,7 @@ def train(model, train_reader, optimizer, epoch, drop_path_prob, args): ...@@ -105,7 +101,7 @@ def train(model, train_reader, optimizer, epoch, drop_path_prob, args):
top1.update(prec1.numpy(), n) top1.update(prec1.numpy(), n)
top5.update(prec5.numpy(), n) top5.update(prec5.numpy(), n)
if step_id % args.report_freq == 0: if step_id % args.log_freq == 0:
logger.info( logger.info(
"Train Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}". "Train Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[0])) format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[0]))
...@@ -132,7 +128,7 @@ def valid(model, valid_reader, epoch, args): ...@@ -132,7 +128,7 @@ def valid(model, valid_reader, epoch, args):
objs.update(loss.numpy(), n) objs.update(loss.numpy(), n)
top1.update(prec1.numpy(), n) top1.update(prec1.numpy(), n)
top5.update(prec5.numpy(), n) top5.update(prec5.numpy(), n)
if step_id % args.report_freq == 0: if step_id % args.log_freq == 0:
logger.info( logger.info(
"Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}". "Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[0])) format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[0]))
...@@ -158,11 +154,13 @@ def main(args): ...@@ -158,11 +154,13 @@ def main(args):
step_per_epoch = int(args.trainset_num / args.batch_size) step_per_epoch = int(args.trainset_num / args.batch_size)
learning_rate = fluid.dygraph.CosineDecay(args.learning_rate, learning_rate = fluid.dygraph.CosineDecay(args.learning_rate,
step_per_epoch, args.epochs) step_per_epoch, args.epochs)
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=args.grad_clip)
optimizer = fluid.optimizer.MomentumOptimizer( optimizer = fluid.optimizer.MomentumOptimizer(
learning_rate, learning_rate,
momentum=args.momentum, momentum=args.momentum,
regularization=fluid.regularizer.L2Decay(args.weight_decay), regularization=fluid.regularizer.L2Decay(args.weight_decay),
parameter_list=model.parameters()) parameter_list=model.parameters(),
grad_clip=clip)
if args.use_data_parallel: if args.use_data_parallel:
model = fluid.dygraph.parallel.DataParallel(model, strategy) model = fluid.dygraph.parallel.DataParallel(model, strategy)
......
...@@ -21,20 +21,17 @@ import sys ...@@ -21,20 +21,17 @@ import sys
import ast import ast
import argparse import argparse
import functools import functools
import logging import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from model import NetworkImageNet as Network from paddleslim.common import AvgrageMeter, get_logger
from paddleslim.common import AvgrageMeter
import genotypes import genotypes
import reader import reader
from model import NetworkImageNet as Network
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir) sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
...@@ -62,7 +59,7 @@ add_arg('dropout', float, 0.0, "Dropout probability.") ...@@ -62,7 +59,7 @@ add_arg('dropout', float, 0.0, "Dropout probability.")
add_arg('grad_clip', float, 5, "Gradient clipping.") add_arg('grad_clip', float, 5, "Gradient clipping.")
add_arg('label_smooth', float, 0.1, "Label smoothing.") add_arg('label_smooth', float, 0.1, "Label smoothing.")
add_arg('arch', str, 'DARTS_V2', "Which architecture to use") add_arg('arch', str, 'DARTS_V2', "Which architecture to use")
add_arg('report_freq', int, 100, 'Report frequency') add_arg('log_freq', int, 100, 'Report frequency')
add_arg('use_data_parallel', ast.literal_eval, False, "The flag indicating whether to use data parallel mode to train the model.") add_arg('use_data_parallel', ast.literal_eval, False, "The flag indicating whether to use data parallel mode to train the model.")
# yapf: enable # yapf: enable
...@@ -108,9 +105,7 @@ def train(model, train_reader, optimizer, epoch, args): ...@@ -108,9 +105,7 @@ def train(model, train_reader, optimizer, epoch, args):
else: else:
loss.backward() loss.backward()
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm( optimizer.minimize(loss)
args.grad_clip)
optimizer.minimize(loss, grad_clip=grad_clip)
model.clear_gradients() model.clear_gradients()
n = image.shape[0] n = image.shape[0]
...@@ -118,7 +113,7 @@ def train(model, train_reader, optimizer, epoch, args): ...@@ -118,7 +113,7 @@ def train(model, train_reader, optimizer, epoch, args):
top1.update(prec1.numpy(), n) top1.update(prec1.numpy(), n)
top5.update(prec5.numpy(), n) top5.update(prec5.numpy(), n)
if step_id % args.report_freq == 0: if step_id % args.log_freq == 0:
logger.info( logger.info(
"Train Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}". "Train Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[0])) format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[0]))
...@@ -145,7 +140,7 @@ def valid(model, valid_reader, epoch, args): ...@@ -145,7 +140,7 @@ def valid(model, valid_reader, epoch, args):
objs.update(loss.numpy(), n) objs.update(loss.numpy(), n)
top1.update(prec1.numpy(), n) top1.update(prec1.numpy(), n)
top5.update(prec5.numpy(), n) top5.update(prec5.numpy(), n)
if step_id % args.report_freq == 0: if step_id % args.log_freq == 0:
logger.info( logger.info(
"Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}". "Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[0])) format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[0]))
...@@ -174,11 +169,14 @@ def main(args): ...@@ -174,11 +169,14 @@ def main(args):
step_per_epoch, step_per_epoch,
args.decay_rate, args.decay_rate,
staircase=True) staircase=True)
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=args.grad_clip)
optimizer = fluid.optimizer.MomentumOptimizer( optimizer = fluid.optimizer.MomentumOptimizer(
learning_rate, learning_rate,
momentum=args.momentum, momentum=args.momentum,
regularization=fluid.regularizer.L2Decay(args.weight_decay), regularization=fluid.regularizer.L2Decay(args.weight_decay),
parameter_list=model.parameters()) parameter_list=model.parameters(),
grad_clip=clip)
if args.use_data_parallel: if args.use_data_parallel:
model = fluid.dygraph.parallel.DataParallel(model, strategy) model = fluid.dygraph.parallel.DataParallel(model, strategy)
......
...@@ -108,8 +108,7 @@ class DARTSearch(object): ...@@ -108,8 +108,7 @@ class DARTSearch(object):
else: else:
loss.backward() loss.backward()
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(5) optimizer.minimize(loss)
optimizer.minimize(loss, grad_clip)
self.model.clear_gradients() self.model.clear_gradients()
objs.update(loss.numpy(), n) objs.update(loss.numpy(), n)
...@@ -163,11 +162,14 @@ class DARTSearch(object): ...@@ -163,11 +162,14 @@ class DARTSearch(object):
step_per_epoch *= 2 step_per_epoch *= 2
learning_rate = fluid.dygraph.CosineDecay( learning_rate = fluid.dygraph.CosineDecay(
self.learning_rate, step_per_epoch, self.num_epochs) self.learning_rate, step_per_epoch, self.num_epochs)
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0)
optimizer = fluid.optimizer.MomentumOptimizer( optimizer = fluid.optimizer.MomentumOptimizer(
learning_rate, learning_rate,
0.9, 0.9,
regularization=fluid.regularizer.L2DecayRegularizer(3e-4), regularization=fluid.regularizer.L2DecayRegularizer(3e-4),
parameter_list=model_parameters) parameter_list=model_parameters,
grad_clip=clip)
if self.use_data_parallel: if self.use_data_parallel:
self.model = fluid.dygraph.parallel.DataParallel(self.model, self.model = fluid.dygraph.parallel.DataParallel(self.model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册