提交 f4d74a90 编写于 作者: Y Yang Zhang 提交者: GitHub

Add multi process training to ppdet (#3328)

* Initial support for distributed training

* Housekeep on rank 0 only

* Conform to models convention
上级 d6e52889
......@@ -19,6 +19,7 @@ from __future__ import unicode_literals
import os
import shutil
import time
import numpy as np
import paddle.fluid as fluid
......@@ -44,6 +45,33 @@ def is_url(path):
return path.startswith('http://') or path.startswith('https://')
def _get_weight_path(path):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
path = get_weights_path(path)
else:
from ppdet.utils.download import map_path, WEIGHTS_HOME
weight_path = map_path(path, WEIGHTS_HOME)
lock_path = weight_path + '.lock'
if not os.path.exists(weight_path):
os.makedirs(os.path.dirname(weight_path), exist_ok=True)
with open(lock_path, 'w'): # touch
os.utime(lock_path)
if trainer_id == 0:
get_weights_path(path)
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(1)
path = weight_path
else:
path = get_weights_path(path)
return path
def load_pretrain(exe, prog, path):
"""
Load model from the given path.
......@@ -52,8 +80,9 @@ def load_pretrain(exe, prog, path):
prog (fluid.Program): load weight to which Program object.
path (string): URL string or loca model path.
"""
if is_url(path):
path = get_weights_path(path)
path = _get_weight_path(path)
if not os.path.exists(path):
raise ValueError("Model pretrain path {} does not "
......@@ -131,7 +160,7 @@ def load_and_fusebn(exe, prog, path):
"""
logger.info('Load model and fuse batch norm from {}...'.format(path))
if is_url(path):
path = get_weights_path(path)
path = _get_weight_path(path)
if not os.path.exists(path):
raise ValueError("Model path {} does not exists.".format(path))
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import os
import paddle.fluid as fluid
def nccl2_prepare(trainer_id, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'),
current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'),
startup_program=startup_prog,
program=main_prog)
def prepare_for_multi_process(exe, build_strategy, startup_prog, main_prog):
trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers < 2:
return
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
nccl2_prepare(trainer_id, startup_prog, main_prog)
......@@ -130,6 +130,16 @@ def get_dataset_path(path, annotation, image_dir):
"'voc' and 'coco' currently".format(path, osp.split(path)[-1]))
def map_path(url, root_dir):
# parse path after download to decompress under root_dir
fname = url.split('/')[-1]
zip_formats = ['.zip', '.tar', '.gz']
fpath = fname
for zip_format in zip_formats:
fpath = fpath.replace(zip_format, '')
return osp.join(root_dir, fpath)
def get_path(url, root_dir, md5sum=None):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
......@@ -142,12 +152,7 @@ def get_path(url, root_dir, md5sum=None):
md5sum (str): md5 sum of download package
"""
# parse path after download to decompress under root_dir
fname = url.split('/')[-1]
zip_formats = ['.zip', '.tar', '.gz']
fpath = fname
for zip_format in zip_formats:
fpath = fpath.replace(zip_format, '')
fullpath = osp.join(root_dir, fpath)
fullpath = map_path(url, root_dir)
# For same zip file, decompressed directory name different
# from zip file name, rename by following map
......
......@@ -29,7 +29,7 @@ def set_paddle_flags(**kwargs):
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be set before
# NOTE(paddle-dev): All of these flags should be set before
# `import paddle`. Otherwise, it would not take any effect.
set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
......@@ -40,6 +40,7 @@ from ppdet.core.workspace import load_config, merge_config, create
from ppdet.data.data_feed import create_reader
from ppdet.utils.cli import print_total_cfg
from ppdet.utils import dist_utils
from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
from ppdet.utils.stats import TrainingStats
from ppdet.utils.cli import ArgsParser
......@@ -54,6 +55,15 @@ logger = logging.getLogger(__name__)
def main():
env = os.environ
FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
if FLAGS.dist:
trainer_id = int(env['PADDLE_TRAINER_ID'])
import random
local_seed = (99 + trainer_id)
random.seed(local_seed)
np.random.seed(local_seed)
cfg = load_config(FLAGS.config)
if 'architecture' in cfg:
main_arch = cfg.architecture
......@@ -84,7 +94,11 @@ def main():
else:
eval_feed = create(cfg.eval_feed)
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
if 'FLAGS_selected_gpus' in env:
device_id = int(env['FLAGS_selected_gpus'])
else:
device_id = 0
place = fluid.CUDAPlace(device_id)
exe = fluid.Executor(place)
lr_builder = create('LearningRate')
......@@ -132,23 +146,27 @@ def main():
build_strategy = fluid.BuildStrategy()
sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
# only enable sync_bn in multi GPU devices
build_strategy.sync_batch_norm = sync_bn and devices_num > 1 and cfg.use_gpu
build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
and cfg.use_gpu
exec_strategy = fluid.ExecutionStrategy()
# iteration number when CompiledProgram tries to drop local execution scopes.
# Set it to be 1 to save memory usages, so that unused variables in
# local execution scopes can be deleted after each iteration.
exec_strategy.num_iteration_per_drop_scope = 1
train_compile_program = fluid.compiler.CompiledProgram(
train_prog).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
if FLAGS.eval:
eval_compile_program = fluid.compiler.CompiledProgram(eval_prog)
if FLAGS.dist:
dist_utils.prepare_for_multi_process(
exe, build_strategy, startup_prog, train_prog)
exec_strategy.num_threads = 1
exe.run(startup_prog)
compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
if FLAGS.eval:
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)
fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
start_iter = 0
......@@ -160,8 +178,10 @@ def main():
elif cfg.pretrain_weights:
checkpoint.load_pretrain(exe, train_prog, cfg.pretrain_weights)
train_reader = create_reader(train_feed, (cfg.max_iters - start_iter) *
devices_num, FLAGS.dataset_dir)
train_reader = create_reader(
train_feed,
(cfg.max_iters - start_iter) * devices_num,
FLAGS.dataset_dir)
train_pyreader.decorate_sample_list_generator(train_reader, place)
# whether output bbox is normalized in model output layer
......@@ -197,7 +217,7 @@ def main():
time_cost = np.mean(time_stat)
eta_sec = (cfg.max_iters - it) * time_cost
eta = str(datetime.timedelta(seconds=int(eta_sec)))
outs = exe.run(train_compile_program, fetch_list=train_values)
outs = exe.run(compiled_train_prog, fetch_list=train_values)
stats = {k: np.array(v).mean() for k, v in zip(train_keys, outs[:-1])}
# use tb-paddle to log loss
......@@ -209,18 +229,19 @@ def main():
train_stats.update(stats)
logs = train_stats.log()
if it % cfg.log_iter == 0:
if it % cfg.log_iter == 0 and (not FLAGS.dist or trainer_id == 0):
strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format(
it, np.mean(outs[-1]), logs, time_cost, eta)
logger.info(strs)
if it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1:
if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \
and (not FLAGS.dist or trainer_id == 0):
save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
checkpoint.save(exe, train_prog, os.path.join(save_dir, save_name))
if FLAGS.eval:
# evaluation
results = eval_run(exe, eval_compile_program, eval_pyreader,
results = eval_run(exe, compiled_eval_prog, eval_pyreader,
eval_keys, eval_values, eval_cls)
resolution = None
if 'mask' in results[0]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册