From 7a65af0c3953a8b83d0afaaa7f355a00d0d04f39 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Mon, 16 Nov 2020 18:28:02 +0800 Subject: [PATCH] update save load (#1702) --- ppdet/utils/checkpoint.py | 167 ++++++++++++++++++++++++-------------- tools/eval.py | 18 +++- tools/infer.py | 4 +- tools/train.py | 49 +++++++---- 4 files changed, 160 insertions(+), 78 deletions(-) diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index 419f5d611..cf107e0fd 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -11,89 +25,124 @@ import numpy as np import paddle import paddle.fluid as fluid from .download import get_weights_path +import logging +logger = logging.getLogger(__name__) -def get_ckpt_path(path): - if path.startswith('http://') or path.startswith('https://'): - env = os.environ - if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: - trainer_id = int(env['PADDLE_TRAINER_ID']) - num_trainers = int(env['PADDLE_TRAINERS_NUM']) - if num_trainers <= 1: - path = get_weights_path(path) - else: - from ppdet.utils.download import map_path, WEIGHTS_HOME - weight_path = map_path(path, WEIGHTS_HOME) - lock_path = weight_path + '.lock' - if not os.path.exists(weight_path): - try: - os.makedirs(os.path.dirname(weight_path)) - except OSError as e: - if e.errno != errno.EEXIST: - raise - with open(lock_path, 'w'): # touch - os.utime(lock_path, None) - if trainer_id == 0: - get_weights_path(path) - os.remove(lock_path) - else: - while os.path.exists(lock_path): - time.sleep(1) - path = weight_path - else: +def is_url(path): + """ + Whether path is URL. + Args: + path (string): URL string or not. + """ + return path.startswith('http://') or path.startswith('https://') + + +def get_weight_path(path): + env = os.environ + if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: + trainer_id = int(env['PADDLE_TRAINER_ID']) + num_trainers = int(env['PADDLE_TRAINERS_NUM']) + if num_trainers <= 1: path = get_weights_path(path) + else: + from ppdet.utils.download import map_path, WEIGHTS_HOME + weight_path = map_path(path, WEIGHTS_HOME) + lock_path = weight_path + '.lock' + if not os.path.exists(weight_path): + try: + os.makedirs(os.path.dirname(weight_path)) + except OSError as e: + if e.errno != errno.EEXIST: + raise + with open(lock_path, 'w'): # touch + os.utime(lock_path, None) + if trainer_id == 0: + get_weights_path(path) + os.remove(lock_path) + else: + while os.path.exists(lock_path): + time.sleep(1) + path = weight_path + else: + path = get_weights_path(path) + + return path + +def _strip_postfix(path): + path, ext = os.path.splitext(path) + assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \ + "Unknown postfix {} from weights".format(ext) return path -def load_dygraph_ckpt(model, - optimizer=None, - pretrain_ckpt=None, - ckpt=None, - ckpt_type=None, - exclude_params=[], - load_static_weights=False): +def load_weight(model, weight, optimizer=None): + if is_url(weight): + weight = get_weight_path(weight) + + path = _strip_postfix(weight) + pdparam_path = path + '.pdparams' + if not os.path.exists(pdparam_path): + raise ValueError("Model pretrain path {} does not " + "exists.".format(pdparam_path)) + + param_state_dict = paddle.load(pdparam_path) + model.set_dict(param_state_dict) + + if optimizer is not None and os.path.exists(path + '.pdopt'): + optim_state_dict = paddle.load(path + '.pdopt') + optimizer.set_state_dict(optim_state_dict) + return + + +def load_pretrain_weight(model, + pretrain_weight, + load_static_weights=False, + weight_type='pretrain'): + assert weight_type in ['pretrain', 'finetune'] + if is_url(pretrain_weight): + pretrain_weight = get_weight_path(pretrain_weight) + + path = _strip_postfix(pretrain_weight) + if not (os.path.isdir(path) or os.path.isfile(path) or + os.path.exists(path + '.pdparams')): + raise ValueError("Model pretrain path {} does not " + "exists.".format(path)) + + model_dict = model.state_dict() - assert ckpt_type in ['pretrain', 'resume', 'finetune', None] - if ckpt_type == 'pretrain' and ckpt is None: - ckpt = pretrain_ckpt - ckpt = get_ckpt_path(ckpt) - assert os.path.exists(ckpt), "Path {} does not exist.".format(ckpt) if load_static_weights: - pre_state_dict = fluid.load_program_state(ckpt) + pre_state_dict = paddle.static.load_program_state(path) param_state_dict = {} - model_dict = model.state_dict() for key in model_dict.keys(): weight_name = model_dict[key].name if weight_name in pre_state_dict.keys(): - print('Load weight: {}, shape: {}'.format( + logger.info('Load weight: {}, shape: {}'.format( weight_name, pre_state_dict[weight_name].shape)) param_state_dict[key] = pre_state_dict[weight_name] else: param_state_dict[key] = model_dict[key] model.set_dict(param_state_dict) - return model - param_state_dict, optim_state_dict = fluid.load_dygraph(ckpt) + return - if len(exclude_params) != 0: - for k in exclude_params: - param_state_dict.pop(k, None) - - if ckpt_type == 'pretrain': + param_state_dict = paddle.load(path + '.pdparams') + if weight_type == 'pretrain': model.backbone.set_dict(param_state_dict) else: + ignore_set = set() + for name, weight in model_dict: + if name in param_state_dict: + if weight.shape != param_state_dict[name].shape: + param_state_dict.pop(name, None) model.set_dict(param_state_dict) - - if ckpt_type == 'resume': - assert optim_state_dict, "Can't Resume Last Training's Optimizer State!!!" - optimizer.set_dict(optim_state_dict) - return model + return -def save_dygraph_ckpt(model, optimizer, save_dir, save_name): +def save_model(model, optimizer, save_dir, save_name): if not os.path.exists(save_dir): os.makedirs(save_dir) save_path = os.path.join(save_dir, save_name) - fluid.dygraph.save_dygraph(model.state_dict(), save_path) - fluid.dygraph.save_dygraph(optimizer.state_dict(), save_path) - print("Save checkpoint:", save_dir) + paddle.save(model.state_dict(), save_path + ".pdparams") + paddle.save(optimizer.state_dict(), save_path + ".pdopt") + logger.info("Save checkpoint: {}".format(save_dir)) diff --git a/tools/eval.py b/tools/eval.py index 30a8cc693..b5d874df2 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -19,7 +33,7 @@ from ppdet.core.workspace import load_config, merge_config, create from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.cli import ArgsParser from ppdet.utils.eval_utils import get_infer_results, eval_results -from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt +from ppdet.utils.checkpoint import load_weight import logging FORMAT = '%(asctime)s-%(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -51,7 +65,7 @@ def run(FLAGS, cfg, place): model = create(cfg.architecture) # Init Model - model = load_dygraph_ckpt(model, ckpt=cfg.weights) + load_weight(model, cfg.weights) # Data Reader dataset = cfg.EvalDataset diff --git a/tools/infer.py b/tools/infer.py index 8868235b2..e11a58d0e 100755 --- a/tools/infer.py +++ b/tools/infer.py @@ -34,7 +34,7 @@ from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.visualizer import visualize_results from ppdet.utils.cli import ArgsParser from ppdet.data.reader import create_reader -from ppdet.utils.checkpoint import load_dygraph_ckpt +from ppdet.utils.checkpoint import load_weight from ppdet.utils.eval_utils import get_infer_results import logging FORMAT = '%(asctime)s-%(levelname)s: %(message)s' @@ -141,7 +141,7 @@ def run(FLAGS, cfg): use_default_label) # Init Model - model = load_dygraph_ckpt(model, ckpt=cfg.weights) + load_weight(model, cfg.weights) # Data Reader test_reader = create_reader(cfg.TestDataset, cfg.TestReader) diff --git a/tools/train.py b/tools/train.py index 3804d8480..83a9196ab 100755 --- a/tools/train.py +++ b/tools/train.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -21,7 +35,7 @@ from ppdet.core.workspace import load_config, merge_config, create from ppdet.utils.stats import TrainingStats from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.cli import ArgsParser -from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt +from ppdet.utils.checkpoint import load_weight, load_pretrain_weight, save_model from paddle.distributed import ParallelEnv import logging FORMAT = '%(asctime)s-%(levelname)s: %(message)s' @@ -32,7 +46,7 @@ logger = logging.getLogger(__name__) def parse_args(): parser = ArgsParser() parser.add_argument( - "-ckpt_type", + "--weight_type", default='pretrain', type=str, help="Loading Checkpoints only support 'pretrain', 'finetune', 'resume'." @@ -116,12 +130,12 @@ def run(FLAGS, cfg, place): optimizer = create('OptimizerBuilder')(lr, model.parameters()) # Init Model & Optimzer - model = load_dygraph_ckpt( - model, - optimizer, - cfg.pretrain_weights, - ckpt_type=FLAGS.ckpt_type, - load_static_weights=cfg.get('load_static_weights', False)) + if FLAGS.weight_type == 'resume': + load_weight(model, cfg.pretrain_weights, optimizer) + else: + load_pretrain_weight(model, cfg.pretrain_weights, + cfg.get('load_static_weights', False), + FLAGS.weight_type) # Parallel Model if ParallelEnv().nranks > 1: @@ -132,13 +146,17 @@ def run(FLAGS, cfg, place): time_stat = deque(maxlen=cfg.log_iter) start_time = time.time() end_time = time.time() + # Run Train + start_epoch = optimizer.state_dict()['LR_Scheduler']['last_epoch'] for e_id in range(int(cfg.epoch)): + cur_eid = e_id + start_epoch for iter_id, data in enumerate(train_loader): start_time = end_time end_time = time.time() time_stat.append(end_time - start_time) time_cost = np.mean(time_stat) - eta_sec = (cfg.epoch * step_per_epoch - iter_id) * time_cost + eta_sec = ( + (cfg.epoch - cur_eid) * step_per_epoch - iter_id) * time_cost eta = str(datetime.timedelta(seconds=int(eta_sec))) # Model Forward @@ -162,22 +180,23 @@ def run(FLAGS, cfg, place): if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: # Log state - if iter_id == 0: + if e_id == 0 and iter_id == 0: train_stats = TrainingStats(cfg.log_iter, outputs.keys()) train_stats.update(outputs) logs = train_stats.log() if iter_id % cfg.log_iter == 0: - strs = 'Epoch:{}: iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format( - e_id, iter_id, curr_lr, logs, time_cost, eta) + ips = float(cfg['TrainReader']['batch_size']) / time_cost + strs = 'Epoch:{}: iter: {}, lr: {:.6f}, {}, eta: {}, batch_cost: {:.5f} sec, ips: {:.5f} images/sec'.format( + cur_eid, iter_id, curr_lr, logs, eta, time_cost, ips) logger.info(strs) # Save Stage - if ParallelEnv().local_rank == 0 and e_id % cfg.snapshot_epoch == 0: + if ParallelEnv().local_rank == 0 and cur_eid % cfg.snapshot_epoch == 0: cfg_name = os.path.basename(FLAGS.config).split('.')[0] - save_name = str(e_id + 1) if e_id + 1 != int( + save_name = str(cur_eid) if cur_eid + 1 != int( cfg.epoch) else "model_final" save_dir = os.path.join(cfg.save_dir, cfg_name) - save_dygraph_ckpt(model, optimizer, save_dir, save_name) + save_model(model, optimizer, save_dir, save_name) def main(): -- GitLab