提交 48d85379 编写于 作者: littletomatodonkey's avatar littletomatodonkey

rm load_dyg_pretrain

上级 bd1820b7
...@@ -8,9 +8,9 @@ Global: ...@@ -8,9 +8,9 @@ Global:
save_epoch_step: 3 save_epoch_step: 3
eval_batch_step: [0, 2000] eval_batch_step: [0, 2000]
cal_metric_during_train: true cal_metric_during_train: true
pretrained_model: null pretrained_model:
checkpoints: null checkpoints:
save_inference_dir: null save_inference_dir:
use_visualdl: false use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt character_dict_path: ppocr/utils/ppocr_keys_v1.txt
...@@ -38,7 +38,7 @@ Architecture: ...@@ -38,7 +38,7 @@ Architecture:
algorithm: Distillation algorithm: Distillation
Models: Models:
Student: Student:
pretrained: null pretrained:
freeze_params: false freeze_params: false
return_all_feats: true return_all_feats: true
model_type: rec model_type: rec
...@@ -57,7 +57,7 @@ Architecture: ...@@ -57,7 +57,7 @@ Architecture:
name: CTCHead name: CTCHead
fc_decay: 0.00001 fc_decay: 0.00001
Teacher: Teacher:
pretrained: null pretrained:
freeze_params: false freeze_params: false
return_all_feats: true return_all_feats: true
model_type: rec model_type: rec
...@@ -118,8 +118,8 @@ Train: ...@@ -118,8 +118,8 @@ Train:
- DecodeImage: - DecodeImage:
img_mode: BGR img_mode: BGR
channel_first: false channel_first: false
- RecAug: null - RecAug:
- CTCLabelEncode: null - CTCLabelEncode:
- RecResizeImg: - RecResizeImg:
image_shape: [3, 32, 320] image_shape: [3, 32, 320]
- KeepKeys: - KeepKeys:
...@@ -143,7 +143,7 @@ Eval: ...@@ -143,7 +143,7 @@ Eval:
- DecodeImage: - DecodeImage:
img_mode: BGR img_mode: BGR
channel_first: false channel_first: false
- CTCLabelEncode: null - CTCLabelEncode:
- RecResizeImg: - RecResizeImg:
image_shape: [3, 32, 320] image_shape: [3, 32, 320]
- KeepKeys: - KeepKeys:
......
...@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone ...@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head from ppocr.modeling.heads import build_head
from .base_model import BaseModel from .base_model import BaseModel
from ppocr.utils.save_load import load_dygraph_pretrain from ppocr.utils.save_load import init_model
__all__ = ['DistillationModel'] __all__ = ['DistillationModel']
...@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer): ...@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
pretrained = model_config.pop("pretrained") pretrained = model_config.pop("pretrained")
model = BaseModel(model_config) model = BaseModel(model_config)
if pretrained is not None: if pretrained is not None:
load_dygraph_pretrain(model, path=pretrained) init_model(model, path=pretrained)
if freeze_params: if freeze_params:
for param in model.parameters(): for param in model.parameters():
param.trainable = False param.trainable = False
......
...@@ -23,6 +23,8 @@ import six ...@@ -23,6 +23,8 @@ import six
import paddle import paddle
from ppocr.utils.logging import get_logger
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] __all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
...@@ -42,19 +44,11 @@ def _mkdir_if_not_exist(path, logger): ...@@ -42,19 +44,11 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path)) raise OSError('Failed to mkdir {}'.format(path))
def load_dygraph_pretrain(model, logger=None, path=None): def init_model(config, model, optimizer=None, lr_scheduler=None):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
param_state_dict = paddle.load(path + '.pdparams')
model.set_state_dict(param_state_dict)
return
def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
""" """
load model from checkpoint or pretrained_model load model from checkpoint or pretrained_model
""" """
logger = get_logger()
global_config = config['Global'] global_config = config['Global']
checkpoints = global_config.get('checkpoints') checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model') pretrained_model = global_config.get('pretrained_model')
...@@ -77,13 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): ...@@ -77,13 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
best_model_dict = states_dict.get('best_model_dict', {}) best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict: if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1 best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints)) logger.info("resume from {}".format(checkpoints))
elif pretrained_model: elif pretrained_model:
if not isinstance(pretrained_model, list): if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model] pretrained_model = [pretrained_model]
for pretrained in pretrained_model: for pretrained in pretrained_model:
load_dygraph_pretrain(model, logger, path=pretrained) if not (os.path.isdir(pretrained) or
os.path.exists(pretrained + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(pretrained))
param_state_dict = paddle.load(pretrained + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format( logger.info("load pretrained model from {}".format(
pretrained_model)) pretrained_model))
else: else:
......
...@@ -49,7 +49,7 @@ def main(): ...@@ -49,7 +49,7 @@ def main():
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
best_model_dict = init_model(config, model, logger) best_model_dict = init_model(config, model)
if len(best_model_dict): if len(best_model_dict):
logger.info('metric in ckpt ***************') logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items(): for k, v in best_model_dict.items():
......
...@@ -95,7 +95,7 @@ def main(): ...@@ -95,7 +95,7 @@ def main():
else: # base rec model else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"]) model = build_model(config["Architecture"])
init_model(config, model, logger) init_model(config, model)
model.eval() model.eval()
save_path = config["Global"]["save_inference_dir"] save_path = config["Global"]["save_inference_dir"]
......
...@@ -47,7 +47,7 @@ def main(): ...@@ -47,7 +47,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) init_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
......
...@@ -61,7 +61,7 @@ def main(): ...@@ -61,7 +61,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) init_model(config, model)
# build post process # build post process
post_process_class = build_post_process(config['PostProcess']) post_process_class = build_post_process(config['PostProcess'])
......
...@@ -68,7 +68,7 @@ def main(): ...@@ -68,7 +68,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) init_model(config, model)
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(config['PostProcess'],
......
...@@ -58,7 +58,7 @@ def main(): ...@@ -58,7 +58,7 @@ def main():
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) init_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
......
...@@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer): ...@@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = init_model(config, model, optimizer)
logger.info('train dataloader has {} iters'.format(len(train_dataloader))) logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None: if valid_dataloader is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册