From 738e3688218200fcdeab441c9f3435447c8072f6 Mon Sep 17 00:00:00 2001 From: Meiyim Date: Thu, 7 Jan 2021 10:52:54 +0800 Subject: [PATCH] [fix] cannot download model (#610) * [fix] cannot download model * [fix] proper loggging * [fix] checkpoint loading Co-authored-by: chenxuyi --- demo/finetune_classifier_distributed.py | 2 +- demo/finetune_sentiment_analysis.py | 8 ++++++-- ernie/modeling_ernie.py | 8 +++++--- ernie/tokenizing_ernie.py | 16 +++++++++------- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/demo/finetune_classifier_distributed.py b/demo/finetune_classifier_distributed.py index 3dfa2f6..d1df867 100644 --- a/demo/finetune_classifier_distributed.py +++ b/demo/finetune_classifier_distributed.py @@ -128,7 +128,7 @@ model = ErnieModelForSequenceClassification.from_pretrained( if args.init_checkpoint is not None: log.info('loading checkpoint from %s' % args.init_checkpoint) - sd, _ = P.load(args.init_checkpoint) + sd = P.load(args.init_checkpoint) model.set_state_dict(sd) model = P.DataParallel(model) diff --git a/demo/finetune_sentiment_analysis.py b/demo/finetune_sentiment_analysis.py index 7946198..015d29d 100644 --- a/demo/finetune_sentiment_analysis.py +++ b/demo/finetune_sentiment_analysis.py @@ -34,7 +34,6 @@ import propeller.paddle as propeller log.setLevel(logging.DEBUG) logging.getLogger().setLevel(logging.DEBUG) -log = logging.getLogger() #from model.bert import BertConfig, BertModelLayer from ernie.modeling_ernie import ErnieModel, ErnieModelForSequenceClassification @@ -70,6 +69,11 @@ parser.add_argument('--lr', type=float, default=5e-5, help='learning rate') parser.add_argument('--eval', action='store_true') parser.add_argument( '--save_dir', type=Path, required=True, help='model output directory') +parser.add_argument( + '--init_checkpoint', + type=str, + default=None, + help='checkpoint to warm start from') parser.add_argument( '--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer') parser.add_argument( @@ -185,7 +189,7 @@ else: tokenizer=tokenizer.tokenize), ]) - sd, _ = P.load(args.save_dir / 'ckpt.bin') + sd = P.load(args.init_checkpoint) model.set_dict(sd) model.eval() diff --git a/ernie/modeling_ernie.py b/ernie/modeling_ernie.py index b8b5e8e..2d0637c 100644 --- a/ernie/modeling_ernie.py +++ b/ernie/modeling_ernie.py @@ -230,12 +230,14 @@ class PretrainedModel(object): pretrain_dir_or_url, force_download=False, **kwargs): - if not Path(pretrain_dir_or_url).exists() and pretrain_dir_or_url in cls.resource_map: - url = cls.resource_map[pretrain_dir_or_url] + if not Path(pretrain_dir_or_url).exists() and str( + pretrain_dir_or_url) in cls.resource_map: + url = cls.resource_map[str(pretrain_dir_or_url)] log.info('get pretrain dir from %s' % url) pretrain_dir = _fetch_from_remote(url, force_download) else: - log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) + log.info('pretrain dir %s not in %s, read from local' % + (pretrain_dir_or_url, repr(cls.resource_map))) pretrain_dir = Path(pretrain_dir_or_url) if not pretrain_dir.exists(): diff --git a/ernie/tokenizing_ernie.py b/ernie/tokenizing_ernie.py index ffcefd2..b4cc247 100644 --- a/ernie/tokenizing_ernie.py +++ b/ernie/tokenizing_ernie.py @@ -94,16 +94,16 @@ class ErnieTokenizer(object): pretrain_dir_or_url, force_download=False, **kwargs): - if pretrain_dir_or_url in cls.resource_map: - url = cls.resource_map[pretrain_dir_or_url] + if not Path(pretrain_dir_or_url).exists() and str( + pretrain_dir_or_url) in cls.resource_map: + url = cls.resource_map[str(pretrain_dir_or_url)] log.info('get pretrain dir from %s' % url) pretrain_dir = _fetch_from_remote( url, force_download=force_download) else: log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) - pretrain_dir = pretrain_dir_or_url - pretrain_dir = Path(pretrain_dir) + pretrain_dir = Path(pretrain_dir_or_url) if not pretrain_dir.exists(): raise ValueError('pretrain dir not found: %s' % pretrain_dir) vocab_path = pretrain_dir / 'vocab.txt' @@ -235,12 +235,14 @@ class ErnieTinyTokenizer(ErnieTokenizer): pretrain_dir_or_url, force_download=False, **kwargs): - if pretrain_dir_or_url in cls.resource_map: - url = cls.resource_map[pretrain_dir_or_url] + if not Path(pretrain_dir_or_url).exists() and str( + pretrain_dir_or_url) in cls.resource_map: + url = cls.resource_map[str(pretrain_dir_or_url)] log.info('get pretrain dir from %s' % url) pretrain_dir = _fetch_from_remote(url, force_download) else: - log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) + log.info('pretrain dir %s not in %s, read from local' % + (pretrain_dir_or_url, repr(cls.resource_map))) pretrain_dir = Path(pretrain_dir_or_url) if not pretrain_dir.exists(): raise ValueError('pretrain dir not found: %s' % pretrain_dir) -- GitLab