diff --git a/PaddleCV/PaddleDetection/ppdet/utils/checkpoint.py b/PaddleCV/PaddleDetection/ppdet/utils/checkpoint.py index c1888fcdb6f39ff2bb2c25349f5692f98019b1d8..f6be75ab88201277d5cef0feca757086abc0a5f4 100644 --- a/PaddleCV/PaddleDetection/ppdet/utils/checkpoint.py +++ b/PaddleCV/PaddleDetection/ppdet/utils/checkpoint.py @@ -19,6 +19,7 @@ from __future__ import unicode_literals import os import shutil +import time import numpy as np import paddle.fluid as fluid @@ -44,6 +45,33 @@ def is_url(path): return path.startswith('http://') or path.startswith('https://') +def _get_weight_path(path): + env = os.environ + if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: + trainer_id = int(env['PADDLE_TRAINER_ID']) + num_trainers = int(env['PADDLE_TRAINERS_NUM']) + if num_trainers <= 1: + path = get_weights_path(path) + else: + from ppdet.utils.download import map_path, WEIGHTS_HOME + weight_path = map_path(path, WEIGHTS_HOME) + lock_path = weight_path + '.lock' + if not os.path.exists(weight_path): + os.makedirs(os.path.dirname(weight_path), exist_ok=True) + with open(lock_path, 'w'): # touch + os.utime(lock_path) + if trainer_id == 0: + get_weights_path(path) + os.remove(lock_path) + else: + while os.path.exists(lock_path): + time.sleep(1) + path = weight_path + else: + path = get_weights_path(path) + return path + + def load_pretrain(exe, prog, path): """ Load model from the given path. @@ -52,8 +80,9 @@ def load_pretrain(exe, prog, path): prog (fluid.Program): load weight to which Program object. path (string): URL string or loca model path. """ + if is_url(path): - path = get_weights_path(path) + path = _get_weight_path(path) if not os.path.exists(path): raise ValueError("Model pretrain path {} does not " @@ -131,7 +160,7 @@ def load_and_fusebn(exe, prog, path): """ logger.info('Load model and fuse batch norm from {}...'.format(path)) if is_url(path): - path = get_weights_path(path) + path = _get_weight_path(path) if not os.path.exists(path): raise ValueError("Model path {} does not exists.".format(path)) diff --git a/PaddleCV/PaddleDetection/ppdet/utils/dist_utils.py b/PaddleCV/PaddleDetection/ppdet/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..32eead4a797ba70cb6980e0368ff9873102680c2 --- /dev/null +++ b/PaddleCV/PaddleDetection/ppdet/utils/dist_utils.py @@ -0,0 +1,41 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +import os + +import paddle.fluid as fluid + + +def nccl2_prepare(trainer_id, startup_prog, main_prog): + config = fluid.DistributeTranspilerConfig() + config.mode = "nccl2" + t = fluid.DistributeTranspiler(config=config) + t.transpile( + trainer_id, + trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'), + current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'), + startup_program=startup_prog, + program=main_prog) + + +def prepare_for_multi_process(exe, build_strategy, startup_prog, main_prog): + trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0)) + num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) + if num_trainers < 2: + return + build_strategy.num_trainers = num_trainers + build_strategy.trainer_id = trainer_id + nccl2_prepare(trainer_id, startup_prog, main_prog) diff --git a/PaddleCV/PaddleDetection/ppdet/utils/download.py b/PaddleCV/PaddleDetection/ppdet/utils/download.py index def4c4d9a353939d5a110812fb929a9889490330..b40e1404d82e3f8013ac43c843daba98c2dd74f9 100644 --- a/PaddleCV/PaddleDetection/ppdet/utils/download.py +++ b/PaddleCV/PaddleDetection/ppdet/utils/download.py @@ -130,6 +130,16 @@ def get_dataset_path(path, annotation, image_dir): "'voc' and 'coco' currently".format(path, osp.split(path)[-1])) +def map_path(url, root_dir): + # parse path after download to decompress under root_dir + fname = url.split('/')[-1] + zip_formats = ['.zip', '.tar', '.gz'] + fpath = fname + for zip_format in zip_formats: + fpath = fpath.replace(zip_format, '') + return osp.join(root_dir, fpath) + + def get_path(url, root_dir, md5sum=None): """ Download from given url to root_dir. if file or directory specified by url is exists under @@ -142,12 +152,7 @@ def get_path(url, root_dir, md5sum=None): md5sum (str): md5 sum of download package """ # parse path after download to decompress under root_dir - fname = url.split('/')[-1] - zip_formats = ['.zip', '.tar', '.gz'] - fpath = fname - for zip_format in zip_formats: - fpath = fpath.replace(zip_format, '') - fullpath = osp.join(root_dir, fpath) + fullpath = map_path(url, root_dir) # For same zip file, decompressed directory name different # from zip file name, rename by following map diff --git a/PaddleCV/PaddleDetection/tools/train.py b/PaddleCV/PaddleDetection/tools/train.py index fc14fa13ead7f7daea967abdb1f33c99c51d7089..8081ce3dc7e5ef41617534fa8369bd867956e6f8 100644 --- a/PaddleCV/PaddleDetection/tools/train.py +++ b/PaddleCV/PaddleDetection/tools/train.py @@ -29,7 +29,7 @@ def set_paddle_flags(**kwargs): os.environ[key] = str(value) -# NOTE(paddle-dev): All of these flags should be set before +# NOTE(paddle-dev): All of these flags should be set before # `import paddle`. Otherwise, it would not take any effect. set_paddle_flags( FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory @@ -40,6 +40,7 @@ from ppdet.core.workspace import load_config, merge_config, create from ppdet.data.data_feed import create_reader from ppdet.utils.cli import print_total_cfg +from ppdet.utils import dist_utils from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results from ppdet.utils.stats import TrainingStats from ppdet.utils.cli import ArgsParser @@ -54,6 +55,15 @@ logger = logging.getLogger(__name__) def main(): + env = os.environ + FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env + if FLAGS.dist: + trainer_id = int(env['PADDLE_TRAINER_ID']) + import random + local_seed = (99 + trainer_id) + random.seed(local_seed) + np.random.seed(local_seed) + cfg = load_config(FLAGS.config) if 'architecture' in cfg: main_arch = cfg.architecture @@ -84,7 +94,11 @@ def main(): else: eval_feed = create(cfg.eval_feed) - place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() + if 'FLAGS_selected_gpus' in env: + device_id = int(env['FLAGS_selected_gpus']) + else: + device_id = 0 + place = fluid.CUDAPlace(device_id) exe = fluid.Executor(place) lr_builder = create('LearningRate') @@ -132,23 +146,27 @@ def main(): build_strategy = fluid.BuildStrategy() sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn' # only enable sync_bn in multi GPU devices - build_strategy.sync_batch_norm = sync_bn and devices_num > 1 and cfg.use_gpu + build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \ + and cfg.use_gpu exec_strategy = fluid.ExecutionStrategy() # iteration number when CompiledProgram tries to drop local execution scopes. # Set it to be 1 to save memory usages, so that unused variables in # local execution scopes can be deleted after each iteration. exec_strategy.num_iteration_per_drop_scope = 1 - - train_compile_program = fluid.compiler.CompiledProgram( - train_prog).with_data_parallel( - loss_name=loss.name, - build_strategy=build_strategy, - exec_strategy=exec_strategy) - if FLAGS.eval: - eval_compile_program = fluid.compiler.CompiledProgram(eval_prog) + if FLAGS.dist: + dist_utils.prepare_for_multi_process( + exe, build_strategy, startup_prog, train_prog) + exec_strategy.num_threads = 1 exe.run(startup_prog) + compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel( + loss_name=loss.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + + if FLAGS.eval: + compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog) fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel' start_iter = 0 @@ -160,8 +178,10 @@ def main(): elif cfg.pretrain_weights: checkpoint.load_pretrain(exe, train_prog, cfg.pretrain_weights) - train_reader = create_reader(train_feed, (cfg.max_iters - start_iter) * - devices_num, FLAGS.dataset_dir) + train_reader = create_reader( + train_feed, + (cfg.max_iters - start_iter) * devices_num, + FLAGS.dataset_dir) train_pyreader.decorate_sample_list_generator(train_reader, place) # whether output bbox is normalized in model output layer @@ -197,7 +217,7 @@ def main(): time_cost = np.mean(time_stat) eta_sec = (cfg.max_iters - it) * time_cost eta = str(datetime.timedelta(seconds=int(eta_sec))) - outs = exe.run(train_compile_program, fetch_list=train_values) + outs = exe.run(compiled_train_prog, fetch_list=train_values) stats = {k: np.array(v).mean() for k, v in zip(train_keys, outs[:-1])} # use tb-paddle to log loss @@ -209,18 +229,19 @@ def main(): train_stats.update(stats) logs = train_stats.log() - if it % cfg.log_iter == 0: + if it % cfg.log_iter == 0 and (not FLAGS.dist or trainer_id == 0): strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format( it, np.mean(outs[-1]), logs, time_cost, eta) logger.info(strs) - if it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1: + if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \ + and (not FLAGS.dist or trainer_id == 0): save_name = str(it) if it != cfg.max_iters - 1 else "model_final" checkpoint.save(exe, train_prog, os.path.join(save_dir, save_name)) if FLAGS.eval: # evaluation - results = eval_run(exe, eval_compile_program, eval_pyreader, + results = eval_run(exe, compiled_eval_prog, eval_pyreader, eval_keys, eval_values, eval_cls) resolution = None if 'mask' in results[0]: