未验证 提交 ae4167dc 编写于 作者: Z zhoujun 提交者: GitHub

merge init_model and load_dygraph_params to load_model (#4623)

* merge init_model and load_dygraph_params to load_model
上级 1417a3c2
...@@ -30,7 +30,7 @@ from ppocr.modeling.architectures import build_model ...@@ -30,7 +30,7 @@ from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
import tools.program as program import tools.program as program
...@@ -89,7 +89,7 @@ def main(config, device, logger, vdl_writer): ...@@ -89,7 +89,7 @@ def main(config, device, logger, vdl_writer):
logger.info(f"FLOPs after pruning: {flops}") logger.info(f"FLOPs after pruning: {flops}")
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, None) load_model(config, model)
metric = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class) eval_class)
logger.info(f"metric['hmean']: {metric['hmean']}") logger.info(f"metric['hmean']: {metric['hmean']}")
......
...@@ -32,7 +32,7 @@ from ppocr.losses import build_loss ...@@ -32,7 +32,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
import tools.program as program import tools.program as program
dist.get_world_size() dist.get_world_size()
...@@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer): ...@@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'. logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader))) format(len(train_dataloader), len(valid_dataloader)))
......
...@@ -28,7 +28,7 @@ from paddle.jit import to_static ...@@ -28,7 +28,7 @@ from paddle.jit import to_static
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser from tools.program import load_config, merge_config, ArgsParser
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
...@@ -101,7 +101,7 @@ def main(): ...@@ -101,7 +101,7 @@ def main():
quanter = QAT(config=quant_config) quanter = QAT(config=quant_config)
quanter.quantize(model) quanter.quantize(model)
init_model(config, model) load_model(config, model)
model.eval() model.eval()
# build metric # build metric
......
...@@ -37,7 +37,7 @@ from ppocr.losses import build_loss ...@@ -37,7 +37,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
import tools.program as program import tools.program as program
from paddleslim.dygraph.quant import QAT from paddleslim.dygraph.quant import QAT
...@@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer): ...@@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'. logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader))) format(len(train_dataloader), len(valid_dataloader)))
......
...@@ -37,7 +37,7 @@ from ppocr.losses import build_loss ...@@ -37,7 +37,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
import tools.program as program import tools.program as program
import paddleslim import paddleslim
from paddleslim.dygraph.quant import QAT from paddleslim.dygraph.quant import QAT
......
...@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone ...@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head from ppocr.modeling.heads import build_head
from .base_model import BaseModel from .base_model import BaseModel
from ppocr.utils.save_load import init_model, load_pretrained_params from ppocr.utils.save_load import load_pretrained_params
__all__ = ['DistillationModel'] __all__ = ['DistillationModel']
......
...@@ -25,7 +25,7 @@ import paddle ...@@ -25,7 +25,7 @@ import paddle
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
__all__ = ['init_model', 'save_model', 'load_dygraph_params'] __all__ = ['load_model']
def _mkdir_if_not_exist(path, logger): def _mkdir_if_not_exist(path, logger):
...@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger): ...@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path)) raise OSError('Failed to mkdir {}'.format(path))
def init_model(config, model, optimizer=None, lr_scheduler=None): def load_model(config, model, optimizer=None):
""" """
load model from checkpoint or pretrained_model load model from checkpoint or pretrained_model
""" """
...@@ -54,15 +54,14 @@ def init_model(config, model, optimizer=None, lr_scheduler=None): ...@@ -54,15 +54,14 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
pretrained_model = global_config.get('pretrained_model') pretrained_model = global_config.get('pretrained_model')
best_model_dict = {} best_model_dict = {}
if checkpoints: if checkpoints:
assert os.path.exists(checkpoints + ".pdparams"), \ if checkpoints.endswith('pdparams'):
"Given dir {}.pdparams not exist.".format(checkpoints) checkpoints = checkpoints.replace('.pdparams', '')
assert os.path.exists(checkpoints + ".pdopt"), \ assert os.path.exists(checkpoints + ".pdopt"), \
"Given dir {}.pdopt not exist.".format(checkpoints) f"The {checkpoints}.pdopt does not exists!"
para_dict = paddle.load(checkpoints + '.pdparams') load_pretrained_params(model, checkpoints)
opti_dict = paddle.load(checkpoints + '.pdopt') optim_dict = paddle.load(checkpoints + '.pdopt')
model.set_state_dict(para_dict)
if optimizer is not None: if optimizer is not None:
optimizer.set_state_dict(opti_dict) optimizer.set_state_dict(optim_dict)
if os.path.exists(checkpoints + '.states'): if os.path.exists(checkpoints + '.states'):
with open(checkpoints + '.states', 'rb') as f: with open(checkpoints + '.states', 'rb') as f:
...@@ -73,70 +72,31 @@ def init_model(config, model, optimizer=None, lr_scheduler=None): ...@@ -73,70 +72,31 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
best_model_dict['start_epoch'] = states_dict['epoch'] + 1 best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints)) logger.info("resume from {}".format(checkpoints))
elif pretrained_model: elif pretrained_model:
if not isinstance(pretrained_model, list): load_pretrained_params(model, pretrained_model)
pretrained_model = [pretrained_model]
for pretrained in pretrained_model:
if not (os.path.isdir(pretrained) or
os.path.exists(pretrained + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(pretrained))
param_state_dict = paddle.load(pretrained + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format(
pretrained_model))
else: else:
logger.info('train from scratch') logger.info('train from scratch')
return best_model_dict return best_model_dict
def load_dygraph_params(config, model, logger, optimizer):
ckp = config['Global']['checkpoints']
if ckp and os.path.exists(ckp + ".pdparams"):
pre_best_model_dict = init_model(config, model, optimizer)
return pre_best_model_dict
else:
pm = config['Global']['pretrained_model']
if pm is None:
return {}
if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"):
logger.info(f"The pretrained_model {pm} does not exists!")
return {}
pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
params = paddle.load(pm)
state_dict = model.state_dict()
new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
model.set_state_dict(new_state_dict)
logger.info(f"loaded pretrained_model successful from {pm}")
return {}
def load_pretrained_params(model, path): def load_pretrained_params(model, path):
if path is None: logger = get_logger()
return False if path.endswith('pdparams'):
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"): path = path.replace('.pdparams', '')
print(f"The pretrained_model {path} does not exists!") assert os.path.exists(path + ".pdparams"), \
return False f"The {path}.pdparams does not exists!"
path = path if path.endswith('.pdparams') else path + '.pdparams' params = paddle.load(path + '.pdparams')
params = paddle.load(path)
state_dict = model.state_dict() state_dict = model.state_dict()
new_state_dict = {} new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()): for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape): if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2] new_state_dict[k1] = params[k2]
else: else:
print( logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
) )
model.set_state_dict(new_state_dict) model.set_state_dict(new_state_dict)
print(f"load pretrain successful from {path}") logger.info(f"load pretrain successful from {path}")
return model return model
......
...@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader ...@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model, load_dygraph_params from ppocr.utils.save_load import load_model
from ppocr.utils.utility import print_dict from ppocr.utils.utility import print_dict
import tools.program as program import tools.program as program
...@@ -60,7 +60,7 @@ def main(): ...@@ -60,7 +60,7 @@ def main():
else: else:
model_type = None model_type = None
best_model_dict = load_dygraph_params(config, model, logger, None) best_model_dict = load_model(config, model)
if len(best_model_dict): if len(best_model_dict):
logger.info('metric in ckpt ***************') logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items(): for k, v in best_model_dict.items():
......
...@@ -27,7 +27,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) ...@@ -27,7 +27,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model, load_dygraph_params from ppocr.utils.save_load import load_model
from ppocr.utils.utility import print_dict from ppocr.utils.utility import print_dict
import tools.program as program import tools.program as program
...@@ -57,7 +57,7 @@ def main(): ...@@ -57,7 +57,7 @@ def main():
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
best_model_dict = load_dygraph_params(config, model, logger, None) best_model_dict = load_model(config, model)
if len(best_model_dict): if len(best_model_dict):
logger.info('metric in ckpt ***************') logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items(): for k, v in best_model_dict.items():
......
...@@ -26,7 +26,7 @@ from paddle.jit import to_static ...@@ -26,7 +26,7 @@ from paddle.jit import to_static
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser from tools.program import load_config, merge_config, ArgsParser
...@@ -107,7 +107,7 @@ def main(): ...@@ -107,7 +107,7 @@ def main():
else: # base rec model else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"]) model = build_model(config["Architecture"])
init_model(config, model) load_model(config, model)
model.eval() model.eval()
save_path = config["Global"]["save_inference_dir"] save_path = config["Global"]["save_inference_dir"]
......
...@@ -32,7 +32,7 @@ import paddle ...@@ -32,7 +32,7 @@ import paddle
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
...@@ -47,7 +47,7 @@ def main(): ...@@ -47,7 +47,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model) load_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
......
...@@ -34,7 +34,7 @@ import paddle ...@@ -34,7 +34,7 @@ import paddle
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model, load_dygraph_params from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
...@@ -59,7 +59,7 @@ def main(): ...@@ -59,7 +59,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
_ = load_dygraph_params(config, model, logger, None) load_model(config, model)
# build post process # build post process
post_process_class = build_post_process(config['PostProcess']) post_process_class = build_post_process(config['PostProcess'])
......
...@@ -34,7 +34,7 @@ import paddle ...@@ -34,7 +34,7 @@ import paddle
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
...@@ -68,7 +68,7 @@ def main(): ...@@ -68,7 +68,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model) load_model(config, model)
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(config['PostProcess'],
......
...@@ -33,7 +33,7 @@ import paddle ...@@ -33,7 +33,7 @@ import paddle
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
...@@ -58,7 +58,7 @@ def main(): ...@@ -58,7 +58,7 @@ def main():
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model) load_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
...@@ -75,9 +75,7 @@ def main(): ...@@ -75,9 +75,7 @@ def main():
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2' 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
] ]
elif config['Architecture']['algorithm'] == "SAR": elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = [ op[op_name]['keep_keys'] = ['image', 'valid_ratio']
'image', 'valid_ratio'
]
else: else:
op[op_name]['keep_keys'] = ['image'] op[op_name]['keep_keys'] = ['image']
transforms.append(op) transforms.append(op)
......
...@@ -34,11 +34,12 @@ from paddle.jit import to_static ...@@ -34,11 +34,12 @@ from paddle.jit import to_static
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
import cv2 import cv2
def main(config, device, logger, vdl_writer): def main(config, device, logger, vdl_writer):
global_config = config['Global'] global_config = config['Global']
...@@ -53,7 +54,7 @@ def main(config, device, logger, vdl_writer): ...@@ -53,7 +54,7 @@ def main(config, device, logger, vdl_writer):
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) load_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
...@@ -104,4 +105,3 @@ def main(config, device, logger, vdl_writer): ...@@ -104,4 +105,3 @@ def main(config, device, logger, vdl_writer):
if __name__ == '__main__': if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess() config, device, logger, vdl_writer = program.preprocess()
main(config, device, logger, vdl_writer) main(config, device, logger, vdl_writer)
...@@ -35,7 +35,7 @@ from ppocr.losses import build_loss ...@@ -35,7 +35,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model, load_dygraph_params from ppocr.utils.save_load import load_model
import tools.program as program import tools.program as program
dist.get_world_size() dist.get_world_size()
...@@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer): ...@@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters'.format(len(train_dataloader))) logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None: if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format( logger.info('valid dataloader has {} iters'.format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册