提交 bf081f05 编写于 作者: X xixiaoyao

Merge branch 'master' of https://github.com/PaddlePaddle/PALM

......@@ -19,8 +19,9 @@ Beyond the research scope, PaddlePALM has been applied on **Baidu Search Engine*
- **Easy-to-use:** with PALM, *8 steps* to achieve a typical NLP task. Moreover, all basic components (e.g., the model backbone, dataset reader, task output head, optimizer...) have been decoupled, which allows the replacement of any component to other candidates with quite minor changes of your code.
- **Built-in Popular NLP Backbones and Pre-trained models:** multiple state-of-the-art general purpose model architectures and pretrained models (e.g., BERT,ERNIE,RoBERTa,...) are built-in.
- **Multi-task Learning friendly:** *6 steps* to achieve multi-task learning for prepared tasks.
- **Large Scale and Pre-training friendly:** automatically utilize multi-gpus (if exists) to accelerate training and inference. Minor codes is required for distributed training on clusters.
- **Easy to play Multi-task Learning:** only one API is needed for jointly training of several tasks with parameters reusement.
- **Support train/eval with Multi-GPUs:** automatically recognize and adapt to multiple gpus mode to accelerate training and inference.
- **Pre-training friendly:** self-supervised tasks (e.g., mask language model) are built-in to facilitate pre-training. Easy to train from scratch.
- **Easy to Customize:** support customized development of any component (e.g, backbone, task head, reader and optimizer) with reusement of pre-defined ones, which gives developers high flexibility and effeciency to adapt for diverse NLP scenes.
You can easily re-produce following competitive results with minor codes, which covers most of NLP tasks such as classification, matching, sequence labeling, reading comprehension, dialogue understanding and so on. More details can be found in `examples`.
......
......@@ -117,15 +117,15 @@ PaddlePALM是一个设计良好的高级NLP框架。基于PaddlePALM的轻量级
| 模块 | 描述 |
| - | - |
| **paddlepalm** | 一个开源的NLP预训练和多任务学习框架,建立在paddlepaddle框架上。 |
| **paddlepalm.reader** | 特定于任务的数据集读取工具的集合。|
| **paddlepalm.backbone** | 一系列经典的NLP表示模型,如BERT, ERNIE, RoBERTa。|
| **paddlepalm.head** | 任务特定输出层的集合。|
| **paddlepalm.lr_sched** | 一个学习率时间表的集合。|
| **paddlepalm.optimizer** | 优化器的集合。|
| **paddlepalm.downloader** | 预训练模型与配置和vocab文件的下载模块。|
| **paddlepalm.Trainer** | 单一任务训练/预测。一个训练器是建立计算图,管理训练和评估过程,实现模型/检查点保存和pretrain_model/检查点加载。|
| **paddlepalm.MultiHeadTrainer** | 进行多任务训练/预测的核心模块。一个MultiHeadTrainer建立在几个Trainer的基础上。在继承Trainer的基础上,实现了模型主干网络跨任务复用,采用多任务学习,多任务推理,来保证更有效的评估和预测。|
| **paddlepalm** | 基于PaddlePaddle框架的high-level NLP预训练和多任务学习框架。 |
| **paddlepalm.reader** | 预置的任务数据集读取与预处理工具。|
| **paddlepalm.backbone** | 预置的主干网络,如BERT, ERNIE, RoBERTa。|
| **paddlepalm.head** | 预置的任务输出层。|
| **paddlepalm.lr_sched** | 预置的学习率规划策略。|
| **paddlepalm.optimizer** | 预置的优化器。|
| **paddlepalm.downloader** | 预训练模型管理与下载模块。|
| **paddlepalm.Trainer** | 任务训练/预测单元。训练器用于建立计算图,管理训练和评估过程,实现模型/检查点保存和pretrain_model/检查点加载等。|
| **paddlepalm.MultiHeadTrainer** | 完成多任务训练/预测的模块。一个MultiHeadTrainer建立在几个Trainer的基础上。实现了模型主干网络跨任务复用、多任务学习、多任务推理等。|
## 安装
......
# -*- coding: utf-8 -*-
from __future__ import print_function
import os
import requests
import tarfile
import shutil
from tqdm import tqdm
import sys
import urllib
URLLIB=urllib
if sys.version_info >= (3, 0):
import urllib.request
URLLIB=urllib.request
def download(src, url):
file_size = int(requests.head(url).headers['Content-Length'])
header = {
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/'
'70.0.3538.67 Safari/537.36'
}
pbar = tqdm(total=file_size)
resp = requests.get(url, headers=header, stream=True)
with open(src, 'ab') as f:
for chunk in resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size
def _reporthook(count, chunk_size, total_size):
bytes_so_far = count * chunk_size
percent = float(bytes_so_far) / float(total_size)
if percent > 1:
percent = 1
print('\r>> Downloading... {:.1%}'.format(percent), end="")
URLLIB.urlretrieve(url, src, reporthook=_reporthook)
abs_path = os.path.abspath(__file__)
download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz"
......@@ -46,5 +39,4 @@ for file in os.listdir(os.path.join(target_dir, 'task_data', 'chnsenticorp')):
shutil.move(os.path.join(target_dir, 'task_data', 'chnsenticorp', file), dst_dir)
shutil.rmtree(os.path.join(target_dir, 'task_data'))
print(" done!")
# -*- coding: utf-8 -*-
from __future__ import print_function
import os
import requests
from tqdm import tqdm
import sys
import urllib
URLLIB=urllib
if sys.version_info >= (3, 0):
import urllib.request
URLLIB=urllib.request
def download(src, url):
file_size = int(requests.head(url).headers['Content-Length'])
header = {
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/'
'70.0.3538.67 Safari/537.36'
}
pbar = tqdm(total=file_size)
resp = requests.get(url, headers=header, stream=True)
with open(src, 'ab') as f:
for chunk in resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
def _reporthook(count, chunk_size, total_size):
bytes_so_far = count * chunk_size
percent = float(bytes_so_far) / float(total_size)
if percent > 1:
percent = 1
print('\r>> Downloading... {:.1%}'.format(percent), end="")
pbar.close()
return file_size
URLLIB.urlretrieve(url, src, reporthook=_reporthook)
abs_path = os.path.abspath(__file__)
......@@ -32,3 +27,4 @@ if not os.path.exists(data_dir) or not os.path.isdir(data_dir):
download_url = "http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv"
downlaod_path = os.path.join(data_dir, "quora_duplicate_questions.tsv")
download(downlaod_path, download_url)
print(" done!")
# -*- coding: utf-8 -*-
from __future__ import print_function
import os
import requests
import tarfile
import shutil
from tqdm import tqdm
import sys
import urllib
URLLIB=urllib
if sys.version_info >= (3, 0):
import urllib.request
URLLIB=urllib.request
def download(src, url):
file_size = int(requests.head(url).headers['Content-Length'])
header = {
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/'
'70.0.3538.67 Safari/537.36'
}
pbar = tqdm(total=file_size)
resp = requests.get(url, headers=header, stream=True)
with open(src, 'ab') as f:
for chunk in resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size
def _reporthook(count, chunk_size, total_size):
bytes_so_far = count * chunk_size
percent = float(bytes_so_far) / float(total_size)
if percent > 1:
percent = 1
print('\r>> Downloading... {:.1%}'.format(percent), end="")
URLLIB.urlretrieve(url, src, reporthook=_reporthook)
abs_path = os.path.abspath(__file__)
download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz"
......@@ -46,5 +39,5 @@ for file in os.listdir(os.path.join(target_dir, 'task_data', 'cmrc2018')):
shutil.move(os.path.join(target_dir, 'task_data', 'cmrc2018', file), dst_dir)
shutil.rmtree(os.path.join(target_dir, 'task_data'))
print(" done!")
## Example 6: Joint Training in Dialogue
This task is a slot filling task. During training, the task uses intent determination task to assist in training slot filling model. The following sections detail model preparation, dataset preparation, and how to run the task.
## Example 6: Joint Training of Dialogue Intent Recognition and Slot Filling
This example achieves the joint training ofg Dialogue Intent Recognition and Slot Filling. The intent recognition can be regared as a text classification task, and slot filling as sequence labeling task. Both classification and sequence labeling have been built-in in PaddlePALM.
### Step 1: Prepare Pre-trained Models & Datasets
#### Pre-trianed Model
#### Pre-trained Model
The pre-training model of this mission is: [ERNIE-v2-en-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api).
Make sure you have downloaded the required pre-training model in the current folder.
We prepare [ERNIE-v2-en-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api) as our pre-trained model for this example.
Make sure you have downloaded `ERNIE` to current folder.
#### Dataset
This task uses the `Airline Travel Information System` dataset.
Here we use `Airline Travel Information System` dataset as our testbed.
Download dataset:
```shell
python download.py
......
# -*- coding: utf-8 -*-
from __future__ import print_function
import os
import requests
import tarfile
import shutil
from tqdm import tqdm
import sys
import urllib
URLLIB=urllib
if sys.version_info >= (3, 0):
import urllib.request
URLLIB=urllib.request
def download(src, url):
file_size = int(requests.head(url).headers['Content-Length'])
header = {
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/'
'70.0.3538.67 Safari/537.36'
}
pbar = tqdm(total=file_size)
resp = requests.get(url, headers=header, stream=True)
with open(src, 'ab') as f:
for chunk in resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size
def _reporthook(count, chunk_size, total_size):
bytes_so_far = count * chunk_size
percent = float(bytes_so_far) / float(total_size)
if percent > 1:
percent = 1
print('\r>> Downloading... {:.1%}'.format(percent), end="")
URLLIB.urlretrieve(url, src, reporthook=_reporthook)
abs_path = os.path.abspath(__file__)
download_url = "https://baidu-nlp.bj.bcebos.com/dmtk_data_1.0.0.tar.gz"
......@@ -42,4 +35,4 @@ shutil.rmtree(os.path.join(target_dir, 'data/mrda/'))
shutil.rmtree(os.path.join(target_dir, 'data/multi-woz/'))
shutil.rmtree(os.path.join(target_dir, 'data/swda/'))
shutil.rmtree(os.path.join(target_dir, 'data/udc/'))
print(" done!")
......@@ -21,29 +21,24 @@ if __name__ == '__main__':
train_slot = './data/atis/atis_slot/train.tsv'
train_intent = './data/atis/atis_intent/train.tsv'
predict_file = './data/atis/atis_slot/test.tsv'
save_path = './outputs/'
pred_output = './outputs/predict/'
save_type = 'ckpt'
pre_params = './pretrain/ERNIE-v2-en-base/params'
config = json.load(open('./pretrain/ERNIE-v2-en-base/ernie_config.json'))
input_dim = config['hidden_size']
# ----------------------- for training -----------------------
# step 1-1: create readers for training
# step 1-1: create readers
seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed)
cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed)
# step 1-2: load the training data
# step 1-2: load train data
seq_label_reader.load_data(train_slot, file_format='tsv', num_epochs=None, batch_size=batch_size)
cls_reader.load_data(train_intent, batch_size=batch_size, num_epochs=None)
# step 2: create a backbone of the model to extract text features
ernie = palm.backbone.ERNIE.from_config(config)
# step 3: register the backbone in readers
# step 3: register readers with ernie backbone
seq_label_reader.register_with(ernie)
cls_reader.register_with(ernie)
......@@ -51,7 +46,7 @@ if __name__ == '__main__':
seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob)
cls_head = palm.head.Classify(num_classes_intent, input_dim, dropout_prob)
# step 5-1: create a task trainer
# step 5-1: create task trainers and multiHeadTrainer
trainer_seq_label = palm.Trainer("slot", mix_ratio=1.0)
trainer_cls = palm.Trainer("intent", mix_ratio=1.0)
trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_cls])
......@@ -60,23 +55,21 @@ if __name__ == '__main__':
loss2 = trainer_seq_label.build_forward(ernie, seq_label_head)
loss_var = trainer.build_forward()
# step 6-1*: use warmup
# step 6-1*: enable warmup for better fine-tuning
n_steps = seq_label_reader.num_examples * 1.5 * num_epochs // batch_size
warmup_steps = int(0.1 * n_steps)
sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps)
# step 6-2: create a optimizer
# step 6-2: build a optimizer
adam = palm.optimizer.Adam(loss_var, lr, sched)
# step 6-3: build backward
# step 6-3: build backward graph
trainer.build_backward(optimizer=adam, weight_decay=weight_decay)
# step 7: fit prepared reader and data
# step 7: fit readers to trainer
trainer.fit_readers_with_mixratio([seq_label_reader, cls_reader], "slot", num_epochs)
# step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params)
# step 8-2*: set saver to save model
save_steps = int(n_steps-batch_size) // 2
# save_steps = 10
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-1*: load pretrained model
trainer.load_pretrain('./pretrain/ERNIE-v2-en-base')
# step 8-2*: set saver to save models during training
trainer.set_saver(save_path='./outputs/', save_steps=300)
# step 8-3: start training
trainer.train(print_steps=print_steps)
trainer.train(print_steps=10)
# -*- coding: utf-8 -*-
from __future__ import print_function
import os
import requests
import tarfile
import shutil
from tqdm import tqdm
import sys
import urllib
URLLIB=urllib
if sys.version_info >= (3, 0):
import urllib.request
URLLIB=urllib.request
def download(src, url):
file_size = int(requests.head(url).headers['Content-Length'])
header = {
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/'
'70.0.3538.67 Safari/537.36'
}
pbar = tqdm(total=file_size)
resp = requests.get(url, headers=header, stream=True)
with open(src, 'ab') as f:
for chunk in resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size
def _reporthook(count, chunk_size, total_size):
bytes_so_far = count * chunk_size
percent = float(bytes_so_far) / float(total_size)
if percent > 1:
percent = 1
print('\r>> Downloading... {:.1%}'.format(percent), end="")
URLLIB.urlretrieve(url, src, reporthook=_reporthook)
abs_path = os.path.abspath(__file__)
download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz"
......@@ -46,5 +39,4 @@ for file in os.listdir(os.path.join(target_dir, 'task_data', 'chnsenticorp')):
shutil.move(os.path.join(target_dir, 'task_data', 'chnsenticorp', file), dst_dir)
shutil.rmtree(os.path.join(target_dir, 'task_data'))
print(" done!")
# -*- coding: utf-8 -*-
from __future__ import print_function
import os
import requests
import tarfile
import shutil
from tqdm import tqdm
import sys
import urllib
URLLIB=urllib
if sys.version_info >= (3, 0):
import urllib.request
URLLIB=urllib.request
def download(src, url):
file_size = int(requests.head(url).headers['Content-Length'])
header = {
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/'
'70.0.3538.67 Safari/537.36'
}
pbar = tqdm(total=file_size)
resp = requests.get(url, headers=header, stream=True)
with open(src, 'ab') as f:
for chunk in resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size
def _reporthook(count, chunk_size, total_size):
bytes_so_far = count * chunk_size
percent = float(bytes_so_far) / float(total_size)
if percent > 1:
percent = 1
print('\r>> Downloading... {:.1%}'.format(percent), end="")
URLLIB.urlretrieve(url, src, reporthook=_reporthook)
abs_path = os.path.abspath(__file__)
download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz"
......@@ -46,5 +39,4 @@ for file in os.listdir(os.path.join(target_dir, 'task_data', 'msra_ner')):
shutil.move(os.path.join(target_dir, 'task_data', 'msra_ner', file), dst_dir)
shutil.rmtree(os.path.join(target_dir, 'task_data'))
print(" done!")
......@@ -15,23 +15,18 @@
from __future__ import print_function
import os
import requests
import tarfile
import shutil
try:
from urllib.request import urlopen # Python 3
except ImportError:
from urllib2 import urlopen # Python 2
from collections import OrderedDict
import ssl
import sys
import urllib
URLLIB=urllib
if sys.version_info >= (3, 0):
import urllib.request
URLLIB=urllib.request
__all__ = ["download", "ls"]
# for https
ssl._create_default_https_context = ssl._create_unverified_context
_pretrain = (('RoBERTa-zh-base', 'https://bert-models.bj.bcebos.com/chinese_roberta_wwm_ext_L-12_H-768_A-12.tar.gz'),
('RoBERTa-zh-large', 'https://bert-models.bj.bcebos.com/chinese_roberta_wwm_large_ext_L-24_H-1024_A-16.tar.gz'),
('ERNIE-v2-en-base', 'https://ernie.bj.bcebos.com/ERNIE_Base_en_stable-2.0.0.tar.gz'),
......@@ -76,32 +71,15 @@ def _download(item, scope, path, silent=False, convert=False):
filename = data_dir + '/' + data_name
# print process
def _chunk_report(bytes_so_far, total_size):
def _reporthook(count, chunk_size, total_size):
bytes_so_far = count * chunk_size
percent = float(bytes_so_far) / float(total_size)
if percent > 1:
percent = 1
if not silent:
print('\r>> Downloading... {:.1%}'.format(percent), end = "")
# copy to local
def _chunk_read(response, url, chunk_size = 16 * 1024, report_hook = None):
total_size = int(requests.head(url).headers['Content-Length'])
bytes_so_far = 0
with open("%s" % filename, "wb") as f:
while 1:
chunk = response.read(chunk_size)
f.write(chunk)
f.flush()
bytes_so_far += len(chunk)
if not chunk:
break
if report_hook:
report_hook(bytes_so_far, total_size)
return bytes_so_far
response = urlopen(data_url)
_chunk_read(response, data_url, report_hook=_chunk_report)
URLLIB.urlretrieve(data_url, filename, reporthook=_reporthook)
if not silent:
print(' done!')
......
......@@ -42,8 +42,8 @@ class BERT(Backbone):
self._hidden_act = hidden_act
self._prepostprocess_dropout = hidden_dropout_prob
self._attention_dropout = attention_probs_dropout_prob
self._prepostprocess_dropout = 0. if phase == 'predict' else hidden_dropout_prob
self._attention_dropout = 0. if phase == 'predict' else attention_probs_dropout_prob
self._word_emb_name = "word_embedding"
self._pos_emb_name = "pos_embedding"
......
......@@ -45,8 +45,8 @@ class ERNIE(Backbone):
self._task_types = task_type_vocab_size
self._hidden_act = hidden_act
self._prepostprocess_dropout = hidden_dropout_prob
self._attention_dropout = attention_probs_dropout_prob
self._prepostprocess_dropout = 0. if phase == 'predict' else hidden_dropout_prob
self._attention_dropout = 0. if phase == 'predict' else attention_probs_dropout_prob
self._word_emb_name = "word_embedding"
self._pos_emb_name = "pos_embedding"
......
......@@ -125,6 +125,7 @@ def decode_fake(nums, mask, bs):
n_f = len(mask) - n_t
p1 = nums - (n_t-1) * bs
each_f = p1 / (n_f+1)
assert p1 % (n_f+1) == 0
each_f = p1 // (n_f+1)
return each_f * n_f
......@@ -94,14 +94,17 @@ class Classify(Head):
def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training:
if output_dir is None:
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)):
label = int(np.argmax(np.array(self._preds[i])))
result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]}
result = json.dumps(result)
writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
results = []
for i in range(len(self._preds)):
label = int(np.argmax(np.array(self._preds[i])))
result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]}
results.append(result)
if output_dir is not None:
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for result in results:
result = json.dumps(result)
writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return results
......@@ -174,15 +174,18 @@ class Match(Head):
def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training:
if output_dir is None:
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)):
if self._learning_strategy == 'pointwise':
label = int(np.argmax(np.array(self._preds[i])))
result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]}
elif self._learning_strategy == 'pairwise':
result = {'index': i, 'probs': self._preds[i][0]}
result = json.dumps(result, ensure_ascii=False)
writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
results = []
for i in range(len(self._preds)):
if self._learning_strategy == 'pointwise':
label = int(np.argmax(np.array(self._preds[i])))
result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]}
elif self._learning_strategy == 'pairwise':
result = {'index': i, 'probs': self._preds[i][0]}
results.append(result)
if output_dir is not None:
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for result in results:
result = json.dumps(result, ensure_ascii=False)
writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return results
......@@ -128,13 +128,15 @@ class MaskLM(Head):
def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training:
if output_dir is None:
for p in self._preds:
print(p)
else:
results = []
for i in range(len(self._preds)):
result = {'index': i, 'word_id': self._preds[i]}
results.append(result)
if output_dir is not None:
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for p in self._preds:
writer.write(str(p)+'\n')
for result in results:
result = json.dumps(result)
writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return results
......@@ -154,21 +154,21 @@ class MRC(Head):
"""(optional interface) this func will be called after evaluation/predicting process and each epoch during training process."""
if not self._is_training:
if output_dir is None:
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
examples = post_inputs['reader']['examples']
features = post_inputs['reader']['features']
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_prediction_file = os.path.join(output_dir, "predictions.json")
output_nbest_file = os.path.join(output_dir, "nbest_predictions.json")
output_null_log_odds_file = os.path.join(output_dir, "null_odds.json")
_write_predictions(examples, features, self._pred_results,
self._n_best_size, self._max_answer_length,
self._do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file,
self._with_negative,
self._null_score_diff_threshold, self._verbose)
if output_dir is not None:
examples = post_inputs['reader']['examples']
features = post_inputs['reader']['features']
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_prediction_file = os.path.join(output_dir, "predictions.json")
output_nbest_file = os.path.join(output_dir, "nbest_predictions.json")
output_null_log_odds_file = os.path.join(output_dir, "null_odds.json")
_write_predictions(examples, features, self._pred_results,
self._n_best_size, self._max_answer_length,
self._do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file,
self._with_negative,
self._null_score_diff_threshold, self._verbose)
return self._pred_results
def _write_predictions(all_examples, all_features, all_results, n_best_size,
......
......@@ -118,9 +118,9 @@ class SequenceLabel(Head):
def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training:
if output_dir is None:
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for p in self._preds:
writer.write(str(p)+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
if output_dir is not None:
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for p in self._preds:
writer.write(str(p)+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return self._preds
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册