未验证 提交 7a65af0c 编写于 作者: W wangguanzhong 提交者: GitHub

update save load (#1702)

上级 48e21f3c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......@@ -11,89 +25,124 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from .download import get_weights_path
import logging
logger = logging.getLogger(__name__)
def get_ckpt_path(path):
if path.startswith('http://') or path.startswith('https://'):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
path = get_weights_path(path)
else:
from ppdet.utils.download import map_path, WEIGHTS_HOME
weight_path = map_path(path, WEIGHTS_HOME)
lock_path = weight_path + '.lock'
if not os.path.exists(weight_path):
try:
os.makedirs(os.path.dirname(weight_path))
except OSError as e:
if e.errno != errno.EEXIST:
raise
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if trainer_id == 0:
get_weights_path(path)
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(1)
path = weight_path
else:
def is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def get_weight_path(path):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
path = get_weights_path(path)
else:
from ppdet.utils.download import map_path, WEIGHTS_HOME
weight_path = map_path(path, WEIGHTS_HOME)
lock_path = weight_path + '.lock'
if not os.path.exists(weight_path):
try:
os.makedirs(os.path.dirname(weight_path))
except OSError as e:
if e.errno != errno.EEXIST:
raise
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if trainer_id == 0:
get_weights_path(path)
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(1)
path = weight_path
else:
path = get_weights_path(path)
return path
def _strip_postfix(path):
path, ext = os.path.splitext(path)
assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
"Unknown postfix {} from weights".format(ext)
return path
def load_dygraph_ckpt(model,
optimizer=None,
pretrain_ckpt=None,
ckpt=None,
ckpt_type=None,
exclude_params=[],
load_static_weights=False):
def load_weight(model, weight, optimizer=None):
if is_url(weight):
weight = get_weight_path(weight)
path = _strip_postfix(weight)
pdparam_path = path + '.pdparams'
if not os.path.exists(pdparam_path):
raise ValueError("Model pretrain path {} does not "
"exists.".format(pdparam_path))
param_state_dict = paddle.load(pdparam_path)
model.set_dict(param_state_dict)
if optimizer is not None and os.path.exists(path + '.pdopt'):
optim_state_dict = paddle.load(path + '.pdopt')
optimizer.set_state_dict(optim_state_dict)
return
def load_pretrain_weight(model,
pretrain_weight,
load_static_weights=False,
weight_type='pretrain'):
assert weight_type in ['pretrain', 'finetune']
if is_url(pretrain_weight):
pretrain_weight = get_weight_path(pretrain_weight)
path = _strip_postfix(pretrain_weight)
if not (os.path.isdir(path) or os.path.isfile(path) or
os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
model_dict = model.state_dict()
assert ckpt_type in ['pretrain', 'resume', 'finetune', None]
if ckpt_type == 'pretrain' and ckpt is None:
ckpt = pretrain_ckpt
ckpt = get_ckpt_path(ckpt)
assert os.path.exists(ckpt), "Path {} does not exist.".format(ckpt)
if load_static_weights:
pre_state_dict = fluid.load_program_state(ckpt)
pre_state_dict = paddle.static.load_program_state(path)
param_state_dict = {}
model_dict = model.state_dict()
for key in model_dict.keys():
weight_name = model_dict[key].name
if weight_name in pre_state_dict.keys():
print('Load weight: {}, shape: {}'.format(
logger.info('Load weight: {}, shape: {}'.format(
weight_name, pre_state_dict[weight_name].shape))
param_state_dict[key] = pre_state_dict[weight_name]
else:
param_state_dict[key] = model_dict[key]
model.set_dict(param_state_dict)
return model
param_state_dict, optim_state_dict = fluid.load_dygraph(ckpt)
return
if len(exclude_params) != 0:
for k in exclude_params:
param_state_dict.pop(k, None)
if ckpt_type == 'pretrain':
param_state_dict = paddle.load(path + '.pdparams')
if weight_type == 'pretrain':
model.backbone.set_dict(param_state_dict)
else:
ignore_set = set()
for name, weight in model_dict:
if name in param_state_dict:
if weight.shape != param_state_dict[name].shape:
param_state_dict.pop(name, None)
model.set_dict(param_state_dict)
if ckpt_type == 'resume':
assert optim_state_dict, "Can't Resume Last Training's Optimizer State!!!"
optimizer.set_dict(optim_state_dict)
return model
return
def save_dygraph_ckpt(model, optimizer, save_dir, save_name):
def save_model(model, optimizer, save_dir, save_name):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = os.path.join(save_dir, save_name)
fluid.dygraph.save_dygraph(model.state_dict(), save_path)
fluid.dygraph.save_dygraph(optimizer.state_dict(), save_path)
print("Save checkpoint:", save_dir)
paddle.save(model.state_dict(), save_path + ".pdparams")
paddle.save(optimizer.state_dict(), save_path + ".pdopt")
logger.info("Save checkpoint: {}".format(save_dir))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......@@ -19,7 +33,7 @@ from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.utils.eval_utils import get_infer_results, eval_results
from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt
from ppdet.utils.checkpoint import load_weight
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
......@@ -51,7 +65,7 @@ def run(FLAGS, cfg, place):
model = create(cfg.architecture)
# Init Model
model = load_dygraph_ckpt(model, ckpt=cfg.weights)
load_weight(model, cfg.weights)
# Data Reader
dataset = cfg.EvalDataset
......
......@@ -34,7 +34,7 @@ from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.visualizer import visualize_results
from ppdet.utils.cli import ArgsParser
from ppdet.data.reader import create_reader
from ppdet.utils.checkpoint import load_dygraph_ckpt
from ppdet.utils.checkpoint import load_weight
from ppdet.utils.eval_utils import get_infer_results
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
......@@ -141,7 +141,7 @@ def run(FLAGS, cfg):
use_default_label)
# Init Model
model = load_dygraph_ckpt(model, ckpt=cfg.weights)
load_weight(model, cfg.weights)
# Data Reader
test_reader = create_reader(cfg.TestDataset, cfg.TestReader)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......@@ -21,7 +35,7 @@ from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.stats import TrainingStats
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight, save_model
from paddle.distributed import ParallelEnv
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
......@@ -32,7 +46,7 @@ logger = logging.getLogger(__name__)
def parse_args():
parser = ArgsParser()
parser.add_argument(
"-ckpt_type",
"--weight_type",
default='pretrain',
type=str,
help="Loading Checkpoints only support 'pretrain', 'finetune', 'resume'."
......@@ -116,12 +130,12 @@ def run(FLAGS, cfg, place):
optimizer = create('OptimizerBuilder')(lr, model.parameters())
# Init Model & Optimzer
model = load_dygraph_ckpt(
model,
optimizer,
cfg.pretrain_weights,
ckpt_type=FLAGS.ckpt_type,
load_static_weights=cfg.get('load_static_weights', False))
if FLAGS.weight_type == 'resume':
load_weight(model, cfg.pretrain_weights, optimizer)
else:
load_pretrain_weight(model, cfg.pretrain_weights,
cfg.get('load_static_weights', False),
FLAGS.weight_type)
# Parallel Model
if ParallelEnv().nranks > 1:
......@@ -132,13 +146,17 @@ def run(FLAGS, cfg, place):
time_stat = deque(maxlen=cfg.log_iter)
start_time = time.time()
end_time = time.time()
# Run Train
start_epoch = optimizer.state_dict()['LR_Scheduler']['last_epoch']
for e_id in range(int(cfg.epoch)):
cur_eid = e_id + start_epoch
for iter_id, data in enumerate(train_loader):
start_time = end_time
end_time = time.time()
time_stat.append(end_time - start_time)
time_cost = np.mean(time_stat)
eta_sec = (cfg.epoch * step_per_epoch - iter_id) * time_cost
eta_sec = (
(cfg.epoch - cur_eid) * step_per_epoch - iter_id) * time_cost
eta = str(datetime.timedelta(seconds=int(eta_sec)))
# Model Forward
......@@ -162,22 +180,23 @@ def run(FLAGS, cfg, place):
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
# Log state
if iter_id == 0:
if e_id == 0 and iter_id == 0:
train_stats = TrainingStats(cfg.log_iter, outputs.keys())
train_stats.update(outputs)
logs = train_stats.log()
if iter_id % cfg.log_iter == 0:
strs = 'Epoch:{}: iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format(
e_id, iter_id, curr_lr, logs, time_cost, eta)
ips = float(cfg['TrainReader']['batch_size']) / time_cost
strs = 'Epoch:{}: iter: {}, lr: {:.6f}, {}, eta: {}, batch_cost: {:.5f} sec, ips: {:.5f} images/sec'.format(
cur_eid, iter_id, curr_lr, logs, eta, time_cost, ips)
logger.info(strs)
# Save Stage
if ParallelEnv().local_rank == 0 and e_id % cfg.snapshot_epoch == 0:
if ParallelEnv().local_rank == 0 and cur_eid % cfg.snapshot_epoch == 0:
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_name = str(e_id + 1) if e_id + 1 != int(
save_name = str(cur_eid) if cur_eid + 1 != int(
cfg.epoch) else "model_final"
save_dir = os.path.join(cfg.save_dir, cfg_name)
save_dygraph_ckpt(model, optimizer, save_dir, save_name)
save_model(model, optimizer, save_dir, save_name)
def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册