From 2e4465c61e8f3d89acefb0920a358f32a4b8f1fd Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Tue, 26 Nov 2019 15:45:01 +0800 Subject: [PATCH] Add dali preprocessing pipeline for imagenet training (#3960) * Use iterable dataloader for image classification * Support imagenet dataset with plain folder structure * Add dali section to README.md * Improve docs * Remove some flags from docs * Add dali reader for image classification * Fix elapsed time calculation --- PaddleCV/image_classification/README.md | 36 +++ PaddleCV/image_classification/README_en.md | 42 ++++ PaddleCV/image_classification/dali.py | 214 ++++++++++++++++++ PaddleCV/image_classification/train.py | 133 ++++++----- .../image_classification/utils/utility.py | 8 +- 5 files changed, 362 insertions(+), 71 deletions(-) create mode 100644 PaddleCV/image_classification/dali.py diff --git a/PaddleCV/image_classification/README.md b/PaddleCV/image_classification/README.md index 1359b038..052b1c58 100644 --- a/PaddleCV/image_classification/README.md +++ b/PaddleCV/image_classification/README.md @@ -14,6 +14,7 @@ - [进阶使用](#进阶使用) - [Mixup训练](#mixup训练) - [混合精度训练](#混合精度训练) + - [DALI预处理](#DALI预处理) - [自定义数据集](#自定义数据集) - [已发布模型及其性能](#已发布模型及其性能) - [FAQ](#faq) @@ -246,6 +247,41 @@ Mixup相关介绍参考[mixup: Beyond Empirical Risk Minimization](https://arxiv FP16相关内容已经迁移至PaddlePaddle/Fleet 中 +### DALI预处理 + +使用[Nvidia DALI](https://github.com/NVIDIA/DALI)预处理类库可以加速训练并提高GPU利用率。 + +DALI预处理目前支持标准ImageNet处理步骤( random crop -> resize -> flip -> normalize),并且支持列表文件或者文件夹方式的数据集格式。 + +指定`--use_dali=True`即可开启DALI预处理,如下面的例子中,使用DALI训练ShuffleNet v2 0.25x,在8卡v100上,图片吞吐可以达到10000张/秒以上,GPU利用率在85%以上。 + +``` bash +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FLAGS_fraction_of_gpu_memory_to_use=0.80 + +python -m paddle.distributed.launch train.py \ + --model=ShuffleNetV2_x0_25 \ + --batch_size=2048 \ + --class_dim=1000 \ + --image_shape=3,224,224 \ + --lr_strategy=cosine_decay_warmup \ + --num_epochs=240 \ + --lr=0.5 \ + --l2_decay=3e-5 \ + --lower_scale=0.64 \ + --lower_ratio=0.8 \ + --upper_ratio=1.2 \ + --use_dali=True +``` + +更多DALI相关用例请参考[DALI Paddle插件文档](https://docs.nvidia.com/deeplearning/sdk/dali-master-branch-user-guide/docs/plugins/paddle_tutorials.html)。 + +#### 注意事项 + +1. PaddlePaddle需使用1.6或以上的版本,并且需要使用GCC5.4以上编译器编译。 +2. Nvidia DALI需要使用[#1371](https://github.com/NVIDIA/DALI/pull/1371)以后的git版本。请参考[此文档](https://docs.nvidia.com/deeplearning/sdk/dali-master-branch-user-guide/docs/installation.html)安装nightly版本或从源码安装。 +3. 因为DALI使用GPU进行图片预处理,需要占用部分显存,请适当调整 `FLAGS_fraction_of_gpu_memory_to_use`环境变量(如`0.8`)来预留部分显存供DALI使用。 + ### 自定义数据集 PaddlePaddle/Models ImageClassification 支持自定义数据 diff --git a/PaddleCV/image_classification/README_en.md b/PaddleCV/image_classification/README_en.md index bcc92ff1..6a05b7a3 100644 --- a/PaddleCV/image_classification/README_en.md +++ b/PaddleCV/image_classification/README_en.md @@ -15,6 +15,7 @@ English | [中文](README.md) - [Advanced Usage](#advanced-usage) - [Mixup Training](#mixup-training) - [Using Mixed-Precision Training](#using-mixed-precision-training) + - [Preprocessing with Nvidia DALI](#preprocessing-with-nvidia-dali) - [Custom Dataset](#custom-dataset) - [Supported Models and Performances](#supported-models-and-performances) - [Reference](#reference) @@ -238,6 +239,47 @@ Refer to [mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710. Mixed-precision part is moving to PaddlePaddle/Fleet now. +### Preprocessing with Nvidia DALI + +[Nvidia DALI](https://github.com/NVIDIA/DALI) can be used to preprocess input images, which could speed up training and achieve higher GPU utilization. + +At present, DALI preprocessing supports the standard ImageNet pipeline (random crop -> resize -> flip -> normalize), it supports dataset in both file list or plain folder format. + +DALI preprocessing can be enabled with the `--use_dali=True` command line flag. +For example, training ShuffleNet v2 0.25x with the following command should +reach a throughput of over 10000 images/second, and GPU utilization should be +above 85%. + +``` bash +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FLAGS_fraction_of_gpu_memory_to_use=0.80 + +python -m paddle.distributed.launch train.py \ + --model=ShuffleNetV2_x0_25 \ + --batch_size=2048 \ + --class_dim=1000 \ + --image_shape=3,224,224 \ + --lr_strategy=cosine_decay_warmup \ + --num_epochs=240 \ + --lr=0.5 \ + --l2_decay=3e-5 \ + --lower_scale=0.64 \ + --lower_ratio=0.8 \ + --upper_ratio=1.2 \ + --use_dali=True + +``` + +For more details please refer to [Documentation on DALI Paddle Plugin](https://docs.nvidia.com/deeplearning/sdk/dali-master-branch-user-guide/docs/plugins/paddle_tutorials.html). + +#### NOTES +1. PaddlePaddle with version 1.6 or above is required, and it must be compiled +with GCC 5.4 and up. +2. Nvidia DALI should include this PR [#1371](https://github.com/NVIDIA/DALI/pull/1371). Please refer to [this doc](https://docs.nvidia.com/deeplearning/sdk/dali-master-branch-user-guide/docs/installation.html) and install nightly version or build from source. +3. Since DALI utilize the GPU for preprocessing, it will take up some GPU + memory. Please reduce the memory used by paddle by setting the + `FLAGS_fraction_of_gpu_memory_to_use` environment variable to a smaller + number (e.g., `0.8`) ### Custom Dataset diff --git a/PaddleCV/image_classification/dali.py b/PaddleCV/image_classification/dali.py new file mode 100644 index 00000000..2abed264 --- /dev/null +++ b/PaddleCV/image_classification/dali.py @@ -0,0 +1,214 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import os + +from nvidia.dali.pipeline import Pipeline +import nvidia.dali.ops as ops +import nvidia.dali.types as types +from nvidia.dali.plugin.paddle import DALIGenericIterator + +import paddle +from paddle import fluid + + +class HybridTrainPipe(Pipeline): + def __init__(self, file_root, file_list, batch_size, resize_shorter, + crop, min_area, lower, upper, interp, mean, std, + device_id, shard_id=0, num_shards=1, random_shuffle=True, + num_threads=4, seed=42): + super(HybridTrainPipe, self).__init__(batch_size, + num_threads, + device_id, + seed=seed) + self.input = ops.FileReader(file_root=file_root, + file_list=file_list, + shard_id=shard_id, + num_shards=num_shards, + random_shuffle=random_shuffle) + # set internal nvJPEG buffers size to handle full-sized ImageNet images + # without additional reallocations + device_memory_padding = 211025920 + host_memory_padding = 140544512 + self.decode = ops.ImageDecoderRandomCrop( + device='mixed', + output_type=types.RGB, + device_memory_padding=device_memory_padding, + host_memory_padding=host_memory_padding, + random_aspect_ratio=[lower, upper], + random_area=[min_area, 1.0], + num_attempts=100) + self.res = ops.Resize(device='gpu', + resize_x=crop, + resize_y=crop, + interp_type=interp) + self.cmnp = ops.CropMirrorNormalize( + device="gpu", + output_dtype=types.FLOAT, + output_layout=types.NCHW, + crop=(crop, crop), + image_type=types.RGB, + mean=mean, + std=std) + self.coin = ops.CoinFlip(probability=0.5) + self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu") + + def define_graph(self): + rng = self.coin() + jpegs, labels = self.input(name="Reader") + images = self.decode(jpegs) + images = self.res(images) + output = self.cmnp(images.gpu(), mirror=rng) + return [output, self.to_int64(labels.gpu())] + + def __len__(self): + return self.epoch_size("Reader") + + +class HybridValPipe(Pipeline): + def __init__(self, file_root, file_list, batch_size, + resize_shorter, crop, interp, mean, std, + device_id, shard_id=0, num_shards=1, random_shuffle=False, + num_threads=4, seed=42): + super(HybridValPipe, self).__init__(batch_size, + num_threads, + device_id, + seed=seed) + self.input = ops.FileReader(file_root=file_root, + file_list=file_list, + shard_id=shard_id, + num_shards=num_shards, + random_shuffle=random_shuffle) + self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB) + self.res = ops.Resize(device="gpu", + resize_shorter=resize_shorter, + interp_type=interp) + self.cmnp = ops.CropMirrorNormalize( + device="gpu", + output_dtype=types.FLOAT, + output_layout=types.NCHW, + crop=(crop, crop), + image_type=types.RGB, + mean=mean, + std=std) + self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu") + + def define_graph(self): + jpegs, labels = self.input(name="Reader") + images = self.decode(jpegs) + images = self.res(images) + output = self.cmnp(images) + return [output, self.to_int64(labels.gpu())] + + def __len__(self): + return self.epoch_size("Reader") + + +def build(settings, mode='train'): + env = os.environ + assert settings.use_gpu, "gpu training is required for DALI" + assert not settings.use_mixup, "mixup is not supported by DALI reader" + assert not settings.use_aa, "auto augment is not supported by DALI reader" + assert float(env.get('FLAGS_fraction_of_gpu_memory_to_use', 0.92)) < 0.9, \ + "Please leave enough GPU memory for DALI workspace, e.g., by setting" \ + " `export FLAGS_fraction_of_gpu_memory_to_use=0.8`" + + file_root = settings.data_dir + bs = settings.batch_size + assert bs % paddle.fluid.core.get_cuda_device_count() == 0, \ + "batch size must be multiple of number of devices" + batch_size = bs // paddle.fluid.core.get_cuda_device_count() + + mean = [v * 255 for v in settings.image_mean] + std = [v * 255 for v in settings.image_std] + crop = settings.crop_size + resize_shorter = settings.resize_short_size + min_area = settings.lower_scale + lower = settings.lower_ratio + upper = settings.upper_ratio + + interp = settings.interpolation or 1 # default to linear + interp_map = { + 0: types.INTERP_NN, # cv2.INTER_NEAREST + 1: types.INTERP_LINEAR, # cv2.INTER_LINEAR + 2: types.INTERP_CUBIC, # cv2.INTER_CUBIC + 4: types.INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4 + } + assert interp in interp_map, "interpolation method not supported by DALI" + interp = interp_map[interp] + + if mode != 'train': + p = fluid.framework.cuda_places()[0] + place = fluid.core.Place() + place.set_place(p) + device_id = place.gpu_device_id() + file_list = os.path.join(file_root, 'val_list.txt') + if not os.path.exists(file_list): + file_list = None + file_root = os.path.join(file_root, 'val') + pipe = HybridValPipe(file_root, file_list, batch_size, + resize_shorter, crop, interp, mean, std, + device_id=device_id) + pipe.build() + return DALIGenericIterator(pipe, ['feed_image', 'feed_label'], + size=len(pipe), dynamic_shape=True, + fill_last_batch=False, + last_batch_padded=True) + + file_list = os.path.join(file_root, 'train_list.txt') + if not os.path.exists(file_list): + file_list = None + file_root = os.path.join(file_root, 'train') + + if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env: + shard_id = int(env['PADDLE_TRAINER_ID']) + num_shards = int(env['PADDLE_TRAINERS_NUM']) + device_id = int(env['FLAGS_selected_gpus']) + pipe = HybridTrainPipe(file_root, file_list, batch_size, + resize_shorter, crop, min_area, + lower, upper, interp, mean, std, + device_id, shard_id, num_shards, + seed=42 + shard_id) + pipe.build() + pipelines = [pipe] + sample_per_shard = len(pipe) // num_shards + else: + pipelines = [] + places = fluid.framework.cuda_places() + num_shards = len(places) + for idx, p in enumerate(places): + place = fluid.core.Place() + place.set_place(p) + device_id = place.gpu_device_id() + pipe = HybridTrainPipe( + file_root, file_list, batch_size, + resize_shorter, crop, min_area, + lower, upper, interp, mean, std, + device_id, idx, num_shards, seed=42 + idx) + pipe.build() + pipelines.append(pipe) + sample_per_shard = len(pipelines[0]) + + return DALIGenericIterator( + pipelines, ['feed_image', 'feed_label'], size=sample_per_shard) + + +def train(settings): + return build(settings, 'train') + + +def val(settings): + return build(settings, 'val') diff --git a/PaddleCV/image_classification/train.py b/PaddleCV/image_classification/train.py index a7f4cca0..75db7f5b 100755 --- a/PaddleCV/image_classification/train.py +++ b/PaddleCV/image_classification/train.py @@ -94,33 +94,29 @@ def build_program(is_train, main_prog, startup_prog, args): loss_out.append(data_loader) return loss_out - -def validate(args, test_data_loader, exe, test_prog, test_fetch_list, pass_id, +def validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id, train_batch_metrics_record): test_batch_time_record = [] test_batch_metrics_record = [] test_batch_id = 0 - test_data_loader.start() - try: - while True: - t1 = time.time() - test_batch_metrics = exe.run(program=test_prog, - fetch_list=test_fetch_list) - t2 = time.time() - test_batch_elapse = t2 - t1 - test_batch_time_record.append(test_batch_elapse) - - test_batch_metrics_avg = np.mean( - np.array(test_batch_metrics), axis=1) - test_batch_metrics_record.append(test_batch_metrics_avg) + for batch in test_iter: + t1 = time.time() + test_batch_metrics = exe.run(program=test_prog, + feed=batch, + fetch_list=test_fetch_list) + t2 = time.time() + test_batch_elapse = t2 - t1 + test_batch_time_record.append(test_batch_elapse) + + test_batch_metrics_avg = np.mean( + np.array(test_batch_metrics), axis=1) + test_batch_metrics_record.append(test_batch_metrics_avg) + + print_info(pass_id, test_batch_id, args.print_step, + test_batch_metrics_avg, test_batch_elapse, "batch") + sys.stdout.flush() + test_batch_id += 1 - print_info(pass_id, test_batch_id, args.print_step, - test_batch_metrics_avg, test_batch_elapse, "batch") - sys.stdout.flush() - test_batch_id += 1 - - except fluid.core.EOFException: - test_data_loader.reset() #train_epoch_time_avg = np.mean(np.array(train_batch_time_record)) train_epoch_metrics_avg = np.mean( np.array(train_batch_metrics_record), axis=0) @@ -176,75 +172,74 @@ def train(args): exe = fluid.Executor(place) exe.run(startup_prog) + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0)) + #init model by checkpoint or pretrianed model. init_model(exe, args, train_prog) num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) - imagenet_reader = reader.ImageNetReader(0 if num_trainers > 1 else None) - train_reader = imagenet_reader.train(settings=args) - test_reader = imagenet_reader.val(settings=args) - - train_data_loader.set_sample_list_generator(train_reader, place) - test_data_loader.set_sample_list_generator(test_reader, place) + if args.use_dali: + import dali + train_iter = dali.train(settings=args) + if trainer_id == 0: + test_iter = dali.val(settings=args) + else: + imagenet_reader = reader.ImageNetReader(0 if num_trainers > 1 else None) + train_reader = imagenet_reader.train(settings=args) + test_reader = imagenet_reader.val(settings=args) + places = place + if num_trainers <= 1 and args.use_gpu: + places = fluid.framework.cuda_places() + train_data_loader.set_sample_list_generator(train_reader, places) + test_data_loader.set_sample_list_generator(test_reader, place) compiled_train_prog = best_strategy_compiled(args, train_prog, train_fetch_vars[0], exe) - trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0)) - total_batch_num = 0 #this is for benchmark for pass_id in range(args.num_epochs): - if num_trainers > 1: - imagenet_reader.set_shuffle_seed(pass_id + ( - args.random_seed if args.random_seed else 0)) + if num_trainers > 1 and not args.use_dali: + imagenet_reader.set_shuffle_seed(pass_id + (args.random_seed if args.random_seed else 0)) train_batch_id = 0 train_batch_time_record = [] train_batch_metrics_record = [] - train_data_loader.start() - try: - while True: - if args.max_iter and total_batch_num == args.max_iter: - return - t1 = time.time() - train_batch_metrics = exe.run(compiled_train_prog, - fetch_list=train_fetch_list) - t2 = time.time() - train_batch_elapse = t2 - t1 - train_batch_time_record.append(train_batch_elapse) - train_batch_metrics_avg = np.mean( - np.array(train_batch_metrics), axis=1) - train_batch_metrics_record.append(train_batch_metrics_avg) - if trainer_id == 0: - print_info(pass_id, train_batch_id, args.print_step, - train_batch_metrics_avg, train_batch_elapse, - "batch") - sys.stdout.flush() - train_batch_id += 1 - total_batch_num = total_batch_num + 1 #this is for benchmark - - ##profiler tools - if args.is_profiler and pass_id == 0 and train_batch_id == 100: - profiler.start_profiler("All") - elif args.is_profiler and pass_id == 0 and train_batch_id == 150: - profiler.stop_profiler("total", args.profiler_path) - return - - except fluid.core.EOFException: - train_data_loader.reset() + if not args.use_dali: + train_iter = train_data_loader() + test_iter = test_data_loader() + + t1 = time.time() + for batch in train_iter: + train_batch_metrics = exe.run(compiled_train_prog, + feed=batch, + fetch_list=train_fetch_list) + t2 = time.time() + train_batch_elapse = t2 - t1 + train_batch_time_record.append(train_batch_elapse) + train_batch_metrics_avg = np.mean( + np.array(train_batch_metrics), axis=1) + train_batch_metrics_record.append(train_batch_metrics_avg) + if trainer_id == 0: + print_info(pass_id, train_batch_id, args.print_step, + train_batch_metrics_avg, train_batch_elapse, "batch") + sys.stdout.flush() + train_batch_id += 1 + t1 = time.time() + + if args.use_dali: + train_iter.reset() if trainer_id == 0 and args.validate: if args.use_ema: print('ExponentialMovingAverage validate start...') with ema.apply(exe): - validate(args, test_data_loader, exe, test_prog, - test_fetch_list, pass_id, - train_batch_metrics_record) + validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id, train_batch_metrics_record) print('ExponentialMovingAverage validate over!') - validate(args, test_data_loader, exe, test_prog, test_fetch_list, - pass_id, train_batch_metrics_record) + validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id, train_batch_metrics_record) #For now, save model per epoch. if pass_id % args.save_step == 0: save_model(args, exe, train_prog, pass_id) + if args.use_dali: + test_iter.reset() def main(): args = parse_args() diff --git a/PaddleCV/image_classification/utils/utility.py b/PaddleCV/image_classification/utils/utility.py index c14593c7..a336320c 100644 --- a/PaddleCV/image_classification/utils/utility.py +++ b/PaddleCV/image_classification/utils/utility.py @@ -114,6 +114,7 @@ def parse_args(): parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") # READER AND PREPROCESS + add_arg('use_dali', bool, False, "Whether to use nvidia DALI for preprocessing") add_arg('lower_scale', float, 0.08, "The value of lower_scale in ramdom_crop") add_arg('lower_ratio', float, 3./4., "The value of lower_ratio in ramdom_crop") add_arg('upper_ratio', float, 4./3., "The value of upper_ratio in ramdom_crop") @@ -328,14 +329,17 @@ def create_data_loader(is_train, args): feed_list=[feed_image, feed_y_a, feed_y_b, feed_lam], capacity=64, use_double_buffer=True, - iterable=False) + iterable=True) return data_loader, [feed_image, feed_y_a, feed_y_b, feed_lam] else: + if args.use_dali: + return None, [feed_image, feed_label] + data_loader = fluid.io.DataLoader.from_generator( feed_list=[feed_image, feed_label], capacity=64, use_double_buffer=True, - iterable=False) + iterable=True) return data_loader, [feed_image, feed_label] -- GitLab