未验证 提交 ae4167dc 编写于 作者: Z zhoujun 提交者: GitHub

merge init_model and load_dygraph_params to load_model (#4623)

* merge init_model and load_dygraph_params to load_model
上级 1417a3c2
......@@ -30,7 +30,7 @@ from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
import tools.program as program
......@@ -89,7 +89,7 @@ def main(config, device, logger, vdl_writer):
logger.info(f"FLOPs after pruning: {flops}")
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, None)
load_model(config, model)
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class)
logger.info(f"metric['hmean']: {metric['hmean']}")
......
......@@ -32,7 +32,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
import tools.program as program
dist.get_world_size()
......@@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer)
pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader)))
......
......@@ -28,7 +28,7 @@ from paddle.jit import to_static
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
from ppocr.metrics import build_metric
......@@ -101,7 +101,7 @@ def main():
quanter = QAT(config=quant_config)
quanter.quantize(model)
init_model(config, model)
load_model(config, model)
model.eval()
# build metric
......
......@@ -37,7 +37,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
import tools.program as program
from paddleslim.dygraph.quant import QAT
......@@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer)
pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader)))
......
......@@ -37,7 +37,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
import tools.program as program
import paddleslim
from paddleslim.dygraph.quant import QAT
......
......@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
from .base_model import BaseModel
from ppocr.utils.save_load import init_model, load_pretrained_params
from ppocr.utils.save_load import load_pretrained_params
__all__ = ['DistillationModel']
......
......@@ -25,7 +25,7 @@ import paddle
from ppocr.utils.logging import get_logger
__all__ = ['init_model', 'save_model', 'load_dygraph_params']
__all__ = ['load_model']
def _mkdir_if_not_exist(path, logger):
......@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path))
def init_model(config, model, optimizer=None, lr_scheduler=None):
def load_model(config, model, optimizer=None):
"""
load model from checkpoint or pretrained_model
"""
......@@ -54,15 +54,14 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
pretrained_model = global_config.get('pretrained_model')
best_model_dict = {}
if checkpoints:
assert os.path.exists(checkpoints + ".pdparams"), \
"Given dir {}.pdparams not exist.".format(checkpoints)
if checkpoints.endswith('pdparams'):
checkpoints = checkpoints.replace('.pdparams', '')
assert os.path.exists(checkpoints + ".pdopt"), \
"Given dir {}.pdopt not exist.".format(checkpoints)
para_dict = paddle.load(checkpoints + '.pdparams')
opti_dict = paddle.load(checkpoints + '.pdopt')
model.set_state_dict(para_dict)
f"The {checkpoints}.pdopt does not exists!"
load_pretrained_params(model, checkpoints)
optim_dict = paddle.load(checkpoints + '.pdopt')
if optimizer is not None:
optimizer.set_state_dict(opti_dict)
optimizer.set_state_dict(optim_dict)
if os.path.exists(checkpoints + '.states'):
with open(checkpoints + '.states', 'rb') as f:
......@@ -73,70 +72,31 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model]
for pretrained in pretrained_model:
if not (os.path.isdir(pretrained) or
os.path.exists(pretrained + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(pretrained))
param_state_dict = paddle.load(pretrained + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format(
pretrained_model))
load_pretrained_params(model, pretrained_model)
else:
logger.info('train from scratch')
return best_model_dict
def load_dygraph_params(config, model, logger, optimizer):
ckp = config['Global']['checkpoints']
if ckp and os.path.exists(ckp + ".pdparams"):
pre_best_model_dict = init_model(config, model, optimizer)
return pre_best_model_dict
else:
pm = config['Global']['pretrained_model']
if pm is None:
return {}
if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"):
logger.info(f"The pretrained_model {pm} does not exists!")
return {}
pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
params = paddle.load(pm)
state_dict = model.state_dict()
new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
model.set_state_dict(new_state_dict)
logger.info(f"loaded pretrained_model successful from {pm}")
return {}
def load_pretrained_params(model, path):
if path is None:
return False
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
print(f"The pretrained_model {path} does not exists!")
return False
path = path if path.endswith('.pdparams') else path + '.pdparams'
params = paddle.load(path)
logger = get_logger()
if path.endswith('pdparams'):
path = path.replace('.pdparams', '')
assert os.path.exists(path + ".pdparams"), \
f"The {path}.pdparams does not exists!"
params = paddle.load(path + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
print(
logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
model.set_state_dict(new_state_dict)
print(f"load pretrain successful from {path}")
logger.info(f"load pretrain successful from {path}")
return model
......
......@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model, load_dygraph_params
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import print_dict
import tools.program as program
......@@ -60,7 +60,7 @@ def main():
else:
model_type = None
best_model_dict = load_dygraph_params(config, model, logger, None)
best_model_dict = load_model(config, model)
if len(best_model_dict):
logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items():
......
......@@ -27,7 +27,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model, load_dygraph_params
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import print_dict
import tools.program as program
......@@ -57,7 +57,7 @@ def main():
model = build_model(config['Architecture'])
best_model_dict = load_dygraph_params(config, model, logger, None)
best_model_dict = load_model(config, model)
if len(best_model_dict):
logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items():
......
......@@ -26,7 +26,7 @@ from paddle.jit import to_static
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
......@@ -107,7 +107,7 @@ def main():
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
init_model(config, model)
load_model(config, model)
model.eval()
save_path = config["Global"]["save_inference_dir"]
......
......@@ -32,7 +32,7 @@ import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
......@@ -47,7 +47,7 @@ def main():
# build model
model = build_model(config['Architecture'])
init_model(config, model)
load_model(config, model)
# create data ops
transforms = []
......
......@@ -34,7 +34,7 @@ import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model, load_dygraph_params
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
......@@ -59,7 +59,7 @@ def main():
# build model
model = build_model(config['Architecture'])
_ = load_dygraph_params(config, model, logger, None)
load_model(config, model)
# build post process
post_process_class = build_post_process(config['PostProcess'])
......
......@@ -34,7 +34,7 @@ import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
......@@ -68,7 +68,7 @@ def main():
# build model
model = build_model(config['Architecture'])
init_model(config, model)
load_model(config, model)
# build post process
post_process_class = build_post_process(config['PostProcess'],
......
......@@ -33,7 +33,7 @@ import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
......@@ -58,7 +58,7 @@ def main():
model = build_model(config['Architecture'])
init_model(config, model)
load_model(config, model)
# create data ops
transforms = []
......@@ -75,9 +75,7 @@ def main():
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
]
elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = [
'image', 'valid_ratio'
]
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
......
......@@ -34,11 +34,12 @@ from paddle.jit import to_static
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
import cv2
def main(config, device, logger, vdl_writer):
global_config = config['Global']
......@@ -53,7 +54,7 @@ def main(config, device, logger, vdl_writer):
model = build_model(config['Architecture'])
init_model(config, model, logger)
load_model(config, model)
# create data ops
transforms = []
......@@ -104,4 +105,3 @@ def main(config, device, logger, vdl_writer):
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
main(config, device, logger, vdl_writer)
......@@ -35,7 +35,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model, load_dygraph_params
from ppocr.utils.save_load import load_model
import tools.program as program
dist.get_world_size()
......@@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册