diff --git a/dygraph/utils/utils.py b/dygraph/utils/utils.py index 46e204dd2e91f319c788eb43ca50602308ce1954..fa995d27af3f78e97bc06d586fa7bb2ecf439f83 100644 --- a/dygraph/utils/utils.py +++ b/dygraph/utils/utils.py @@ -52,7 +52,11 @@ def load_pretrained_model(model, pretrained_model): logging.info('Load pretrained model from {}'.format(pretrained_model)) if os.path.exists(pretrained_model): ckpt_path = os.path.join(pretrained_model, 'model') - para_state_dict, _ = fluid.load_dygraph(ckpt_path) + try: + para_state_dict, _ = fluid.load_dygraph(ckpt_path) + except: + para_state_dict = fluid.load_program_state(pretrained_model) + model_state_dict = model.state_dict() keys = model_state_dict.keys() num_params_loaded = 0