未验证 提交 7a65af0c 编写于 作者: W wangguanzhong 提交者: GitHub

update save load (#1702)

上级 48e21f3c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -11,10 +25,20 @@ import numpy as np ...@@ -11,10 +25,20 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from .download import get_weights_path from .download import get_weights_path
import logging
logger = logging.getLogger(__name__)
def is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def get_ckpt_path(path): def get_weight_path(path):
if path.startswith('http://') or path.startswith('https://'):
env = os.environ env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID']) trainer_id = int(env['PADDLE_TRAINER_ID'])
...@@ -46,54 +70,79 @@ def get_ckpt_path(path): ...@@ -46,54 +70,79 @@ def get_ckpt_path(path):
return path return path
def load_dygraph_ckpt(model, def _strip_postfix(path):
optimizer=None, path, ext = os.path.splitext(path)
pretrain_ckpt=None, assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
ckpt=None, "Unknown postfix {} from weights".format(ext)
ckpt_type=None, return path
exclude_params=[],
load_static_weights=False):
def load_weight(model, weight, optimizer=None):
if is_url(weight):
weight = get_weight_path(weight)
path = _strip_postfix(weight)
pdparam_path = path + '.pdparams'
if not os.path.exists(pdparam_path):
raise ValueError("Model pretrain path {} does not "
"exists.".format(pdparam_path))
param_state_dict = paddle.load(pdparam_path)
model.set_dict(param_state_dict)
if optimizer is not None and os.path.exists(path + '.pdopt'):
optim_state_dict = paddle.load(path + '.pdopt')
optimizer.set_state_dict(optim_state_dict)
return
def load_pretrain_weight(model,
pretrain_weight,
load_static_weights=False,
weight_type='pretrain'):
assert weight_type in ['pretrain', 'finetune']
if is_url(pretrain_weight):
pretrain_weight = get_weight_path(pretrain_weight)
path = _strip_postfix(pretrain_weight)
if not (os.path.isdir(path) or os.path.isfile(path) or
os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
model_dict = model.state_dict()
assert ckpt_type in ['pretrain', 'resume', 'finetune', None]
if ckpt_type == 'pretrain' and ckpt is None:
ckpt = pretrain_ckpt
ckpt = get_ckpt_path(ckpt)
assert os.path.exists(ckpt), "Path {} does not exist.".format(ckpt)
if load_static_weights: if load_static_weights:
pre_state_dict = fluid.load_program_state(ckpt) pre_state_dict = paddle.static.load_program_state(path)
param_state_dict = {} param_state_dict = {}
model_dict = model.state_dict()
for key in model_dict.keys(): for key in model_dict.keys():
weight_name = model_dict[key].name weight_name = model_dict[key].name
if weight_name in pre_state_dict.keys(): if weight_name in pre_state_dict.keys():
print('Load weight: {}, shape: {}'.format( logger.info('Load weight: {}, shape: {}'.format(
weight_name, pre_state_dict[weight_name].shape)) weight_name, pre_state_dict[weight_name].shape))
param_state_dict[key] = pre_state_dict[weight_name] param_state_dict[key] = pre_state_dict[weight_name]
else: else:
param_state_dict[key] = model_dict[key] param_state_dict[key] = model_dict[key]
model.set_dict(param_state_dict) model.set_dict(param_state_dict)
return model return
param_state_dict, optim_state_dict = fluid.load_dygraph(ckpt)
if len(exclude_params) != 0: param_state_dict = paddle.load(path + '.pdparams')
for k in exclude_params: if weight_type == 'pretrain':
param_state_dict.pop(k, None)
if ckpt_type == 'pretrain':
model.backbone.set_dict(param_state_dict) model.backbone.set_dict(param_state_dict)
else: else:
ignore_set = set()
for name, weight in model_dict:
if name in param_state_dict:
if weight.shape != param_state_dict[name].shape:
param_state_dict.pop(name, None)
model.set_dict(param_state_dict) model.set_dict(param_state_dict)
return
if ckpt_type == 'resume':
assert optim_state_dict, "Can't Resume Last Training's Optimizer State!!!"
optimizer.set_dict(optim_state_dict)
return model
def save_dygraph_ckpt(model, optimizer, save_dir, save_name): def save_model(model, optimizer, save_dir, save_name):
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
save_path = os.path.join(save_dir, save_name) save_path = os.path.join(save_dir, save_name)
fluid.dygraph.save_dygraph(model.state_dict(), save_path) paddle.save(model.state_dict(), save_path + ".pdparams")
fluid.dygraph.save_dygraph(optimizer.state_dict(), save_path) paddle.save(optimizer.state_dict(), save_path + ".pdopt")
print("Save checkpoint:", save_dir) logger.info("Save checkpoint: {}".format(save_dir))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -19,7 +33,7 @@ from ppdet.core.workspace import load_config, merge_config, create ...@@ -19,7 +33,7 @@ from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.utils.eval_utils import get_infer_results, eval_results from ppdet.utils.eval_utils import get_infer_results, eval_results
from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt from ppdet.utils.checkpoint import load_weight
import logging import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s' FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT) logging.basicConfig(level=logging.INFO, format=FORMAT)
...@@ -51,7 +65,7 @@ def run(FLAGS, cfg, place): ...@@ -51,7 +65,7 @@ def run(FLAGS, cfg, place):
model = create(cfg.architecture) model = create(cfg.architecture)
# Init Model # Init Model
model = load_dygraph_ckpt(model, ckpt=cfg.weights) load_weight(model, cfg.weights)
# Data Reader # Data Reader
dataset = cfg.EvalDataset dataset = cfg.EvalDataset
......
...@@ -34,7 +34,7 @@ from ppdet.utils.check import check_gpu, check_version, check_config ...@@ -34,7 +34,7 @@ from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.visualizer import visualize_results from ppdet.utils.visualizer import visualize_results
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.data.reader import create_reader from ppdet.data.reader import create_reader
from ppdet.utils.checkpoint import load_dygraph_ckpt from ppdet.utils.checkpoint import load_weight
from ppdet.utils.eval_utils import get_infer_results from ppdet.utils.eval_utils import get_infer_results
import logging import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s' FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
...@@ -141,7 +141,7 @@ def run(FLAGS, cfg): ...@@ -141,7 +141,7 @@ def run(FLAGS, cfg):
use_default_label) use_default_label)
# Init Model # Init Model
model = load_dygraph_ckpt(model, ckpt=cfg.weights) load_weight(model, cfg.weights)
# Data Reader # Data Reader
test_reader = create_reader(cfg.TestDataset, cfg.TestReader) test_reader = create_reader(cfg.TestDataset, cfg.TestReader)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -21,7 +35,7 @@ from ppdet.core.workspace import load_config, merge_config, create ...@@ -21,7 +35,7 @@ from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.stats import TrainingStats from ppdet.utils.stats import TrainingStats
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt from ppdet.utils.checkpoint import load_weight, load_pretrain_weight, save_model
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
import logging import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s' FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
...@@ -32,7 +46,7 @@ logger = logging.getLogger(__name__) ...@@ -32,7 +46,7 @@ logger = logging.getLogger(__name__)
def parse_args(): def parse_args():
parser = ArgsParser() parser = ArgsParser()
parser.add_argument( parser.add_argument(
"-ckpt_type", "--weight_type",
default='pretrain', default='pretrain',
type=str, type=str,
help="Loading Checkpoints only support 'pretrain', 'finetune', 'resume'." help="Loading Checkpoints only support 'pretrain', 'finetune', 'resume'."
...@@ -116,12 +130,12 @@ def run(FLAGS, cfg, place): ...@@ -116,12 +130,12 @@ def run(FLAGS, cfg, place):
optimizer = create('OptimizerBuilder')(lr, model.parameters()) optimizer = create('OptimizerBuilder')(lr, model.parameters())
# Init Model & Optimzer # Init Model & Optimzer
model = load_dygraph_ckpt( if FLAGS.weight_type == 'resume':
model, load_weight(model, cfg.pretrain_weights, optimizer)
optimizer, else:
cfg.pretrain_weights, load_pretrain_weight(model, cfg.pretrain_weights,
ckpt_type=FLAGS.ckpt_type, cfg.get('load_static_weights', False),
load_static_weights=cfg.get('load_static_weights', False)) FLAGS.weight_type)
# Parallel Model # Parallel Model
if ParallelEnv().nranks > 1: if ParallelEnv().nranks > 1:
...@@ -132,13 +146,17 @@ def run(FLAGS, cfg, place): ...@@ -132,13 +146,17 @@ def run(FLAGS, cfg, place):
time_stat = deque(maxlen=cfg.log_iter) time_stat = deque(maxlen=cfg.log_iter)
start_time = time.time() start_time = time.time()
end_time = time.time() end_time = time.time()
# Run Train
start_epoch = optimizer.state_dict()['LR_Scheduler']['last_epoch']
for e_id in range(int(cfg.epoch)): for e_id in range(int(cfg.epoch)):
cur_eid = e_id + start_epoch
for iter_id, data in enumerate(train_loader): for iter_id, data in enumerate(train_loader):
start_time = end_time start_time = end_time
end_time = time.time() end_time = time.time()
time_stat.append(end_time - start_time) time_stat.append(end_time - start_time)
time_cost = np.mean(time_stat) time_cost = np.mean(time_stat)
eta_sec = (cfg.epoch * step_per_epoch - iter_id) * time_cost eta_sec = (
(cfg.epoch - cur_eid) * step_per_epoch - iter_id) * time_cost
eta = str(datetime.timedelta(seconds=int(eta_sec))) eta = str(datetime.timedelta(seconds=int(eta_sec)))
# Model Forward # Model Forward
...@@ -162,22 +180,23 @@ def run(FLAGS, cfg, place): ...@@ -162,22 +180,23 @@ def run(FLAGS, cfg, place):
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
# Log state # Log state
if iter_id == 0: if e_id == 0 and iter_id == 0:
train_stats = TrainingStats(cfg.log_iter, outputs.keys()) train_stats = TrainingStats(cfg.log_iter, outputs.keys())
train_stats.update(outputs) train_stats.update(outputs)
logs = train_stats.log() logs = train_stats.log()
if iter_id % cfg.log_iter == 0: if iter_id % cfg.log_iter == 0:
strs = 'Epoch:{}: iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format( ips = float(cfg['TrainReader']['batch_size']) / time_cost
e_id, iter_id, curr_lr, logs, time_cost, eta) strs = 'Epoch:{}: iter: {}, lr: {:.6f}, {}, eta: {}, batch_cost: {:.5f} sec, ips: {:.5f} images/sec'.format(
cur_eid, iter_id, curr_lr, logs, eta, time_cost, ips)
logger.info(strs) logger.info(strs)
# Save Stage # Save Stage
if ParallelEnv().local_rank == 0 and e_id % cfg.snapshot_epoch == 0: if ParallelEnv().local_rank == 0 and cur_eid % cfg.snapshot_epoch == 0:
cfg_name = os.path.basename(FLAGS.config).split('.')[0] cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_name = str(e_id + 1) if e_id + 1 != int( save_name = str(cur_eid) if cur_eid + 1 != int(
cfg.epoch) else "model_final" cfg.epoch) else "model_final"
save_dir = os.path.join(cfg.save_dir, cfg_name) save_dir = os.path.join(cfg.save_dir, cfg_name)
save_dygraph_ckpt(model, optimizer, save_dir, save_name) save_model(model, optimizer, save_dir, save_name)
def main(): def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册