From d4d5f89fff0527dc573e23fd5a0f6dad2316e97d Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Sat, 9 May 2020 16:21:09 +0800 Subject: [PATCH] Refine reader config (#266) * refine reader config --- demo/darts/README.md | 1 + demo/darts/reader.py | 42 ++++------------------------ demo/darts/search.py | 4 +-- demo/darts/train.py | 11 ++++---- docs/zh_cn/api_cn/darts.rst | 5 +++- paddleslim/nas/darts/train_search.py | 11 +++++--- 6 files changed, 24 insertions(+), 50 deletions(-) diff --git a/demo/darts/README.md b/demo/darts/README.md index 92ce4079..64f6bf60 100644 --- a/demo/darts/README.md +++ b/demo/darts/README.md @@ -42,6 +42,7 @@ python search.py # DARTS一阶近似搜索方法 python search.py --unrolled=True # DARTS的二阶近似搜索方法 python search.py --method='PC-DARTS' --batch_size=256 --learning_rate=0.1 --arch_learning_rate=6e-4 --epochs_no_archopt=15 # PC-DARTS搜索方法 ``` +如果您使用的是docker环境,请确保共享内存足够使用多进程的dataloader,如果碰到共享内存问题,请设置`--use_multiprocess=False` 也可以使用多卡进行模型结构搜索,以4卡为例(GPU id: 0-3), 启动命令如下: diff --git a/demo/darts/reader.py b/demo/darts/reader.py index 9a3718bc..93f31861 100644 --- a/demo/darts/reader.py +++ b/demo/darts/reader.py @@ -140,32 +140,10 @@ def train_search(batch_size, train_portion, is_shuffle, args): split_point = int(np.floor(train_portion * len(datasets))) train_datasets = datasets[:split_point] val_datasets = datasets[split_point:] - train_readers = [] - val_readers = [] - n = int(math.ceil(len(train_datasets) // args.num_workers) - ) if args.use_multiprocess else len(train_datasets) - train_datasets_lists = [ - train_datasets[i:i + n] for i in range(0, len(train_datasets), n) + reader = [ + reader_generator(train_datasets, batch_size, True, True, args), + reader_generator(val_datasets, batch_size, True, True, args) ] - val_datasets_lists = [ - val_datasets[i:i + n] for i in range(0, len(val_datasets), n) - ] - - for pid in range(len(train_datasets_lists)): - train_readers.append( - reader_generator(train_datasets_lists[pid], batch_size, True, True, - args)) - val_readers.append( - reader_generator(val_datasets_lists[pid], batch_size, True, True, - args)) - if args.use_multiprocess: - reader = [ - paddle.reader.multiprocess_reader(train_readers, False), - paddle.reader.multiprocess_reader(val_readers, False) - ] - else: - reader = [train_readers[0], val_readers[0]] - return reader @@ -174,18 +152,8 @@ def train_valid(batch_size, is_train, is_shuffle, args): datasets = cifar10_reader( paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), name, is_shuffle, args) - n = int(math.ceil(len(datasets) // args. - num_workers)) if args.use_multiprocess else len(datasets) - datasets_lists = [datasets[i:i + n] for i in range(0, len(datasets), n)] - multi_readers = [] - for pid in range(len(datasets_lists)): - multi_readers.append( - reader_generator(datasets_lists[pid], batch_size, is_train, - is_shuffle, args)) - if args.use_multiprocess: - reader = paddle.reader.multiprocess_reader(multi_readers, False) - else: - reader = multi_readers[0] + + reader = reader_generator(datasets, batch_size, is_train, is_shuffle, args) return reader diff --git a/demo/darts/search.py b/demo/darts/search.py index f89408d2..d8a8c484 100644 --- a/demo/darts/search.py +++ b/demo/darts/search.py @@ -35,8 +35,7 @@ add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable add_arg('log_freq', int, 50, "Log frequency.") -add_arg('use_multiprocess', bool, False, "Whether use multiprocess reader.") -add_arg('num_workers', int, 4, "The multiprocess reader number.") +add_arg('use_multiprocess', bool, True, "Whether use multiprocess reader.") add_arg('data', str, 'dataset/cifar10',"The dir of dataset.") add_arg('batch_size', int, 64, "Minibatch size.") add_arg('learning_rate', float, 0.025, "The start learning rate.") @@ -88,6 +87,7 @@ def main(args): unrolled=args.unrolled, num_epochs=args.epochs, epochs_no_archopt=args.epochs_no_archopt, + use_multiprocess=args.use_multiprocess, use_data_parallel=args.use_data_parallel, save_dir=args.model_save_dir, log_freq=args.log_freq) diff --git a/demo/darts/train.py b/demo/darts/train.py index 70c48f37..39b31633 100644 --- a/demo/darts/train.py +++ b/demo/darts/train.py @@ -39,8 +39,7 @@ parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable -add_arg('use_multiprocess', bool, False, "Whether use multiprocess reader.") -add_arg('num_workers', int, 4, "The multiprocess reader number.") +add_arg('use_multiprocess', bool, True, "Whether use multiprocess reader.") add_arg('data', str, 'dataset/cifar10',"The dir of dataset.") add_arg('batch_size', int, 96, "Minibatch size.") add_arg('learning_rate', float, 0.025, "The start learning rate.") @@ -170,17 +169,17 @@ def main(args): model = fluid.dygraph.parallel.DataParallel(model, strategy) train_loader = fluid.io.DataLoader.from_generator( - capacity=1024, + capacity=64, use_double_buffer=True, iterable=True, return_list=True, - use_multiprocess=True) + use_multiprocess=args.use_multiprocess) valid_loader = fluid.io.DataLoader.from_generator( - capacity=1024, + capacity=64, use_double_buffer=True, iterable=True, return_list=True, - use_multiprocess=True) + use_multiprocess=args.use_multiprocess) train_reader = reader.train_valid( batch_size=args.batch_size, diff --git a/docs/zh_cn/api_cn/darts.rst b/docs/zh_cn/api_cn/darts.rst index 41f23e9c..c4fd7256 100644 --- a/docs/zh_cn/api_cn/darts.rst +++ b/docs/zh_cn/api_cn/darts.rst @@ -4,7 +4,7 @@ DARTSearch --------- -.. py:class:: paddleslim.nas.DARTSearch(model, train_reader, valid_reader, place, learning_rate=0.025, batchsize=64, num_imgs=50000, arch_learning_rate=3e-4, unrolled=False, num_epochs=50, epochs_no_archopt=0, use_data_parallel=False, save_dir='./', log_freq=50) +.. py:class:: paddleslim.nas.DARTSearch(model, train_reader, valid_reader, place, learning_rate=0.025, batchsize=64, num_imgs=50000, arch_learning_rate=3e-4, unrolled=False, num_epochs=50, epochs_no_archopt=0, use_multiprocess=False, use_data_parallel=False, save_dir='./', log_freq=50) `源代码 `_ @@ -18,11 +18,14 @@ DARTSearch - **place** (fluid.CPUPlace()|fluid.CUDAPlace(N))-该参数表示程序运行在何种设备上,这里的N为GPU对应的ID - **learning_rate** (float)-模型参数的初始学习率。默认值:0.025。 - **batchsize** (int)-搜索过程数据的批大小。默认值:64。 +- **num_imgs** (int)-数据集总样本数。默认值:50000。 - **arch_learning_rate** (float)-架构参数的学习率。默认值:3e-4。 - **unrolled** (bool)-是否使用二阶搜索算法。默认值:False。 - **num_epochs** (int)-搜索训练的轮数。默认值:50。 - **epochs_no_archopt** (int)-跳过前若干轮的模型架构参数优化。默认值:0。 +- **use_multiprocess** (bool)-是否使用多进程的dataloader。默认值:False。 - **use_data_parallel** (bool)-是否使用数据并行的多卡训练。默认值:False。 +- **save_dir** (str)-模型参数保存目录。默认值:'./'。 - **log_freq** (int)-每多少步输出一条log。默认值:50。 diff --git a/paddleslim/nas/darts/train_search.py b/paddleslim/nas/darts/train_search.py index c0fd0bfc..c8c70413 100644 --- a/paddleslim/nas/darts/train_search.py +++ b/paddleslim/nas/darts/train_search.py @@ -59,6 +59,7 @@ class DARTSearch(object): unrolled(bool): Use one-step unrolled validation loss. Default: False. num_epochs(int): Epoch number. Default: 50. epochs_no_archopt(int): Epochs skip architecture optimize at begining. Default: 0. + use_multiprocess(bool): Whether to use multiprocess in dataloader. Default: False. use_data_parallel(bool): Whether to use data parallel mode. Default: False. log_freq(int): Log frequency. Default: 50. @@ -76,6 +77,7 @@ class DARTSearch(object): unrolled=False, num_epochs=50, epochs_no_archopt=0, + use_multiprocess=False, use_data_parallel=False, save_dir='./', log_freq=50): @@ -90,6 +92,7 @@ class DARTSearch(object): self.unrolled = unrolled self.epochs_no_archopt = epochs_no_archopt self.num_epochs = num_epochs + self.use_multiprocess = use_multiprocess self.use_data_parallel = use_data_parallel self.save_dir = save_dir self.log_freq = log_freq @@ -207,17 +210,17 @@ class DARTSearch(object): self.valid_reader) train_loader = fluid.io.DataLoader.from_generator( - capacity=1024, + capacity=64, use_double_buffer=True, iterable=True, return_list=True, - use_multiprocess=True) + use_multiprocess=self.use_multiprocess) valid_loader = fluid.io.DataLoader.from_generator( - capacity=1024, + capacity=64, use_double_buffer=True, iterable=True, return_list=True, - use_multiprocess=True) + use_multiprocess=self.use_multiprocess) train_loader.set_batch_generator(self.train_reader, places=self.place) valid_loader.set_batch_generator(self.valid_reader, places=self.place) -- GitLab