diff --git a/demo/finetune_classifier_distributed.py b/demo/finetune_classifier_distributed.py index 3dfa2f680c1a82a032495911d30bbbb7a52e562b..d1df8675369125baf3c1605347cfa0848b0903de 100644 --- a/demo/finetune_classifier_distributed.py +++ b/demo/finetune_classifier_distributed.py @@ -128,7 +128,7 @@ model = ErnieModelForSequenceClassification.from_pretrained( if args.init_checkpoint is not None: log.info('loading checkpoint from %s' % args.init_checkpoint) - sd, _ = P.load(args.init_checkpoint) + sd = P.load(args.init_checkpoint) model.set_state_dict(sd) model = P.DataParallel(model) diff --git a/demo/finetune_sentiment_analysis.py b/demo/finetune_sentiment_analysis.py index 794619840926305ba8de55342a77e00cb2cb821d..015d29d4d2fe0acdf34d0032ebdffd3946515e62 100644 --- a/demo/finetune_sentiment_analysis.py +++ b/demo/finetune_sentiment_analysis.py @@ -34,7 +34,6 @@ import propeller.paddle as propeller log.setLevel(logging.DEBUG) logging.getLogger().setLevel(logging.DEBUG) -log = logging.getLogger() #from model.bert import BertConfig, BertModelLayer from ernie.modeling_ernie import ErnieModel, ErnieModelForSequenceClassification @@ -70,6 +69,11 @@ parser.add_argument('--lr', type=float, default=5e-5, help='learning rate') parser.add_argument('--eval', action='store_true') parser.add_argument( '--save_dir', type=Path, required=True, help='model output directory') +parser.add_argument( + '--init_checkpoint', + type=str, + default=None, + help='checkpoint to warm start from') parser.add_argument( '--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer') parser.add_argument( @@ -185,7 +189,7 @@ else: tokenizer=tokenizer.tokenize), ]) - sd, _ = P.load(args.save_dir / 'ckpt.bin') + sd = P.load(args.init_checkpoint) model.set_dict(sd) model.eval() diff --git a/ernie/modeling_ernie.py b/ernie/modeling_ernie.py index b8b5e8ebce75c6e49e97cccf05fadd08bae2d90d..2d0637c88fb147868fa7dc460d540bf4825c8389 100644 --- a/ernie/modeling_ernie.py +++ b/ernie/modeling_ernie.py @@ -230,12 +230,14 @@ class PretrainedModel(object): pretrain_dir_or_url, force_download=False, **kwargs): - if not Path(pretrain_dir_or_url).exists() and pretrain_dir_or_url in cls.resource_map: - url = cls.resource_map[pretrain_dir_or_url] + if not Path(pretrain_dir_or_url).exists() and str( + pretrain_dir_or_url) in cls.resource_map: + url = cls.resource_map[str(pretrain_dir_or_url)] log.info('get pretrain dir from %s' % url) pretrain_dir = _fetch_from_remote(url, force_download) else: - log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) + log.info('pretrain dir %s not in %s, read from local' % + (pretrain_dir_or_url, repr(cls.resource_map))) pretrain_dir = Path(pretrain_dir_or_url) if not pretrain_dir.exists(): diff --git a/ernie/tokenizing_ernie.py b/ernie/tokenizing_ernie.py index ffcefd2eb943d451c68069e372f0148ceb188515..b4cc247eeba113fc8d62bd63197c3d398922ac90 100644 --- a/ernie/tokenizing_ernie.py +++ b/ernie/tokenizing_ernie.py @@ -94,16 +94,16 @@ class ErnieTokenizer(object): pretrain_dir_or_url, force_download=False, **kwargs): - if pretrain_dir_or_url in cls.resource_map: - url = cls.resource_map[pretrain_dir_or_url] + if not Path(pretrain_dir_or_url).exists() and str( + pretrain_dir_or_url) in cls.resource_map: + url = cls.resource_map[str(pretrain_dir_or_url)] log.info('get pretrain dir from %s' % url) pretrain_dir = _fetch_from_remote( url, force_download=force_download) else: log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) - pretrain_dir = pretrain_dir_or_url - pretrain_dir = Path(pretrain_dir) + pretrain_dir = Path(pretrain_dir_or_url) if not pretrain_dir.exists(): raise ValueError('pretrain dir not found: %s' % pretrain_dir) vocab_path = pretrain_dir / 'vocab.txt' @@ -235,12 +235,14 @@ class ErnieTinyTokenizer(ErnieTokenizer): pretrain_dir_or_url, force_download=False, **kwargs): - if pretrain_dir_or_url in cls.resource_map: - url = cls.resource_map[pretrain_dir_or_url] + if not Path(pretrain_dir_or_url).exists() and str( + pretrain_dir_or_url) in cls.resource_map: + url = cls.resource_map[str(pretrain_dir_or_url)] log.info('get pretrain dir from %s' % url) pretrain_dir = _fetch_from_remote(url, force_download) else: - log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) + log.info('pretrain dir %s not in %s, read from local' % + (pretrain_dir_or_url, repr(cls.resource_map))) pretrain_dir = Path(pretrain_dir_or_url) if not pretrain_dir.exists(): raise ValueError('pretrain dir not found: %s' % pretrain_dir)