未验证 提交 3ca0e9f6 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #35 from wangxiao1021/downloader

fix bugs
......@@ -114,16 +114,16 @@ paddlepalm框架的运行原理图如图所示
我们提供了BERT、ERNIE等主干网络的相关预训练模型。为了加速模型收敛,获得更佳的测试集表现,我们强烈建议用户在多任务学习时尽量在预训练模型的基础上进行(而不是从参数随机初始化开始)。用户可以查看可供下载的预训练模型:
```shell
python download_models.py ls pretrain
python download_models.py -l
```
用户可通过运行`python download_models.py download <model_name>`下载需要的预训练模型,例如,下载预训练BERT模型(uncased large)的命令如下:
用户可通过运行`python download_models.py -d <model_name>`下载需要的预训练模型,例如,下载预训练BERT模型(uncased large)的命令如下:
```shell
python download_models.py download bert-en-uncased-large
python download_models.py -d bert-en-uncased-large
```
此外,用户也可通过运行`python download_models.py download all`下载已提供的所有预训练模型。
此外,用户也可通过运行`python download_models.py -d all`下载已提供的所有预训练模型。
脚本会自动在**当前文件夹**中创建一个pretrain目录(注:运行DEMO时,需保证pretrain文件夹在PALM项目目录下),并在其中创建bert子目录,里面存放预训练模型(`params`文件夹内)、相关的网络参数(`bert_config.json`)和字典(`vocab.txt`)。除了BERT模型,脚本还提供了ERNIE预训练模型(uncased large)的一键下载,将`<model_name>`改成`ernie-en-uncased-large`即可。全部可用的预训练模型列表见[paddlenlp/lark](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/PaddleLARK)
......
......@@ -15,7 +15,15 @@
import paddlepalm as palm
import sys
if(sys.argv[1] == 'ls'):
palm.downloader.ls(sys.argv[2])
if(sys.argv[1] == 'download'):
palm.downloader.download('pretrain', sys.argv[2])
import argparse
# create parser
parser = argparse.ArgumentParser(description = 'Download pretrain models for initializing params of backbones. ')
parser.add_argument("-l", "--list", action = 'store_true', help = 'show the list of pretrain models')
parser.add_argument("-d", "--download", action = 'store', help = 'download pretrain models')
args = parser.parse_args()
if(args.list):
palm.downloader.ls('pretrain')
if(args.download):
palm.downloader.download('pretrain', args.download)
......@@ -38,8 +38,6 @@ _items = {
'backbone': {'utils': None},
'tasktype': {'utils': None},
}
def lll():
pass
def _download(item, scope, path, silent=False):
data_url = _items[item][scope]
......@@ -137,7 +135,7 @@ def download(item, scope='all', path='.'):
_download(item, 'utils', path, silent=True)
if scope != 'all':
assert scope in _items[item], '{} is not found. Support items: {}'.format(item, list(_items.keys()))
assert scope in _items[item], '{} is not found. Support scopes: {}'.format(scope, list(_items[item].keys()))
_download(item, scope, path)
else:
for s in _items[item].keys():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册