未验证 提交 69a9aa6c 编写于 作者: C chengduo 提交者: GitHub

Support multi process training for transformer and yolov3 (#2475)

* add multi process implementation for yolov3
上级 9bd3df44
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle.fluid as fluid
def nccl2_prepare(trainer_id, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id,
trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'),
current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'),
startup_program=startup_prog,
program=main_prog)
def prepare_for_multi_process(exe, build_strategy, train_prog):
# prepare for multi-process
trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers < 2: return
print("PADDLE_TRAINERS_NUM", num_trainers)
print("PADDLE_TRAINER_ID", trainer_id)
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
# NOTE(zcd): use multi processes to train the model,
# and each process use one GPU card.
startup_prog = fluid.Program()
nccl2_prepare(trainer_id, startup_prog, train_prog)
# the startup_prog are run two times, but it doesn't matter.
exe.run(startup_prog)
......@@ -28,7 +28,7 @@ import image_utils
from pycocotools.coco import COCO
from data_utils import GeneratorEnqueuer
from config import cfg
import paddle.fluid as fluid
class DataSetReader(object):
"""A class for parsing and read COCO dataset"""
......@@ -146,6 +146,7 @@ class DataSetReader(object):
size=416,
batch_size=None,
shuffle=False,
shuffle_seed=None,
mixup_iter=0,
random_sizes=[],
image=None):
......@@ -225,6 +226,8 @@ class DataSetReader(object):
if mode == 'train':
imgs = self._parse_images_by_mode(mode)
if shuffle:
if shuffle_seed is not None:
np.random.seed(shuffle_seed)
np.random.shuffle(imgs)
read_cnt = 0
total_iter = 0
......@@ -269,6 +272,14 @@ class DataSetReader(object):
batch_out = [(im, im_id, im_shape)]
yield batch_out
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if mode == 'train' and num_trainers > 1:
assert shuffle_seed is not None, \
"If num_trainers > 1, the shuffle_seed must be set, because " \
"the order of batch data generated by reader " \
"must be the same in the respective processes."
reader = fluid.contrib.reader.distributed_batch_reader(reader)
return reader
......@@ -278,16 +289,17 @@ dsr = DataSetReader()
def train(size=416,
batch_size=64,
shuffle=True,
shuffle_seed=None,
total_iter=0,
mixup_iter=0,
random_sizes=[],
num_workers=8,
max_queue=32,
use_multiprocessing=True):
generator = dsr.get_reader('train', size, batch_size, shuffle,
use_multiprocess_reader=True):
generator = dsr.get_reader('train', size, batch_size, shuffle, shuffle_seed,
int(mixup_iter / num_workers), random_sizes)
if not use_multiprocessing:
if not use_multiprocess_reader:
return generator
def infinite_reader():
......@@ -299,7 +311,7 @@ def train(size=416,
cnt = 0
try:
enqueuer = GeneratorEnqueuer(
infinite_reader(), use_multiprocessing=use_multiprocessing)
infinite_reader(), use_multiprocessing=use_multiprocess_reader)
enqueuer.start(max_queue_size=max_queue, workers=num_workers)
generator_out = None
while True:
......
......@@ -17,13 +17,15 @@ from __future__ import division
from __future__ import print_function
import os
def set_paddle_flags(flags):
for key, value in flags.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
set_paddle_flags({
'FLAGS_eager_delete_tensor_gb': 0, # enable gc
'FLAGS_eager_delete_tensor_gb': 0, # enable gc
'FLAGS_memory_fraction_of_eager_deletion': 1,
'FLAGS_fraction_of_gpu_memory_to_use': 0.98
})
......@@ -41,6 +43,21 @@ import reader
from models.yolov3 import YOLOv3
from learning_rate import exponential_with_warmup_decay
from config import cfg
import dist_utils
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
def get_device_num():
# NOTE(zcd): for multi-processe training, each process use one GPU card.
if num_trainers > 1: return 1
visible_device = os.environ.get('CUDA_VISIBLE_DEVICES', None)
if visible_device:
device_num = len(visible_device.split(','))
else:
device_num = subprocess.check_output(
['nvidia-smi', '-L']).decode().count('\n')
return device_num
def train():
......@@ -60,8 +77,7 @@ def train():
loss = model.loss()
loss.persistable = True
devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
devices_num = len(devices.split(","))
devices_num = get_device_num()
print("Found {} CUDA devices.".format(devices_num))
learning_rate = cfg.learning_rate
......@@ -81,7 +97,8 @@ def train():
momentum=cfg.momentum)
optimizer.minimize(loss)
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
......@@ -95,15 +112,24 @@ def train():
fluid.io.load_vars(exe, cfg.pretrain, predicate=if_exist)
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False #gc and memory optimize may conflict
build_strategy.memory_optimize = False #gc and memory optimize may conflict
syncbn = cfg.syncbn
if syncbn and devices_num <= 1:
if (syncbn and devices_num <= 1) or num_trainers > 1:
print("Disable syncbn in single device")
syncbn = False
build_strategy.sync_batch_norm = syncbn
exec_strategy = fluid.ExecutionStrategy()
if cfg.use_gpu and num_trainers > 1:
dist_utils.prepare_for_multi_process(exe, build_strategy,
fluid.default_main_program())
exec_strategy.num_threads = 1
compile_program = fluid.compiler.CompiledProgram(fluid.default_main_program(
)).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
random_sizes = [cfg.input_size]
if cfg.random_shape:
......@@ -111,17 +137,23 @@ def train():
total_iter = cfg.max_iter - cfg.start_iter
mixup_iter = total_iter - cfg.no_mixup_iter
shuffle = True
if args.enable_ce:
shuffle = False
# NOTE: If num_trainers > 1, the shuffle_seed must be set, because
# the order of batch data generated by reader
# must be the same in the respective processes.
shuffle_seed = 1 if num_trainers > 1 else None
train_reader = reader.train(
input_size,
batch_size=cfg.batch_size,
shuffle=shuffle,
shuffle_seed=shuffle_seed,
total_iter=total_iter * devices_num,
mixup_iter=mixup_iter * devices_num,
random_sizes=random_sizes,
use_multiprocessing=cfg.use_multiprocess)
use_multiprocess_reader=cfg.use_multiprocess_reader)
py_reader = model.py_reader
py_reader.decorate_paddle_reader(train_reader)
......
......@@ -102,7 +102,7 @@ def parse_args():
add_arg('class_num', int, 80, "Class number.")
add_arg('data_dir', str, 'dataset/coco', "The data root path.")
add_arg('start_iter', int, 0, "Start iteration.")
add_arg('use_multiprocess', bool, True, "add multiprocess.")
add_arg('use_multiprocess_reader', bool, True, "add multiprocess.")
#SOLVER
add_arg('batch_size', int, 8, "Mini-batch size per device.")
add_arg('learning_rate', float, 0.001, "Learning rate.")
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle.fluid as fluid
def nccl2_prepare(trainer_id, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id,
trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'),
current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'),
startup_program=startup_prog,
program=main_prog)
def prepare_for_multi_process(exe, build_strategy, train_prog):
# prepare for multi-process
trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers < 2: return
print("PADDLE_TRAINERS_NUM", num_trainers)
print("PADDLE_TRAINER_ID", trainer_id)
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
# NOTE(zcd): use multi processes to train the model,
# and each process use one GPU card.
startup_prog = fluid.Program()
nccl2_prepare(trainer_id, startup_prog, train_prog)
# the startup_prog are run two times, but it doesn't matter.
exe.run(startup_prog)
\ No newline at end of file
......@@ -180,6 +180,7 @@ class DataReader(object):
min_length=0,
max_length=100,
shuffle=True,
shuffle_seed=None,
shuffle_batch=False,
use_token_batch=False,
field_delimiter="\t",
......@@ -199,6 +200,7 @@ class DataReader(object):
self._sort_type = sort_type
self._clip_last_batch = clip_last_batch
self._shuffle = shuffle
self._shuffle_seed = shuffle_seed
self._shuffle_batch = shuffle_batch
self._min_length = min_length
self._max_length = max_length
......@@ -292,6 +294,8 @@ class DataReader(object):
else:
if self._shuffle:
infos = self._sample_infos
if self._shuffle_seed is not None:
self._random.seed(self._shuffle_seed)
self._random.shuffle(infos)
else:
infos = self._sample_infos
......
......@@ -16,6 +16,9 @@ import reader
from config import *
from desc import *
from model import transformer, position_encoding_init
import dist_utils
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
def parse_args():
......@@ -146,6 +149,18 @@ def parse_args():
return args
def get_device_num():
# NOTE(zcd): for multi-processe training, each process use one GPU card.
if num_trainers > 1: return 1
visible_device = os.environ.get('CUDA_VISIBLE_DEVICES', None)
if visible_device:
device_num = len(visible_device.split(','))
else:
device_num = subprocess.check_output(
['nvidia-smi', '-L']).decode().count('\n')
return device_num
def append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
current_endpoint):
assert (trainer_id >= 0 and len(worker_endpoints) > 1 and
......@@ -269,6 +284,10 @@ def prepare_data_generator(args,
Data generator wrapper for DataReader. If use py_reader, set the data
provider for py_reader
"""
# NOTE: If num_trainers > 1, the shuffle_seed must be set, because
# the order of batch data generated by reader
# must be the same in the respective processes.
shuffle_seed = 1 if num_trainers > 1 else None
data_reader = reader.DataReader(
fpattern=args.val_file_pattern if is_test else args.train_file_pattern,
src_vocab_fpath=args.src_vocab_fpath,
......@@ -279,6 +298,7 @@ def prepare_data_generator(args,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_seed=shuffle_seed,
shuffle_batch=args.shuffle_batch,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
......@@ -324,8 +344,12 @@ def prepare_data_generator(args,
# to make data on each device have similar token number
data_reader = split(data_reader, count)
if args.use_py_reader:
pyreader.decorate_tensor_provider(
py_reader_provider_wrapper(data_reader, place))
train_reader = py_reader_provider_wrapper(data_reader, place)
if num_trainers > 1:
assert shuffle_seed is not None
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
pyreader.decorate_tensor_provider(train_reader)
data_reader = None
else: # Data generator for multi-devices
data_reader = stack(data_reader, count)
......@@ -505,6 +529,10 @@ def train_loop(exe,
# build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
build_strategy.fuse_all_optimizer_ops = True
if num_trainers > 1 and args.use_py_reader and TrainTaskConfig.use_gpu:
dist_utils.prepare_for_multi_process(exe, build_strategy, train_prog)
exec_strategy.num_threads = 1
logging.info("begin executor")
train_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
......@@ -527,7 +555,6 @@ def train_loop(exe,
step_idx = 0
init_flag = True
logging.info("begin train")
for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
......@@ -605,6 +632,7 @@ def train_loop(exe,
time_consumed))
else:
logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed))
if not args.enable_ce:
fluid.io.save_persistables(
exe,
......@@ -637,7 +665,7 @@ def train(args):
else:
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id)
dev_count = fluid.core.get_cuda_device_count()
dev_count = get_device_num()
exe = fluid.Executor(place)
......@@ -775,4 +803,4 @@ if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
args = parse_args()
train(args)
train(args)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册