提交 df3425ba 编写于 作者: W wangxiao

update download_models.py & README.md

上级 7060eafe
......@@ -110,28 +110,22 @@ paddlepalm框架的运行原理图如图所示
### 预训练模型
#### 下载
我们提供了BERT、ERNIE等主干网络的相关预训练模型。为了加速模型收敛,获得更佳的测试集表现,我们强烈建议用户在多任务学习时尽量在预训练模型的基础上进行(而不是从参数随机初始化开始)。用户可通过运行`script/download_pretrain_models <model_name>`下载需要的预训练模型,例如,下载预训练BERT模型(uncased large)的命令如下
我们提供了BERT、ERNIE等主干网络的相关预训练模型。为了加速模型收敛,获得更佳的测试集表现,我们强烈建议用户在多任务学习时尽量在预训练模型的基础上进行(而不是从参数随机初始化开始)。用户可以查看可供下载的预训练模型:
```shell
bash script/download_pretrain_backbone.sh bert
python download_models.py
```
脚本会自动在**当前文件夹**中创建一个pretrain_model目录(注:运行DEMO时,需保证pretrain_model文件夹在PALM项目目录下),并在其中创建bert子目录,里面存放预训练模型(`params`文件夹内)、相关的网络参数(`bert_config.json`)和字典(`vocab.txt`)。除了BERT模型,脚本还提供了ERNIE预训练模型(uncased large)的一键下载,将`<model_name>`改成`ernie`即可。全部可用的预训练模型列表见[paddlenlp/lark](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/PaddleLARK)
#### 转换
注意,预训练模型不能直接被框架使用。我们提供了转换脚本可以将其转换成paddlepalm的模型格式。如下,通过运行`script/convert_params.sh`可将预训练模型bert转换成框架的模型格式。
用户可通过运行`python download_models.py download <model_name>`下载需要的预训练模型,例如,下载预训练BERT模型(uncased large)的命令如下:
```shell
bash script/convert_params.sh pretrain_model/bert/params
python download_models.py download bert-en-uncased-large
```
注意,以下恢复操作在执行后述DEMO流程中**无需执行**
若用户需将转换成的paddlepalm模型恢复为原始的预训练模型,可以运行`script/recover_params.sh`进行恢复。
此外,用户也可通过运行`python download_models.py download all`下载已提供的所有预训练模型。
脚本会自动在**当前文件夹**中创建一个pretrain目录(注:运行DEMO时,需保证pretrain_model文件夹在PALM项目目录下),并在其中创建bert子目录,里面存放预训练模型(`params`文件夹内)、相关的网络参数(`bert_config.json`)和字典(`vocab.txt`)。除了BERT模型,脚本还提供了ERNIE预训练模型(uncased large)的一键下载,将`<model_name>`改成`ernie-en-uncased-large`即可。全部可用的预训练模型列表见[paddlenlp/lark](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/PaddleLARK)
```shell
bash script/recover_params.sh pretrain_model/bert/params
```
## 三个DEMO入门PaddlePALM
......
......@@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddlepalm as palm
palm.download('pretrain')
import paddlepalm as palm
import sys
if(sys.argv[1] == 'ls'):
palm.ls(sys.argv[1], sys.argv[2])
if(sys.argv[1] == 'download'):
palm.download(sys.argv[1], sys.argv[2])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册