未验证 提交 b6ba314d 编写于 作者: B Bai Yifan 提交者: GitHub

[Cherry-Pick]Refine reader config (#267)

* refine reader config
上级 a82c9df4
......@@ -42,6 +42,7 @@ python search.py # DARTS一阶近似搜索方法
python search.py --unrolled=True # DARTS的二阶近似搜索方法
python search.py --method='PC-DARTS' --batch_size=256 --learning_rate=0.1 --arch_learning_rate=6e-4 --epochs_no_archopt=15 # PC-DARTS搜索方法
```
如果您使用的是docker环境,请确保共享内存足够使用多进程的dataloader,如果碰到共享内存问题,请设置`--use_multiprocess=False`
也可以使用多卡进行模型结构搜索,以4卡为例(GPU id: 0-3), 启动命令如下:
......
......@@ -140,32 +140,10 @@ def train_search(batch_size, train_portion, is_shuffle, args):
split_point = int(np.floor(train_portion * len(datasets)))
train_datasets = datasets[:split_point]
val_datasets = datasets[split_point:]
train_readers = []
val_readers = []
n = int(math.ceil(len(train_datasets) // args.num_workers)
) if args.use_multiprocess else len(train_datasets)
train_datasets_lists = [
train_datasets[i:i + n] for i in range(0, len(train_datasets), n)
reader = [
reader_generator(train_datasets, batch_size, True, True, args),
reader_generator(val_datasets, batch_size, True, True, args)
]
val_datasets_lists = [
val_datasets[i:i + n] for i in range(0, len(val_datasets), n)
]
for pid in range(len(train_datasets_lists)):
train_readers.append(
reader_generator(train_datasets_lists[pid], batch_size, True, True,
args))
val_readers.append(
reader_generator(val_datasets_lists[pid], batch_size, True, True,
args))
if args.use_multiprocess:
reader = [
paddle.reader.multiprocess_reader(train_readers, False),
paddle.reader.multiprocess_reader(val_readers, False)
]
else:
reader = [train_readers[0], val_readers[0]]
return reader
......@@ -174,18 +152,8 @@ def train_valid(batch_size, is_train, is_shuffle, args):
datasets = cifar10_reader(
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
name, is_shuffle, args)
n = int(math.ceil(len(datasets) // args.
num_workers)) if args.use_multiprocess else len(datasets)
datasets_lists = [datasets[i:i + n] for i in range(0, len(datasets), n)]
multi_readers = []
for pid in range(len(datasets_lists)):
multi_readers.append(
reader_generator(datasets_lists[pid], batch_size, is_train,
is_shuffle, args))
if args.use_multiprocess:
reader = paddle.reader.multiprocess_reader(multi_readers, False)
else:
reader = multi_readers[0]
reader = reader_generator(datasets, batch_size, is_train, is_shuffle, args)
return reader
......
......@@ -35,8 +35,7 @@ add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('log_freq', int, 50, "Log frequency.")
add_arg('use_multiprocess', bool, False, "Whether use multiprocess reader.")
add_arg('num_workers', int, 4, "The multiprocess reader number.")
add_arg('use_multiprocess', bool, True, "Whether use multiprocess reader.")
add_arg('data', str, 'dataset/cifar10',"The dir of dataset.")
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('learning_rate', float, 0.025, "The start learning rate.")
......@@ -88,6 +87,7 @@ def main(args):
unrolled=args.unrolled,
num_epochs=args.epochs,
epochs_no_archopt=args.epochs_no_archopt,
use_multiprocess=args.use_multiprocess,
use_data_parallel=args.use_data_parallel,
save_dir=args.model_save_dir,
log_freq=args.log_freq)
......
......@@ -39,8 +39,7 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_multiprocess', bool, False, "Whether use multiprocess reader.")
add_arg('num_workers', int, 4, "The multiprocess reader number.")
add_arg('use_multiprocess', bool, True, "Whether use multiprocess reader.")
add_arg('data', str, 'dataset/cifar10',"The dir of dataset.")
add_arg('batch_size', int, 96, "Minibatch size.")
add_arg('learning_rate', float, 0.025, "The start learning rate.")
......@@ -170,17 +169,17 @@ def main(args):
model = fluid.dygraph.parallel.DataParallel(model, strategy)
train_loader = fluid.io.DataLoader.from_generator(
capacity=1024,
capacity=64,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
use_multiprocess=args.use_multiprocess)
valid_loader = fluid.io.DataLoader.from_generator(
capacity=1024,
capacity=64,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
use_multiprocess=args.use_multiprocess)
train_reader = reader.train_valid(
batch_size=args.batch_size,
......
......@@ -4,7 +4,7 @@
DARTSearch
---------
.. py:class:: paddleslim.nas.DARTSearch(model, train_reader, valid_reader, place, learning_rate=0.025, batchsize=64, num_imgs=50000, arch_learning_rate=3e-4, unrolled=False, num_epochs=50, epochs_no_archopt=0, use_data_parallel=False, save_dir='./', log_freq=50)
.. py:class:: paddleslim.nas.DARTSearch(model, train_reader, valid_reader, place, learning_rate=0.025, batchsize=64, num_imgs=50000, arch_learning_rate=3e-4, unrolled=False, num_epochs=50, epochs_no_archopt=0, use_multiprocess=False, use_data_parallel=False, save_dir='./', log_freq=50)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/release/1.1.0/paddleslim/nas/darts/train_search.py>`_
......@@ -18,11 +18,14 @@ DARTSearch
- **place** (fluid.CPUPlace()|fluid.CUDAPlace(N))-该参数表示程序运行在何种设备上,这里的NGPU对应的ID
- **learning_rate** (float)-模型参数的初始学习率。默认值:0.025
- **batchsize** (int)-搜索过程数据的批大小。默认值:64
- **num_imgs** (int)-数据集总样本数。默认值:50000
- **arch_learning_rate** (float)-架构参数的学习率。默认值:3e-4
- **unrolled** (bool)-是否使用二阶搜索算法。默认值:False
- **num_epochs** (int)-搜索训练的轮数。默认值:50
- **epochs_no_archopt** (int)-跳过前若干轮的模型架构参数优化。默认值:0
- **use_multiprocess** (bool)-是否使用多进程的dataloader。默认值:False
- **use_data_parallel** (bool)-是否使用数据并行的多卡训练。默认值:False
- **save_dir** (str)-模型参数保存目录。默认值:'./'
- **log_freq** (int)-每多少步输出一条log。默认值:50
......
......@@ -59,6 +59,7 @@ class DARTSearch(object):
unrolled(bool): Use one-step unrolled validation loss. Default: False.
num_epochs(int): Epoch number. Default: 50.
epochs_no_archopt(int): Epochs skip architecture optimize at begining. Default: 0.
use_multiprocess(bool): Whether to use multiprocess in dataloader. Default: False.
use_data_parallel(bool): Whether to use data parallel mode. Default: False.
log_freq(int): Log frequency. Default: 50.
......@@ -76,6 +77,7 @@ class DARTSearch(object):
unrolled=False,
num_epochs=50,
epochs_no_archopt=0,
use_multiprocess=False,
use_data_parallel=False,
save_dir='./',
log_freq=50):
......@@ -90,6 +92,7 @@ class DARTSearch(object):
self.unrolled = unrolled
self.epochs_no_archopt = epochs_no_archopt
self.num_epochs = num_epochs
self.use_multiprocess = use_multiprocess
self.use_data_parallel = use_data_parallel
self.save_dir = save_dir
self.log_freq = log_freq
......@@ -207,17 +210,17 @@ class DARTSearch(object):
self.valid_reader)
train_loader = fluid.io.DataLoader.from_generator(
capacity=1024,
capacity=64,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
use_multiprocess=self.use_multiprocess)
valid_loader = fluid.io.DataLoader.from_generator(
capacity=1024,
capacity=64,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
use_multiprocess=self.use_multiprocess)
train_loader.set_batch_generator(self.train_reader, places=self.place)
valid_loader.set_batch_generator(self.valid_reader, places=self.place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册