提交 bf081f05 编写于 作者: X xixiaoyao

Merge branch 'master' of https://github.com/PaddlePaddle/PALM

...@@ -19,8 +19,9 @@ Beyond the research scope, PaddlePALM has been applied on **Baidu Search Engine* ...@@ -19,8 +19,9 @@ Beyond the research scope, PaddlePALM has been applied on **Baidu Search Engine*
- **Easy-to-use:** with PALM, *8 steps* to achieve a typical NLP task. Moreover, all basic components (e.g., the model backbone, dataset reader, task output head, optimizer...) have been decoupled, which allows the replacement of any component to other candidates with quite minor changes of your code. - **Easy-to-use:** with PALM, *8 steps* to achieve a typical NLP task. Moreover, all basic components (e.g., the model backbone, dataset reader, task output head, optimizer...) have been decoupled, which allows the replacement of any component to other candidates with quite minor changes of your code.
- **Built-in Popular NLP Backbones and Pre-trained models:** multiple state-of-the-art general purpose model architectures and pretrained models (e.g., BERT,ERNIE,RoBERTa,...) are built-in. - **Built-in Popular NLP Backbones and Pre-trained models:** multiple state-of-the-art general purpose model architectures and pretrained models (e.g., BERT,ERNIE,RoBERTa,...) are built-in.
- **Multi-task Learning friendly:** *6 steps* to achieve multi-task learning for prepared tasks. - **Easy to play Multi-task Learning:** only one API is needed for jointly training of several tasks with parameters reusement.
- **Large Scale and Pre-training friendly:** automatically utilize multi-gpus (if exists) to accelerate training and inference. Minor codes is required for distributed training on clusters. - **Support train/eval with Multi-GPUs:** automatically recognize and adapt to multiple gpus mode to accelerate training and inference.
- **Pre-training friendly:** self-supervised tasks (e.g., mask language model) are built-in to facilitate pre-training. Easy to train from scratch.
- **Easy to Customize:** support customized development of any component (e.g, backbone, task head, reader and optimizer) with reusement of pre-defined ones, which gives developers high flexibility and effeciency to adapt for diverse NLP scenes. - **Easy to Customize:** support customized development of any component (e.g, backbone, task head, reader and optimizer) with reusement of pre-defined ones, which gives developers high flexibility and effeciency to adapt for diverse NLP scenes.
You can easily re-produce following competitive results with minor codes, which covers most of NLP tasks such as classification, matching, sequence labeling, reading comprehension, dialogue understanding and so on. More details can be found in `examples`. You can easily re-produce following competitive results with minor codes, which covers most of NLP tasks such as classification, matching, sequence labeling, reading comprehension, dialogue understanding and so on. More details can be found in `examples`.
......
...@@ -117,15 +117,15 @@ PaddlePALM是一个设计良好的高级NLP框架。基于PaddlePALM的轻量级 ...@@ -117,15 +117,15 @@ PaddlePALM是一个设计良好的高级NLP框架。基于PaddlePALM的轻量级
| 模块 | 描述 | | 模块 | 描述 |
| - | - | | - | - |
| **paddlepalm** | 一个开源的NLP预训练和多任务学习框架,建立在paddlepaddle框架上。 | | **paddlepalm** | 基于PaddlePaddle框架的high-level NLP预训练和多任务学习框架。 |
| **paddlepalm.reader** | 特定于任务的数据集读取工具的集合。| | **paddlepalm.reader** | 预置的任务数据集读取与预处理工具。|
| **paddlepalm.backbone** | 一系列经典的NLP表示模型,如BERT, ERNIE, RoBERTa。| | **paddlepalm.backbone** | 预置的主干网络,如BERT, ERNIE, RoBERTa。|
| **paddlepalm.head** | 任务特定输出层的集合。| | **paddlepalm.head** | 预置的任务输出层。|
| **paddlepalm.lr_sched** | 一个学习率时间表的集合。| | **paddlepalm.lr_sched** | 预置的学习率规划策略。|
| **paddlepalm.optimizer** | 优化器的集合。| | **paddlepalm.optimizer** | 预置的优化器。|
| **paddlepalm.downloader** | 预训练模型与配置和vocab文件的下载模块。| | **paddlepalm.downloader** | 预训练模型管理与下载模块。|
| **paddlepalm.Trainer** | 单一任务训练/预测。一个训练器是建立计算图,管理训练和评估过程,实现模型/检查点保存和pretrain_model/检查点加载。| | **paddlepalm.Trainer** | 任务训练/预测单元。训练器用于建立计算图,管理训练和评估过程,实现模型/检查点保存和pretrain_model/检查点加载等。|
| **paddlepalm.MultiHeadTrainer** | 进行多任务训练/预测的核心模块。一个MultiHeadTrainer建立在几个Trainer的基础上。在继承Trainer的基础上,实现了模型主干网络跨任务复用,采用多任务学习,多任务推理,来保证更有效的评估和预测。| | **paddlepalm.MultiHeadTrainer** | 完成多任务训练/预测的模块。一个MultiHeadTrainer建立在几个Trainer的基础上。实现了模型主干网络跨任务复用、多任务学习、多任务推理等。|
## 安装 ## 安装
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function
import os import os
import requests
import tarfile import tarfile
import shutil import shutil
from tqdm import tqdm import sys
import urllib
URLLIB=urllib
if sys.version_info >= (3, 0):
import urllib.request
URLLIB=urllib.request
def download(src, url): def download(src, url):
file_size = int(requests.head(url).headers['Content-Length']) def _reporthook(count, chunk_size, total_size):
bytes_so_far = count * chunk_size
header = { percent = float(bytes_so_far) / float(total_size)
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/' if percent > 1:
'70.0.3538.67 Safari/537.36' percent = 1
} print('\r>> Downloading... {:.1%}'.format(percent), end="")
pbar = tqdm(total=file_size)
resp = requests.get(url, headers=header, stream=True)
with open(src, 'ab') as f:
for chunk in resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size
URLLIB.urlretrieve(url, src, reporthook=_reporthook)
abs_path = os.path.abspath(__file__) abs_path = os.path.abspath(__file__)
download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz" download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz"
...@@ -46,5 +39,4 @@ for file in os.listdir(os.path.join(target_dir, 'task_data', 'chnsenticorp')): ...@@ -46,5 +39,4 @@ for file in os.listdir(os.path.join(target_dir, 'task_data', 'chnsenticorp')):
shutil.move(os.path.join(target_dir, 'task_data', 'chnsenticorp', file), dst_dir) shutil.move(os.path.join(target_dir, 'task_data', 'chnsenticorp', file), dst_dir)
shutil.rmtree(os.path.join(target_dir, 'task_data')) shutil.rmtree(os.path.join(target_dir, 'task_data'))
print(" done!")
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function
import os import os
import requests import sys
from tqdm import tqdm import urllib
URLLIB=urllib
if sys.version_info >= (3, 0):
import urllib.request
URLLIB=urllib.request
def download(src, url): def download(src, url):
file_size = int(requests.head(url).headers['Content-Length']) def _reporthook(count, chunk_size, total_size):
header = { bytes_so_far = count * chunk_size
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/' percent = float(bytes_so_far) / float(total_size)
'70.0.3538.67 Safari/537.36' if percent > 1:
} percent = 1
pbar = tqdm(total=file_size) print('\r>> Downloading... {:.1%}'.format(percent), end="")
resp = requests.get(url, headers=header, stream=True)
with open(src, 'ab') as f:
for chunk in resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close() URLLIB.urlretrieve(url, src, reporthook=_reporthook)
return file_size
abs_path = os.path.abspath(__file__) abs_path = os.path.abspath(__file__)
...@@ -32,3 +27,4 @@ if not os.path.exists(data_dir) or not os.path.isdir(data_dir): ...@@ -32,3 +27,4 @@ if not os.path.exists(data_dir) or not os.path.isdir(data_dir):
download_url = "http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv" download_url = "http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv"
downlaod_path = os.path.join(data_dir, "quora_duplicate_questions.tsv") downlaod_path = os.path.join(data_dir, "quora_duplicate_questions.tsv")
download(downlaod_path, download_url) download(downlaod_path, download_url)
print(" done!")
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function
import os import os
import requests
import tarfile import tarfile
import shutil import shutil
from tqdm import tqdm import sys
import urllib
URLLIB=urllib
if sys.version_info >= (3, 0):
import urllib.request
URLLIB=urllib.request
def download(src, url): def download(src, url):
file_size = int(requests.head(url).headers['Content-Length']) def _reporthook(count, chunk_size, total_size):
bytes_so_far = count * chunk_size
header = { percent = float(bytes_so_far) / float(total_size)
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/' if percent > 1:
'70.0.3538.67 Safari/537.36' percent = 1
} print('\r>> Downloading... {:.1%}'.format(percent), end="")
pbar = tqdm(total=file_size)
resp = requests.get(url, headers=header, stream=True)
with open(src, 'ab') as f:
for chunk in resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size
URLLIB.urlretrieve(url, src, reporthook=_reporthook)
abs_path = os.path.abspath(__file__) abs_path = os.path.abspath(__file__)
download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz" download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz"
...@@ -46,5 +39,5 @@ for file in os.listdir(os.path.join(target_dir, 'task_data', 'cmrc2018')): ...@@ -46,5 +39,5 @@ for file in os.listdir(os.path.join(target_dir, 'task_data', 'cmrc2018')):
shutil.move(os.path.join(target_dir, 'task_data', 'cmrc2018', file), dst_dir) shutil.move(os.path.join(target_dir, 'task_data', 'cmrc2018', file), dst_dir)
shutil.rmtree(os.path.join(target_dir, 'task_data')) shutil.rmtree(os.path.join(target_dir, 'task_data'))
print(" done!")
## Example 6: Joint Training in Dialogue ## Example 6: Joint Training of Dialogue Intent Recognition and Slot Filling
This task is a slot filling task. During training, the task uses intent determination task to assist in training slot filling model. The following sections detail model preparation, dataset preparation, and how to run the task. This example achieves the joint training ofg Dialogue Intent Recognition and Slot Filling. The intent recognition can be regared as a text classification task, and slot filling as sequence labeling task. Both classification and sequence labeling have been built-in in PaddlePALM.
### Step 1: Prepare Pre-trained Models & Datasets ### Step 1: Prepare Pre-trained Models & Datasets
#### Pre-trianed Model #### Pre-trained Model
The pre-training model of this mission is: [ERNIE-v2-en-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api). We prepare [ERNIE-v2-en-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api) as our pre-trained model for this example.
Make sure you have downloaded the required pre-training model in the current folder.
Make sure you have downloaded `ERNIE` to current folder.
#### Dataset #### Dataset
This task uses the `Airline Travel Information System` dataset. Here we use `Airline Travel Information System` dataset as our testbed.
Download dataset: Download dataset:
```shell ```shell
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function
import os import os
import requests
import tarfile import tarfile
import shutil import shutil
from tqdm import tqdm import sys
import urllib
URLLIB=urllib
if sys.version_info >= (3, 0):
import urllib.request
URLLIB=urllib.request
def download(src, url): def download(src, url):
file_size = int(requests.head(url).headers['Content-Length']) def _reporthook(count, chunk_size, total_size):
bytes_so_far = count * chunk_size
header = { percent = float(bytes_so_far) / float(total_size)
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/' if percent > 1:
'70.0.3538.67 Safari/537.36' percent = 1
} print('\r>> Downloading... {:.1%}'.format(percent), end="")
pbar = tqdm(total=file_size)
resp = requests.get(url, headers=header, stream=True)
with open(src, 'ab') as f:
for chunk in resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size
URLLIB.urlretrieve(url, src, reporthook=_reporthook)
abs_path = os.path.abspath(__file__) abs_path = os.path.abspath(__file__)
download_url = "https://baidu-nlp.bj.bcebos.com/dmtk_data_1.0.0.tar.gz" download_url = "https://baidu-nlp.bj.bcebos.com/dmtk_data_1.0.0.tar.gz"
...@@ -42,4 +35,4 @@ shutil.rmtree(os.path.join(target_dir, 'data/mrda/')) ...@@ -42,4 +35,4 @@ shutil.rmtree(os.path.join(target_dir, 'data/mrda/'))
shutil.rmtree(os.path.join(target_dir, 'data/multi-woz/')) shutil.rmtree(os.path.join(target_dir, 'data/multi-woz/'))
shutil.rmtree(os.path.join(target_dir, 'data/swda/')) shutil.rmtree(os.path.join(target_dir, 'data/swda/'))
shutil.rmtree(os.path.join(target_dir, 'data/udc/')) shutil.rmtree(os.path.join(target_dir, 'data/udc/'))
print(" done!")
...@@ -21,29 +21,24 @@ if __name__ == '__main__': ...@@ -21,29 +21,24 @@ if __name__ == '__main__':
train_slot = './data/atis/atis_slot/train.tsv' train_slot = './data/atis/atis_slot/train.tsv'
train_intent = './data/atis/atis_intent/train.tsv' train_intent = './data/atis/atis_intent/train.tsv'
predict_file = './data/atis/atis_slot/test.tsv'
save_path = './outputs/'
pred_output = './outputs/predict/'
save_type = 'ckpt'
pre_params = './pretrain/ERNIE-v2-en-base/params'
config = json.load(open('./pretrain/ERNIE-v2-en-base/ernie_config.json')) config = json.load(open('./pretrain/ERNIE-v2-en-base/ernie_config.json'))
input_dim = config['hidden_size'] input_dim = config['hidden_size']
# ----------------------- for training ----------------------- # ----------------------- for training -----------------------
# step 1-1: create readers for training # step 1-1: create readers
seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed) seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed)
cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed) cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed)
# step 1-2: load the training data # step 1-2: load train data
seq_label_reader.load_data(train_slot, file_format='tsv', num_epochs=None, batch_size=batch_size) seq_label_reader.load_data(train_slot, file_format='tsv', num_epochs=None, batch_size=batch_size)
cls_reader.load_data(train_intent, batch_size=batch_size, num_epochs=None) cls_reader.load_data(train_intent, batch_size=batch_size, num_epochs=None)
# step 2: create a backbone of the model to extract text features # step 2: create a backbone of the model to extract text features
ernie = palm.backbone.ERNIE.from_config(config) ernie = palm.backbone.ERNIE.from_config(config)
# step 3: register the backbone in readers # step 3: register readers with ernie backbone
seq_label_reader.register_with(ernie) seq_label_reader.register_with(ernie)
cls_reader.register_with(ernie) cls_reader.register_with(ernie)
...@@ -51,7 +46,7 @@ if __name__ == '__main__': ...@@ -51,7 +46,7 @@ if __name__ == '__main__':
seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob) seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob)
cls_head = palm.head.Classify(num_classes_intent, input_dim, dropout_prob) cls_head = palm.head.Classify(num_classes_intent, input_dim, dropout_prob)
# step 5-1: create a task trainer # step 5-1: create task trainers and multiHeadTrainer
trainer_seq_label = palm.Trainer("slot", mix_ratio=1.0) trainer_seq_label = palm.Trainer("slot", mix_ratio=1.0)
trainer_cls = palm.Trainer("intent", mix_ratio=1.0) trainer_cls = palm.Trainer("intent", mix_ratio=1.0)
trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_cls]) trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_cls])
...@@ -60,23 +55,21 @@ if __name__ == '__main__': ...@@ -60,23 +55,21 @@ if __name__ == '__main__':
loss2 = trainer_seq_label.build_forward(ernie, seq_label_head) loss2 = trainer_seq_label.build_forward(ernie, seq_label_head)
loss_var = trainer.build_forward() loss_var = trainer.build_forward()
# step 6-1*: use warmup # step 6-1*: enable warmup for better fine-tuning
n_steps = seq_label_reader.num_examples * 1.5 * num_epochs // batch_size n_steps = seq_label_reader.num_examples * 1.5 * num_epochs // batch_size
warmup_steps = int(0.1 * n_steps) warmup_steps = int(0.1 * n_steps)
sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps)
# step 6-2: create a optimizer # step 6-2: build a optimizer
adam = palm.optimizer.Adam(loss_var, lr, sched) adam = palm.optimizer.Adam(loss_var, lr, sched)
# step 6-3: build backward # step 6-3: build backward graph
trainer.build_backward(optimizer=adam, weight_decay=weight_decay) trainer.build_backward(optimizer=adam, weight_decay=weight_decay)
# step 7: fit prepared reader and data # step 7: fit readers to trainer
trainer.fit_readers_with_mixratio([seq_label_reader, cls_reader], "slot", num_epochs) trainer.fit_readers_with_mixratio([seq_label_reader, cls_reader], "slot", num_epochs)
# step 8-1*: load pretrained parameters # step 8-1*: load pretrained model
trainer.load_pretrain(pre_params) trainer.load_pretrain('./pretrain/ERNIE-v2-en-base')
# step 8-2*: set saver to save model # step 8-2*: set saver to save models during training
save_steps = int(n_steps-batch_size) // 2 trainer.set_saver(save_path='./outputs/', save_steps=300)
# save_steps = 10
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training # step 8-3: start training
trainer.train(print_steps=print_steps) trainer.train(print_steps=10)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function
import os import os
import requests
import tarfile import tarfile
import shutil import shutil
from tqdm import tqdm import sys
import urllib
URLLIB=urllib
if sys.version_info >= (3, 0):
import urllib.request
URLLIB=urllib.request
def download(src, url): def download(src, url):
file_size = int(requests.head(url).headers['Content-Length']) def _reporthook(count, chunk_size, total_size):
bytes_so_far = count * chunk_size
header = { percent = float(bytes_so_far) / float(total_size)
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/' if percent > 1:
'70.0.3538.67 Safari/537.36' percent = 1
} print('\r>> Downloading... {:.1%}'.format(percent), end="")
pbar = tqdm(total=file_size)
resp = requests.get(url, headers=header, stream=True)
with open(src, 'ab') as f:
for chunk in resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size
URLLIB.urlretrieve(url, src, reporthook=_reporthook)
abs_path = os.path.abspath(__file__) abs_path = os.path.abspath(__file__)
download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz" download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz"
...@@ -46,5 +39,4 @@ for file in os.listdir(os.path.join(target_dir, 'task_data', 'chnsenticorp')): ...@@ -46,5 +39,4 @@ for file in os.listdir(os.path.join(target_dir, 'task_data', 'chnsenticorp')):
shutil.move(os.path.join(target_dir, 'task_data', 'chnsenticorp', file), dst_dir) shutil.move(os.path.join(target_dir, 'task_data', 'chnsenticorp', file), dst_dir)
shutil.rmtree(os.path.join(target_dir, 'task_data')) shutil.rmtree(os.path.join(target_dir, 'task_data'))
print(" done!")
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function
import os import os
import requests
import tarfile import tarfile
import shutil import shutil
from tqdm import tqdm import sys
import urllib
URLLIB=urllib
if sys.version_info >= (3, 0):
import urllib.request
URLLIB=urllib.request
def download(src, url): def download(src, url):
file_size = int(requests.head(url).headers['Content-Length']) def _reporthook(count, chunk_size, total_size):
bytes_so_far = count * chunk_size
header = { percent = float(bytes_so_far) / float(total_size)
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/' if percent > 1:
'70.0.3538.67 Safari/537.36' percent = 1
} print('\r>> Downloading... {:.1%}'.format(percent), end="")
pbar = tqdm(total=file_size)
resp = requests.get(url, headers=header, stream=True)
with open(src, 'ab') as f:
for chunk in resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size
URLLIB.urlretrieve(url, src, reporthook=_reporthook)
abs_path = os.path.abspath(__file__) abs_path = os.path.abspath(__file__)
download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz" download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz"
...@@ -46,5 +39,4 @@ for file in os.listdir(os.path.join(target_dir, 'task_data', 'msra_ner')): ...@@ -46,5 +39,4 @@ for file in os.listdir(os.path.join(target_dir, 'task_data', 'msra_ner')):
shutil.move(os.path.join(target_dir, 'task_data', 'msra_ner', file), dst_dir) shutil.move(os.path.join(target_dir, 'task_data', 'msra_ner', file), dst_dir)
shutil.rmtree(os.path.join(target_dir, 'task_data')) shutil.rmtree(os.path.join(target_dir, 'task_data'))
print(" done!")
...@@ -15,23 +15,18 @@ ...@@ -15,23 +15,18 @@
from __future__ import print_function from __future__ import print_function
import os import os
import requests
import tarfile import tarfile
import shutil import shutil
try:
from urllib.request import urlopen # Python 3
except ImportError:
from urllib2 import urlopen # Python 2
from collections import OrderedDict from collections import OrderedDict
import ssl import sys
import urllib
URLLIB=urllib
if sys.version_info >= (3, 0):
import urllib.request
URLLIB=urllib.request
__all__ = ["download", "ls"] __all__ = ["download", "ls"]
# for https
ssl._create_default_https_context = ssl._create_unverified_context
_pretrain = (('RoBERTa-zh-base', 'https://bert-models.bj.bcebos.com/chinese_roberta_wwm_ext_L-12_H-768_A-12.tar.gz'), _pretrain = (('RoBERTa-zh-base', 'https://bert-models.bj.bcebos.com/chinese_roberta_wwm_ext_L-12_H-768_A-12.tar.gz'),
('RoBERTa-zh-large', 'https://bert-models.bj.bcebos.com/chinese_roberta_wwm_large_ext_L-24_H-1024_A-16.tar.gz'), ('RoBERTa-zh-large', 'https://bert-models.bj.bcebos.com/chinese_roberta_wwm_large_ext_L-24_H-1024_A-16.tar.gz'),
('ERNIE-v2-en-base', 'https://ernie.bj.bcebos.com/ERNIE_Base_en_stable-2.0.0.tar.gz'), ('ERNIE-v2-en-base', 'https://ernie.bj.bcebos.com/ERNIE_Base_en_stable-2.0.0.tar.gz'),
...@@ -76,32 +71,15 @@ def _download(item, scope, path, silent=False, convert=False): ...@@ -76,32 +71,15 @@ def _download(item, scope, path, silent=False, convert=False):
filename = data_dir + '/' + data_name filename = data_dir + '/' + data_name
# print process # print process
def _chunk_report(bytes_so_far, total_size): def _reporthook(count, chunk_size, total_size):
bytes_so_far = count * chunk_size
percent = float(bytes_so_far) / float(total_size) percent = float(bytes_so_far) / float(total_size)
if percent > 1: if percent > 1:
percent = 1 percent = 1
if not silent: if not silent:
print('\r>> Downloading... {:.1%}'.format(percent), end = "") print('\r>> Downloading... {:.1%}'.format(percent), end = "")
# copy to local URLLIB.urlretrieve(data_url, filename, reporthook=_reporthook)
def _chunk_read(response, url, chunk_size = 16 * 1024, report_hook = None):
total_size = int(requests.head(url).headers['Content-Length'])
bytes_so_far = 0
with open("%s" % filename, "wb") as f:
while 1:
chunk = response.read(chunk_size)
f.write(chunk)
f.flush()
bytes_so_far += len(chunk)
if not chunk:
break
if report_hook:
report_hook(bytes_so_far, total_size)
return bytes_so_far
response = urlopen(data_url)
_chunk_read(response, data_url, report_hook=_chunk_report)
if not silent: if not silent:
print(' done!') print(' done!')
......
...@@ -42,8 +42,8 @@ class BERT(Backbone): ...@@ -42,8 +42,8 @@ class BERT(Backbone):
self._hidden_act = hidden_act self._hidden_act = hidden_act
self._prepostprocess_dropout = hidden_dropout_prob self._prepostprocess_dropout = 0. if phase == 'predict' else hidden_dropout_prob
self._attention_dropout = attention_probs_dropout_prob self._attention_dropout = 0. if phase == 'predict' else attention_probs_dropout_prob
self._word_emb_name = "word_embedding" self._word_emb_name = "word_embedding"
self._pos_emb_name = "pos_embedding" self._pos_emb_name = "pos_embedding"
......
...@@ -45,8 +45,8 @@ class ERNIE(Backbone): ...@@ -45,8 +45,8 @@ class ERNIE(Backbone):
self._task_types = task_type_vocab_size self._task_types = task_type_vocab_size
self._hidden_act = hidden_act self._hidden_act = hidden_act
self._prepostprocess_dropout = hidden_dropout_prob self._prepostprocess_dropout = 0. if phase == 'predict' else hidden_dropout_prob
self._attention_dropout = attention_probs_dropout_prob self._attention_dropout = 0. if phase == 'predict' else attention_probs_dropout_prob
self._word_emb_name = "word_embedding" self._word_emb_name = "word_embedding"
self._pos_emb_name = "pos_embedding" self._pos_emb_name = "pos_embedding"
......
...@@ -125,6 +125,7 @@ def decode_fake(nums, mask, bs): ...@@ -125,6 +125,7 @@ def decode_fake(nums, mask, bs):
n_f = len(mask) - n_t n_f = len(mask) - n_t
p1 = nums - (n_t-1) * bs p1 = nums - (n_t-1) * bs
each_f = p1 / (n_f+1) assert p1 % (n_f+1) == 0
each_f = p1 // (n_f+1)
return each_f * n_f return each_f * n_f
...@@ -94,14 +94,17 @@ class Classify(Head): ...@@ -94,14 +94,17 @@ class Classify(Head):
def epoch_postprocess(self, post_inputs, output_dir=None): def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training: if not self._is_training:
if output_dir is None: results = []
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)): for i in range(len(self._preds)):
label = int(np.argmax(np.array(self._preds[i]))) label = int(np.argmax(np.array(self._preds[i])))
result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]} result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]}
results.append(result)
if output_dir is not None:
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for result in results:
result = json.dumps(result) result = json.dumps(result)
writer.write(result+'\n') writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return results
...@@ -174,15 +174,18 @@ class Match(Head): ...@@ -174,15 +174,18 @@ class Match(Head):
def epoch_postprocess(self, post_inputs, output_dir=None): def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training: if not self._is_training:
if output_dir is None: results = []
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)): for i in range(len(self._preds)):
if self._learning_strategy == 'pointwise': if self._learning_strategy == 'pointwise':
label = int(np.argmax(np.array(self._preds[i]))) label = int(np.argmax(np.array(self._preds[i])))
result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]} result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]}
elif self._learning_strategy == 'pairwise': elif self._learning_strategy == 'pairwise':
result = {'index': i, 'probs': self._preds[i][0]} result = {'index': i, 'probs': self._preds[i][0]}
results.append(result)
if output_dir is not None:
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for result in results:
result = json.dumps(result, ensure_ascii=False) result = json.dumps(result, ensure_ascii=False)
writer.write(result+'\n') writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return results
...@@ -128,13 +128,15 @@ class MaskLM(Head): ...@@ -128,13 +128,15 @@ class MaskLM(Head):
def epoch_postprocess(self, post_inputs, output_dir=None): def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training: if not self._is_training:
if output_dir is None: results = []
for p in self._preds: for i in range(len(self._preds)):
print(p) result = {'index': i, 'word_id': self._preds[i]}
else: results.append(result)
if output_dir is not None:
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for p in self._preds: for result in results:
writer.write(str(p)+'\n') result = json.dumps(result)
writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return results
...@@ -154,8 +154,7 @@ class MRC(Head): ...@@ -154,8 +154,7 @@ class MRC(Head):
"""(optional interface) this func will be called after evaluation/predicting process and each epoch during training process.""" """(optional interface) this func will be called after evaluation/predicting process and each epoch during training process."""
if not self._is_training: if not self._is_training:
if output_dir is None: if output_dir is not None:
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
examples = post_inputs['reader']['examples'] examples = post_inputs['reader']['examples']
features = post_inputs['reader']['features'] features = post_inputs['reader']['features']
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
...@@ -169,6 +168,7 @@ class MRC(Head): ...@@ -169,6 +168,7 @@ class MRC(Head):
output_nbest_file, output_null_log_odds_file, output_nbest_file, output_null_log_odds_file,
self._with_negative, self._with_negative,
self._null_score_diff_threshold, self._verbose) self._null_score_diff_threshold, self._verbose)
return self._pred_results
def _write_predictions(all_examples, all_features, all_results, n_best_size, def _write_predictions(all_examples, all_features, all_results, n_best_size,
......
...@@ -118,9 +118,9 @@ class SequenceLabel(Head): ...@@ -118,9 +118,9 @@ class SequenceLabel(Head):
def epoch_postprocess(self, post_inputs, output_dir=None): def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training: if not self._is_training:
if output_dir is None: if output_dir is not None:
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for p in self._preds: for p in self._preds:
writer.write(str(p)+'\n') writer.write(str(p)+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return self._preds
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册