提交 7e59194b 编写于 作者: L LielinJiang 提交者: ruri

Add multiprocess training to image classification (#3482)

* add_multiprocess_train_to_image_classification
上级 71b7fa84
......@@ -90,6 +90,23 @@ python train.py \
bash run.sh train 模型名
```
**多进程模型训练:**
如果你有多张GPU卡的话,我们强烈建议你使用多进程模式来训练模型,这会极大的提升训练速度。启动方式如下:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch train.py \
--model=ResNet50 \
--batch_size=256 \
--total_images=1281167 \
--class_dim=1000 \
--image_shape=3,224,224 \
--model_save_dir=output/ \
--lr_strategy=piecewise_decay \
--reader_thread=4 \
--lr=0.1
```
或者参考 scripts/train/ResNet50_dist.sh
**参数说明:**
环境配置部分:
......
......@@ -82,6 +82,25 @@ or running run.sh scripts
bash run.sh train model_name
```
**multiprocess training:**
If you have multiple gpus, this method is strongly recommended, because it can improve training speed dramatically.
You can start the multiprocess training step by:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch train.py \
--model=ResNet50 \
--batch_size=256 \
--total_images=1281167 \
--class_dim=1000 \
--image_shape=3,224,224 \
--model_save_dir=output/ \
--lr_strategy=piecewise_decay \
--reader_thread=4 \
--lr=0.1
```
or reference scripts/train/ResNet50_dist.sh
**parameter introduction:**
Environment settings:
......
......@@ -101,9 +101,9 @@ def eval(args):
exe.run(fluid.default_startup_program())
fluid.io.load_persistables(exe, args.pretrained_model)
imagenet_reader = reader.ImageNetReader()
val_reader = imagenet_reader.val(settings=args)
val_reader = paddle.batch(
reader.val(settings=args), batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
test_info = [[], [], []]
......
......@@ -284,7 +284,7 @@ def prepare_reader(epoch_id, train_py_reader, train_bs, val_bs, trn_dir,
min_scale=min_scale,
shuffle_seed=epoch_id + 1)
train_py_reader.decorate_paddle_reader(
paddle.batch(
fluid.io.batch(
train_reader, batch_size=train_bs))
test_reader = reader.test(
......@@ -292,7 +292,7 @@ def prepare_reader(epoch_id, train_py_reader, train_bs, val_bs, trn_dir,
bs=val_bs * DEVICE_NUM,
sz=img_dim,
rect_val=rect_val)
test_batched_reader = paddle.batch(
test_batched_reader = fluid.io.batch(
test_reader, batch_size=val_bs * DEVICE_NUM)
return test_batched_reader
......
......@@ -88,9 +88,9 @@ def infer(args):
print("model: ", args.model, " is already saved")
exit(0)
test_batch_size = 1
test_reader = paddle.batch(
reader.test(settings=args), batch_size=test_batch_size)
args.test_batch_size = 1
imagenet_reader = reader.ImageNetReader()
test_reader = imagenet_reader.test(settings=args)
feeder = fluid.DataFeeder(place=place, feed_list=[image])
TOPK = args.topk
......
......@@ -92,7 +92,7 @@ def prepare_reader(is_train, pyreader, args, pass_id=1):
bs = args.batch_size / get_device_num()
else:
bs = 16
pyreader.decorate_paddle_reader(paddle.batch(reader, batch_size=bs))
pyreader.decorate_paddle_reader(fluid.io.batch(reader, batch_size=bs))
def build_program(is_train, main_prog, startup_prog, args):
......
......@@ -19,7 +19,7 @@ import functools
import numpy as np
from PIL import Image, ImageEnhance
import paddle
from paddle import fluid
random.seed(0)
np.random.seed(0)
......@@ -190,7 +190,7 @@ def _reader_creator(file_list,
mapper = functools.partial(
process_batch_data, mode=mode, color_jitter=color_jitter, rotate=rotate)
return paddle.reader.xmap_readers(mapper, data_reader, THREAD, BUF_SIZE)
return fluid.io.xmap_readers(mapper, data_reader, THREAD, BUF_SIZE)
def train(batch_size, data_dir=DATA_DIR, shuffle_seed=0, infinite=False):
......
......@@ -20,6 +20,7 @@ import numpy as np
import cv2
import paddle
from paddle import fluid
from utils.autoaugment import ImageNetPolicy
from PIL import Image
......@@ -163,18 +164,11 @@ def create_mixup_reader(settings, rd):
tmp_l2 = []
tmp_lam = []
batch_size = settings.batch_size
alpha = settings.mixup_alpha
def fetch_data():
data_list = []
for i, item in enumerate(rd()):
data_list.append(item)
if i % batch_size == batch_size - 1:
yield data_list
data_list = []
for item in rd():
yield item
def mixup_data():
for data_list in fetch_data():
......@@ -245,15 +239,35 @@ def process_image(sample, settings, mode, color_jitter, rotate):
elif mode == 'test':
return (img, )
def process_batch_data(input_data, settings, mode, color_jitter, rotate):
batch_data = []
for sample in input_data:
if os.path.isfile(sample[0]):
batch_data.append(
process_image(sample, settings, mode, color_jitter, rotate))
else:
print("File not exist : %s" % sample[0])
return batch_data
class ImageNetReader:
def __init__(self, seed=None):
self.shuffle_seed = seed
def set_shuffle_seed(self, seed):
assert isinstance(seed, int), "shuffle seed must be int"
self.shuffle_seed = seed
def _reader_creator(settings,
def _reader_creator(self, settings,
file_list,
mode,
shuffle=False,
color_jitter=False,
rotate=False,
data_dir=None):
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
batch_size = settings.batch_size / paddle.fluid.core.get_cuda_device_count()
def reader():
def read_file_list():
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if mode != "test" and len(full_lines) < settings.batch_size:
......@@ -261,35 +275,49 @@ def _reader_creator(settings,
"Warning: The number of the whole data ({}) is smaller than the batch_size ({}), and drop_last is turnning on, so nothing will feed in program, Terminated now. Please reset batch_size to a smaller number or feed more data!"
.format(len(full_lines), settings.batch_size))
os._exit(1)
if shuffle:
if num_trainers > 1 and mode == "train":
assert self.shuffle_seed is not None, "multiprocess train, shuffle seed must be set!"
np.random.RandomState(self.shuffle_seed).shuffle(full_lines)
elif shuffle:
np.random.shuffle(full_lines)
batch_data = []
for line in full_lines:
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
if not os.path.exists(img_path):
print("Warning: {} doesn't exist!".format(img_path))
if mode == "train" or mode == "val":
yield img_path, int(label)
elif mode == "test":
yield [img_path]
batch_data.append([img_path, int(label)])
if len(batch_data) == batch_size:
if mode == 'train' or mode == 'val' or mode == 'test':
yield batch_data
batch_data = []
return read_file_list
data_reader = reader()
if mode == 'train' and num_trainers > 1:
assert self.shuffle_seed is not None, \
"If num_trainers > 1, the shuffle_seed must be set, because " \
"the order of batch data generated by reader " \
"must be the same in the respective processes."
data_reader = paddle.fluid.contrib.reader.distributed_batch_reader(data_reader)
mapper = functools.partial(
process_image,
process_batch_data,
settings=settings,
mode=mode,
color_jitter=color_jitter,
rotate=rotate)
return paddle.reader.xmap_readers(
return fluid.io.xmap_readers(
mapper,
reader,
data_reader,
settings.reader_thread,
settings.reader_buf_size,
order=False)
def train(settings):
def train(self, settings):
"""Create a reader for trainning
Args:
......@@ -307,7 +335,7 @@ def train(settings):
global policy
policy = ImageNetPolicy()
reader = _reader_creator(
reader = self._reader_creator(
settings,
file_list,
'train',
......@@ -318,10 +346,14 @@ def train(settings):
if settings.use_mixup == True:
reader = create_mixup_reader(settings, reader)
reader = fluid.io.batch(
reader,
batch_size=int(settings.batch_size / paddle.fluid.core.get_cuda_device_count()),
drop_last=True)
return reader
def val(settings):
def val(self, settings):
"""Create a reader for eval
Args:
......@@ -336,11 +368,11 @@ def val(settings):
file_list), "{} doesn't exist, please check data list path".format(
file_list)
return _reader_creator(
return self._reader_creator(
settings, file_list, 'val', shuffle=False, data_dir=settings.data_dir)
def test(settings):
def test(self, settings):
"""Create a reader for testing
Args:
......@@ -353,5 +385,5 @@ def test(settings):
assert os.path.isfile(
file_list), "{} doesn't exist, please check data list path".format(
file_list)
return _reader_creator(
return self._reader_creator(
settings, file_list, 'test', shuffle=False, data_dir=settings.data_dir)
##Training details
export CUDA_VISIBLE_DEVICES=0,1,2,3
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.98
#ResNet50:
python -m paddle.distributed.launch train.py \
--model=ResNet50 \
--batch_size=256 \
--total_images=1281167 \
--class_dim=1000 \
--image_shape=3,224,224 \
--model_save_dir=output/ \
--lr_strategy=piecewise_decay \
--num_epochs=120 \
--lr=0.1 \
--reader_thread=4 \
--l2_decay=1e-4
......@@ -42,7 +42,6 @@ from utils import *
import models
from build_model import create_model
def build_program(is_train, main_prog, startup_prog, args):
"""build program, and add grad op in program accroding to different mode
......@@ -167,25 +166,20 @@ def train(args):
#init model by checkpoint or pretrianed model.
init_model(exe, args, train_prog)
train_reader = reader.train(settings=args)
train_reader = paddle.batch(
train_reader,
batch_size=int(args.batch_size / fluid.core.get_cuda_device_count()),
drop_last=True)
test_reader = reader.val(settings=args)
test_reader = paddle.batch(
test_reader, batch_size=args.test_batch_size, drop_last=True)
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
imagenet_reader = reader.ImageNetReader(0 if num_trainers > 1 else None)
train_reader = imagenet_reader.train(settings=args)
test_reader = imagenet_reader.val(settings=args)
train_py_reader.decorate_sample_list_generator(train_reader, place)
test_py_reader.decorate_sample_list_generator(test_reader, place)
compiled_train_prog = best_strategy_compiled(args, train_prog,
train_fetch_vars[0])
train_fetch_vars[0], exe)
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
for pass_id in range(args.num_epochs):
if num_trainers > 1:
imagenet_reader.set_shuffle_seed(pass_id + (args.random_seed if args.random_seed else 0))
train_batch_id = 0
train_batch_time_record = []
train_batch_metrics_record = []
......@@ -203,7 +197,7 @@ def train(args):
train_batch_metrics_avg = np.mean(
np.array(train_batch_metrics), axis=1)
train_batch_metrics_record.append(train_batch_metrics_avg)
if trainer_id == 0:
print_info(pass_id, train_batch_id, args.print_step,
train_batch_metrics_avg, train_batch_elapse, "batch")
sys.stdout.flush()
......@@ -212,6 +206,7 @@ def train(args):
except fluid.core.EOFException:
train_py_reader.reset()
if trainer_id == 0:
if args.use_ema:
print('ExponentialMovingAverage validate start...')
with ema.apply(exe):
......@@ -226,6 +221,7 @@ def train(args):
def main():
args = parse_args()
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
print_arguments(args)
check_args(args)
train(args)
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle.fluid as fluid
def nccl2_prepare(args, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
envs = args.dist_env
t.transpile(
envs["trainer_id"],
trainers=','.join(envs["trainer_endpoints"]),
current_endpoint=envs["current_endpoint"],
startup_program=startup_prog,
program=main_prog)
def pserver_prepare(args, train_prog, startup_prog):
config = fluid.DistributeTranspilerConfig()
config.slice_var_up = args.split_var
t = fluid.DistributeTranspiler(config=config)
envs = args.dist_env
training_role = envs["training_role"]
t.transpile(
envs["trainer_id"],
program=train_prog,
pservers=envs["pserver_endpoints"],
trainers=envs["num_trainers"],
sync_mode=not args.async_mode,
startup_program=startup_prog)
if training_role == "PSERVER":
pserver_program = t.get_pserver_program(envs["current_endpoint"])
pserver_startup_program = t.get_startup_program(
envs["current_endpoint"],
pserver_program,
startup_program=startup_prog)
return pserver_program, pserver_startup_program
elif training_role == "TRAINER":
train_program = t.get_trainer_program()
return train_program, startup_prog
else:
raise ValueError(
'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def nccl2_prepare_paddle(trainer_id, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'),
current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'),
startup_program=startup_prog,
program=main_prog)
def prepare_for_multi_process(exe, build_strategy, train_prog):
# prepare for multi-process
trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers < 2: return
print("PADDLE_TRAINERS_NUM", num_trainers)
print("PADDLE_TRAINER_ID", trainer_id)
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
# NOTE(zcd): use multi processes to train the model,
# and each process use one GPU card.
startup_prog = fluid.Program()
nccl2_prepare_paddle(trainer_id, startup_prog, train_prog)
# the startup_prog are run two times, but it doesn't matter.
exe.run(startup_prog)
......@@ -18,7 +18,6 @@ from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers.ops as ops
from paddle.fluid.initializer import init_on_cpu
......@@ -142,7 +141,6 @@ class Optimizer(object):
"""
def __init__(self, args):
self.batch_size = args.batch_size
self.lr = args.lr
self.lr_strategy = args.lr_strategy
......
......@@ -32,6 +32,7 @@ import paddle.fluid as fluid
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.framework import Program, program_guard, name_scope, default_main_program
from paddle.fluid import unique_name, layers
from utils import dist_utils
def print_arguments(args):
"""Print argparse's arguments.
......@@ -376,7 +377,7 @@ def print_info(pass_id, batch_id, print_step, metrics, time_info, info_mode):
raise Exception("Illegal info_mode")
def best_strategy_compiled(args, program, loss):
def best_strategy_compiled(args, program, loss, exe):
"""make a program which wrapped by a compiled program
"""
......@@ -391,6 +392,13 @@ def best_strategy_compiled(args, program, loss):
exec_strategy.num_threads = fluid.core.get_cuda_device_count()
exec_strategy.num_iteration_per_drop_scope = 10
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers > 1 and args.use_gpu:
dist_utils.prepare_for_multi_process(exe, build_strategy, program)
# NOTE: the process is fast when num_threads is 1
# for multi-process training.
exec_strategy.num_threads = 1
compiled_program = fluid.CompiledProgram(program).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册