未验证 提交 738e3688 编写于 作者: M Meiyim 提交者: GitHub

[fix] cannot download model (#610)

* [fix] cannot download model

* [fix] proper loggging

* [fix] checkpoint loading
Co-authored-by: Nchenxuyi <work@yq01-qianmo-com-255-129-11.yq01.baidu.com>
上级 de4063b5
......@@ -128,7 +128,7 @@ model = ErnieModelForSequenceClassification.from_pretrained(
if args.init_checkpoint is not None:
log.info('loading checkpoint from %s' % args.init_checkpoint)
sd, _ = P.load(args.init_checkpoint)
sd = P.load(args.init_checkpoint)
model.set_state_dict(sd)
model = P.DataParallel(model)
......
......@@ -34,7 +34,6 @@ import propeller.paddle as propeller
log.setLevel(logging.DEBUG)
logging.getLogger().setLevel(logging.DEBUG)
log = logging.getLogger()
#from model.bert import BertConfig, BertModelLayer
from ernie.modeling_ernie import ErnieModel, ErnieModelForSequenceClassification
......@@ -70,6 +69,11 @@ parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
parser.add_argument('--eval', action='store_true')
parser.add_argument(
'--save_dir', type=Path, required=True, help='model output directory')
parser.add_argument(
'--init_checkpoint',
type=str,
default=None,
help='checkpoint to warm start from')
parser.add_argument(
'--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer')
parser.add_argument(
......@@ -185,7 +189,7 @@ else:
tokenizer=tokenizer.tokenize),
])
sd, _ = P.load(args.save_dir / 'ckpt.bin')
sd = P.load(args.init_checkpoint)
model.set_dict(sd)
model.eval()
......
......@@ -230,12 +230,14 @@ class PretrainedModel(object):
pretrain_dir_or_url,
force_download=False,
**kwargs):
if not Path(pretrain_dir_or_url).exists() and pretrain_dir_or_url in cls.resource_map:
url = cls.resource_map[pretrain_dir_or_url]
if not Path(pretrain_dir_or_url).exists() and str(
pretrain_dir_or_url) in cls.resource_map:
url = cls.resource_map[str(pretrain_dir_or_url)]
log.info('get pretrain dir from %s' % url)
pretrain_dir = _fetch_from_remote(url, force_download)
else:
log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map)))
log.info('pretrain dir %s not in %s, read from local' %
(pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = Path(pretrain_dir_or_url)
if not pretrain_dir.exists():
......
......@@ -94,16 +94,16 @@ class ErnieTokenizer(object):
pretrain_dir_or_url,
force_download=False,
**kwargs):
if pretrain_dir_or_url in cls.resource_map:
url = cls.resource_map[pretrain_dir_or_url]
if not Path(pretrain_dir_or_url).exists() and str(
pretrain_dir_or_url) in cls.resource_map:
url = cls.resource_map[str(pretrain_dir_or_url)]
log.info('get pretrain dir from %s' % url)
pretrain_dir = _fetch_from_remote(
url, force_download=force_download)
else:
log.info('pretrain dir %s not in %s, read from local' %
(pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = pretrain_dir_or_url
pretrain_dir = Path(pretrain_dir)
pretrain_dir = Path(pretrain_dir_or_url)
if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir)
vocab_path = pretrain_dir / 'vocab.txt'
......@@ -235,12 +235,14 @@ class ErnieTinyTokenizer(ErnieTokenizer):
pretrain_dir_or_url,
force_download=False,
**kwargs):
if pretrain_dir_or_url in cls.resource_map:
url = cls.resource_map[pretrain_dir_or_url]
if not Path(pretrain_dir_or_url).exists() and str(
pretrain_dir_or_url) in cls.resource_map:
url = cls.resource_map[str(pretrain_dir_or_url)]
log.info('get pretrain dir from %s' % url)
pretrain_dir = _fetch_from_remote(url, force_download)
else:
log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map)))
log.info('pretrain dir %s not in %s, read from local' %
(pretrain_dir_or_url, repr(cls.resource_map)))
pretrain_dir = Path(pretrain_dir_or_url)
if not pretrain_dir.exists():
raise ValueError('pretrain dir not found: %s' % pretrain_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册