未验证 提交 738e3688 编写于 作者: M Meiyim 提交者: GitHub

[fix] cannot download model (#610)

* [fix] cannot download model

* [fix] proper loggging

* [fix] checkpoint loading
Co-authored-by: Nchenxuyi <work@yq01-qianmo-com-255-129-11.yq01.baidu.com>
上级 de4063b5
...@@ -128,7 +128,7 @@ model = ErnieModelForSequenceClassification.from_pretrained( ...@@ -128,7 +128,7 @@ model = ErnieModelForSequenceClassification.from_pretrained(
if args.init_checkpoint is not None: if args.init_checkpoint is not None:
log.info('loading checkpoint from %s' % args.init_checkpoint) log.info('loading checkpoint from %s' % args.init_checkpoint)
sd, _ = P.load(args.init_checkpoint) sd = P.load(args.init_checkpoint)
model.set_state_dict(sd) model.set_state_dict(sd)
model = P.DataParallel(model) model = P.DataParallel(model)
......
...@@ -34,7 +34,6 @@ import propeller.paddle as propeller ...@@ -34,7 +34,6 @@ import propeller.paddle as propeller
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
logging.getLogger().setLevel(logging.DEBUG) logging.getLogger().setLevel(logging.DEBUG)
log = logging.getLogger()
#from model.bert import BertConfig, BertModelLayer #from model.bert import BertConfig, BertModelLayer
from ernie.modeling_ernie import ErnieModel, ErnieModelForSequenceClassification from ernie.modeling_ernie import ErnieModel, ErnieModelForSequenceClassification
...@@ -70,6 +69,11 @@ parser.add_argument('--lr', type=float, default=5e-5, help='learning rate') ...@@ -70,6 +69,11 @@ parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
parser.add_argument('--eval', action='store_true') parser.add_argument('--eval', action='store_true')
parser.add_argument( parser.add_argument(
'--save_dir', type=Path, required=True, help='model output directory') '--save_dir', type=Path, required=True, help='model output directory')
parser.add_argument(
'--init_checkpoint',
type=str,
default=None,
help='checkpoint to warm start from')
parser.add_argument( parser.add_argument(
'--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer') '--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer')
parser.add_argument( parser.add_argument(
...@@ -185,7 +189,7 @@ else: ...@@ -185,7 +189,7 @@ else:
tokenizer=tokenizer.tokenize), tokenizer=tokenizer.tokenize),
]) ])
sd, _ = P.load(args.save_dir / 'ckpt.bin') sd = P.load(args.init_checkpoint)
model.set_dict(sd) model.set_dict(sd)
model.eval() model.eval()
......
...@@ -230,12 +230,14 @@ class PretrainedModel(object): ...@@ -230,12 +230,14 @@ class PretrainedModel(object):
pretrain_dir_or_url, pretrain_dir_or_url,
force_download=False, force_download=False,
**kwargs): **kwargs):
if not Path(pretrain_dir_or_url).exists() and pretrain_dir_or_url in cls.resource_map: if not Path(pretrain_dir_or_url).exists() and str(
url = cls.resource_map[pretrain_dir_or_url] pretrain_dir_or_url) in cls.resource_map:
url = cls.resource_map[str(pretrain_dir_or_url)]
log.info('get pretrain dir from %s' % url) log.info('get pretrain dir from %s' % url)
pretrain_dir = _fetch_from_remote(url, force_download) pretrain_dir = _fetch_from_remote(url, force_download)
else: else:
log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) log.info('pretrain dir %s not in %s, read from local' %
(pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = Path(pretrain_dir_or_url) pretrain_dir = Path(pretrain_dir_or_url)
if not pretrain_dir.exists(): if not pretrain_dir.exists():
......
...@@ -94,16 +94,16 @@ class ErnieTokenizer(object): ...@@ -94,16 +94,16 @@ class ErnieTokenizer(object):
pretrain_dir_or_url, pretrain_dir_or_url,
force_download=False, force_download=False,
**kwargs): **kwargs):
if pretrain_dir_or_url in cls.resource_map: if not Path(pretrain_dir_or_url).exists() and str(
url = cls.resource_map[pretrain_dir_or_url] pretrain_dir_or_url) in cls.resource_map:
url = cls.resource_map[str(pretrain_dir_or_url)]
log.info('get pretrain dir from %s' % url) log.info('get pretrain dir from %s' % url)
pretrain_dir = _fetch_from_remote( pretrain_dir = _fetch_from_remote(
url, force_download=force_download) url, force_download=force_download)
else: else:
log.info('pretrain dir %s not in %s, read from local' % log.info('pretrain dir %s not in %s, read from local' %
(pretrain_dir_or_url, repr(cls.resource_map))) (pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = pretrain_dir_or_url pretrain_dir = Path(pretrain_dir_or_url)
pretrain_dir = Path(pretrain_dir)
if not pretrain_dir.exists(): if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir) raise ValueError('pretrain dir not found: %s' % pretrain_dir)
vocab_path = pretrain_dir / 'vocab.txt' vocab_path = pretrain_dir / 'vocab.txt'
...@@ -235,12 +235,14 @@ class ErnieTinyTokenizer(ErnieTokenizer): ...@@ -235,12 +235,14 @@ class ErnieTinyTokenizer(ErnieTokenizer):
pretrain_dir_or_url, pretrain_dir_or_url,
force_download=False, force_download=False,
**kwargs): **kwargs):
if pretrain_dir_or_url in cls.resource_map: if not Path(pretrain_dir_or_url).exists() and str(
url = cls.resource_map[pretrain_dir_or_url] pretrain_dir_or_url) in cls.resource_map:
url = cls.resource_map[str(pretrain_dir_or_url)]
log.info('get pretrain dir from %s' % url) log.info('get pretrain dir from %s' % url)
pretrain_dir = _fetch_from_remote(url, force_download) pretrain_dir = _fetch_from_remote(url, force_download)
else: else:
log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) log.info('pretrain dir %s not in %s, read from local' %
(pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = Path(pretrain_dir_or_url) pretrain_dir = Path(pretrain_dir_or_url)
if not pretrain_dir.exists(): if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir) raise ValueError('pretrain dir not found: %s' % pretrain_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册