未验证 提交 11ee20fb 编写于 作者: L LiuChiachi 提交者: GitHub

Add couplet examples (#5007)

* add couplet

* simplify model code

* simplify code

* update couplet README

* add pad_token to TranslationDataset, update CoupletDataset

* update couplet url, add couplet generation example

* update TranslationDataset

* upadte classname to self in __init__

* update README.md
上级 9becf7bb
# 使用Seq2Seq模型完成自动对联 # 使用Seq2Seq模型完成自动对联
以下是本范例模型的简要目录结构及说明:
```
.
├── README.md # 文档,本文件
├── args.py # 训练、预测以及模型参数配置程序
├── data.py # 数据读入程序
├── train.py # 训练主程序
├── predict.py # 预测主程序
└── model.py # 带注意力机制的对联生成程序
```
## 简介
Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)结构,用编码器将源序列编码成vector,再用解码器将该vector解码为目标序列。Seq2Seq 广泛应用于机器翻译,自动对话机器人,文档摘要自动生成,图片描述自动生成等任务中。
本目录包含Seq2Seq的一个经典样例:自动对联生成,带attention机制的文本生成模型。
运行本目录下的范例模型需要安装PaddlePaddle 2.0-rc0版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。
## 模型概览
本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,在预测时我们使用柱搜索(beam search)算法来生对联的下联。
## 数据介绍
本教程使用[couplet数据集](https://bj.bcebos.com/paddlehub-dataset/couplet.tar.gz)数据集作为训练语料,train.tsv作为训练集,dev.tsv数据作为开发集,test.tsv数据作为测试集
数据集会在`CoupletDataset`初始化时自动下载
## 模型训练
执行以下命令即可训练带有注意力机制的Seq2Seq模型:
```sh
python train.py \
--num_layers 2 \
--hidden_size 512 \
--batch_size 128 \
--use_gpu True \
--model_path ./couplet_models \
--max_epoch 20
```
各参数的具体说明请参阅 `args.py` 。训练程序会在每个epoch训练结束之后,save一次模型。
## 模型预测
训练完成之后,可以使用保存的模型(由 `--init_from_ckpt` 指定)对测试集进行beam search解码,命令如下:
```sh
python predict.py \
--num_layers 2 \
--hidden_size 512 \
--batch_size 128 \
--init_from_ckpt couplet_models/19 \
--infer_output_file infer_output.txt \
--beam_size 10 \
--use_gpu True
```
各参数的具体说明请参阅 `args.py` ,注意预测时所用模型超参数需和训练时一致。
## 生成对联样例
崖悬风雨骤 月落水云寒
约春章柳下 邀月醉花间
箬笠红尘外 扁舟明月中
书香醉倒窗前月 烛影摇红梦里人
踏雪寻梅求雅趣 临风把酒觅知音
未出南阳天下论 先登北斗汉中书
朱联妙语千秋颂 赤胆忠心万代传
月半举杯圆月下 花间对酒醉花间
挥笔如剑倚麓山豪气干云揽月去 落笔似龙飞沧海龙吟破浪乘风来
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--use_gpu',
type=eval,
default=False,
help='Whether using gpu [True|False]')
parser.add_argument(
"--learning_rate",
type=float,
default=0.001,
help="learning rate for optimizer")
parser.add_argument(
"--num_layers",
type=int,
default=1,
help="layers number of encoder and decoder")
parser.add_argument(
"--hidden_size",
type=int,
default=100,
help="hidden size of encoder and decoder")
parser.add_argument(
"--batch_size", type=int, default=128, help="Batch size of each step")
parser.add_argument(
"--max_epoch", type=int, default=50, help="max epoch for the training")
parser.add_argument(
"--max_len",
type=int,
default=50,
help="max length for source and target sentence")
parser.add_argument(
"--max_grad_norm",
type=float,
default=5.0,
help="max grad norm for global norm clip")
parser.add_argument(
"--log_freq",
type=int,
default=200,
help="The frequency to print training logs")
parser.add_argument(
"--model_path",
type=str,
default='model',
help="model path for model to save")
parser.add_argument(
"--init_from_ckpt",
type=str,
default=None,
help="The path of checkpoint to be loaded.")
parser.add_argument(
"--infer_file", type=str, help="file name for inference")
parser.add_argument(
"--infer_output_file",
type=str,
default='infer_output',
help="file name for inference output")
parser.add_argument(
"--beam_size", type=int, default=10, help="file name for inference")
args = parser.parse_args()
return args
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from functools import partial
import numpy as np
import paddle
from paddle.utils.download import get_path_from_url
from paddlenlp.data import Vocab, Pad
from paddlenlp.data import SamplerHelper
from paddlenlp.utils.env import DATA_HOME
from paddlenlp.datasets import TranslationDataset
def create_train_loader(batch_size=128):
train_ds = CoupletDataset.get_datasets(["train"])
vocab, _ = CoupletDataset.get_vocab()
pad_id = vocab[CoupletDataset.EOS_TOKEN]
train_batch_sampler = SamplerHelper(train_ds).shuffle().batch(
batch_size=batch_size).shard()
train_loader = paddle.io.DataLoader(
train_ds,
batch_sampler=train_batch_sampler,
collate_fn=partial(
prepare_input, pad_id=pad_id))
return train_loader, len(vocab), pad_id
def create_infer_loader(batch_size=128):
test_ds = CoupletDataset.get_datasets(["test"])
vocab, _ = CoupletDataset.get_vocab()
pad_id = vocab[CoupletDataset.EOS_TOKEN]
bos_id = vocab[CoupletDataset.BOS_TOKEN]
eos_id = vocab[CoupletDataset.EOS_TOKEN]
test_batch_sampler = SamplerHelper(test_ds).batch(
batch_size=batch_size).shard()
test_loader = paddle.io.DataLoader(
test_ds,
batch_sampler=test_batch_sampler,
collate_fn=partial(
prepare_input, pad_id=pad_id))
return test_loader, len(vocab), pad_id, bos_id, eos_id
def prepare_input(insts, pad_id):
src, src_length = Pad(pad_val=pad_id, ret_length=True)(
[inst[0] for inst in insts])
tgt, tgt_length = Pad(pad_val=pad_id, ret_length=True)(
[inst[1] for inst in insts])
tgt_mask = (tgt[:, :-1] != pad_id).astype(paddle.get_default_dtype())
return src, src_length, tgt[:, :-1], tgt[:, 1:, np.newaxis], tgt_mask
class CoupletDataset(TranslationDataset):
URL = "https://paddlenlp.bj.bcebos.com/datasets/couplet.tar.gz"
SPLITS = {
'train': TranslationDataset.META_INFO(
os.path.join("couplet", "train_src.tsv"),
os.path.join("couplet", "train_tgt.tsv"),
"ad137385ad5e264ac4a54fe8c95d1583",
"daf4dd79dbf26040696eee0d645ef5ad"),
'dev': TranslationDataset.META_INFO(
os.path.join("couplet", "dev_src.tsv"),
os.path.join("couplet", "dev_tgt.tsv"),
"65bf9e72fa8fdf0482751c1fd6b6833c",
"3bc3b300b19d170923edfa8491352951"),
'test': TranslationDataset.META_INFO(
os.path.join("couplet", "test_src.tsv"),
os.path.join("couplet", "test_tgt.tsv"),
"f0a7366dfa0acac884b9f4901aac2cc1",
"56664bff3f2edfd7a751a55a689f90c2")
}
VOCAB_INFO = (os.path.join("couplet", "vocab.txt"), os.path.join(
"couplet", "vocab.txt"), "0bea1445c7c7fb659b856bb07e54a604",
"0bea1445c7c7fb659b856bb07e54a604")
UNK_TOKEN = '<unk>'
BOS_TOKEN = '<s>'
EOS_TOKEN = '</s>'
MD5 = '5c0dcde8eec6a517492227041c2e2d54'
def __init__(self, mode='train', root='./'):
data_select = ('train', 'dev', 'test')
if mode not in data_select:
raise TypeError(
'`train`, `dev` or `test` is supported but `{}` is passed in'.
format(mode))
# Download data
root = self.get_data(root=root)
self.data = self.read_raw_data(root, mode)
self.vocab, _ = self.get_vocab(root)
self.transform()
def transform(self):
eos_id = self.vocab[self.EOS_TOKEN]
bos_id = self.vocab[self.BOS_TOKEN]
self.data = [(
[bos_id] + self.vocab.to_indices(data[0].split("\x02")) + [eos_id],
[bos_id] + self.vocab.to_indices(data[1].split("\x02")) + [eos_id])
for data in self.data]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class CrossEntropyCriterion(nn.Layer):
def __init__(self):
super(CrossEntropyCriterion, self).__init__()
def forward(self, predict, label, trg_mask):
cost = F.softmax_with_cross_entropy(
logits=predict, label=label, soft_label=False)
cost = paddle.squeeze(cost, axis=[2])
masked_cost = cost * trg_mask
batch_mean_cost = paddle.mean(masked_cost, axis=[0])
seq_cost = paddle.sum(batch_mean_cost)
return seq_cost
class Seq2SeqEncoder(nn.Layer):
def __init__(self, vocab_size, embed_dim, hidden_size, num_layers):
super(Seq2SeqEncoder, self).__init__()
self.embedder = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=0.2 if num_layers > 1 else 0.)
def forward(self, sequence, sequence_length):
inputs = self.embedder(sequence)
encoder_output, encoder_state = self.lstm(
inputs, sequence_length=sequence_length)
return encoder_output, encoder_state
class AttentionLayer(nn.Layer):
def __init__(self, hidden_size):
super(AttentionLayer, self).__init__()
self.input_proj = nn.Linear(hidden_size, hidden_size)
self.output_proj = nn.Linear(hidden_size + hidden_size, hidden_size)
def forward(self, hidden, encoder_output, encoder_padding_mask):
encoder_output = self.input_proj(encoder_output)
attn_scores = paddle.matmul(
paddle.unsqueeze(hidden, [1]), encoder_output, transpose_y=True)
if encoder_padding_mask is not None:
attn_scores = paddle.add(attn_scores, encoder_padding_mask)
attn_scores = F.softmax(attn_scores)
attn_out = paddle.squeeze(
paddle.matmul(attn_scores, encoder_output), [1])
attn_out = paddle.concat([attn_out, hidden], 1)
attn_out = self.output_proj(attn_out)
return attn_out
class Seq2SeqDecoderCell(nn.RNNCellBase):
def __init__(self, num_layers, input_size, hidden_size):
super(Seq2SeqDecoderCell, self).__init__()
self.dropout = nn.Dropout(0.2)
self.lstm_cells = nn.LayerList([
nn.LSTMCell(
input_size=input_size + hidden_size if i == 0 else hidden_size,
hidden_size=hidden_size) for i in range(num_layers)
])
self.attention_layer = AttentionLayer(hidden_size)
def forward(self,
step_input,
states,
encoder_output,
encoder_padding_mask=None):
lstm_states, input_feed = states
new_lstm_states = []
step_input = paddle.concat([step_input, input_feed], 1)
for i, lstm_cell in enumerate(self.lstm_cells):
out, new_lstm_state = lstm_cell(step_input, lstm_states[i])
step_input = self.dropout(out)
new_lstm_states.append(new_lstm_state)
out = self.attention_layer(step_input, encoder_output,
encoder_padding_mask)
return out, [new_lstm_states, out]
class Seq2SeqDecoder(nn.Layer):
def __init__(self, vocab_size, embed_dim, hidden_size, num_layers):
super(Seq2SeqDecoder, self).__init__()
self.embedder = nn.Embedding(vocab_size, embed_dim)
self.lstm_attention = nn.RNN(
Seq2SeqDecoderCell(num_layers, embed_dim, hidden_size))
self.output_layer = nn.Linear(hidden_size, vocab_size)
def forward(self, trg, decoder_initial_states, encoder_output,
encoder_padding_mask):
inputs = self.embedder(trg)
decoder_output, _ = self.lstm_attention(
inputs,
initial_states=decoder_initial_states,
encoder_output=encoder_output,
encoder_padding_mask=encoder_padding_mask)
predict = self.output_layer(decoder_output)
return predict
class Seq2SeqAttnModel(nn.Layer):
def __init__(self, vocab_size, embed_dim, hidden_size, num_layers,
eos_id=1):
super(Seq2SeqAttnModel, self).__init__()
self.hidden_size = hidden_size
self.eos_id = eos_id
self.num_layers = num_layers
self.INF = 1e9
self.encoder = Seq2SeqEncoder(vocab_size, embed_dim, hidden_size,
num_layers)
self.decoder = Seq2SeqDecoder(vocab_size, embed_dim, hidden_size,
num_layers)
def forward(self, src, src_length, trg):
encoder_output, encoder_final_state = self.encoder(src, src_length)
# Transfer shape of encoder_final_states to [num_layers, 2, batch_size, hidden_size]
encoder_final_states = [
(encoder_final_state[0][i], encoder_final_state[1][i])
for i in range(self.num_layers)
]
# Construct decoder initial states: use input_feed and the shape is
# [[h,c] * num_layers, input_feed], consistent with Seq2SeqDecoderCell.states
decoder_initial_states = [
encoder_final_states,
self.decoder.lstm_attention.cell.get_initial_states(
batch_ref=encoder_output, shape=[self.hidden_size])
]
# Build attention mask to avoid paying attention on padddings
src_mask = (src != self.eos_id).astype(paddle.get_default_dtype())
encoder_padding_mask = (src_mask - 1.0) * self.INF
encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1])
predict = self.decoder(trg, decoder_initial_states, encoder_output,
encoder_padding_mask)
return predict
class Seq2SeqAttnInferModel(Seq2SeqAttnModel):
def __init__(self,
vocab_size,
embed_dim,
hidden_size,
num_layers,
bos_id=0,
eos_id=1,
beam_size=4,
max_out_len=256):
self.bos_id = bos_id
self.beam_size = beam_size
self.max_out_len = max_out_len
self.num_layers = num_layers
super(Seq2SeqAttnInferModel, self).__init__(
vocab_size, embed_dim, hidden_size, num_layers, eos_id)
# Dynamic decoder for inference
self.beam_search_decoder = nn.BeamSearchDecoder(
self.decoder.lstm_attention.cell,
start_token=bos_id,
end_token=eos_id,
beam_size=beam_size,
embedding_fn=self.decoder.embedder,
output_fn=self.decoder.output_layer)
def forward(self, src, src_length):
encoder_output, encoder_final_state = self.encoder(src, src_length)
encoder_final_state = [
(encoder_final_state[0][i], encoder_final_state[1][i])
for i in range(self.num_layers)
]
# Initial decoder initial states
decoder_initial_states = [
encoder_final_state,
self.decoder.lstm_attention.cell.get_initial_states(
batch_ref=encoder_output, shape=[self.hidden_size])
]
# Build attention mask to avoid paying attention on paddings
src_mask = (src != self.eos_id).astype(paddle.get_default_dtype())
encoder_padding_mask = (src_mask - 1.0) * self.INF
encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1])
# Tile the batch dimension with beam_size
encoder_output = nn.BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_output, self.beam_size)
encoder_padding_mask = nn.BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_padding_mask, self.beam_size)
# Dynamic decoding with beam search
seq_output, _ = nn.dynamic_decode(
decoder=self.beam_search_decoder,
inits=decoder_initial_states,
max_step_num=self.max_out_len,
encoder_output=encoder_output,
encoder_padding_mask=encoder_padding_mask)
return seq_output
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from args import parse_args
from data import create_infer_loader, CoupletDataset
from model import Seq2SeqAttnInferModel
import numpy as np
import paddle
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
"""
Post-process the decoded sequence.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq
def do_predict(args):
device = paddle.set_device("gpu" if args.use_gpu else "cpu")
test_loader, vocab_size, pad_id, bos_id, eos_id = create_infer_loader(
args.batch_size)
vocab, _ = CoupletDataset.get_vocab()
trg_idx2word = vocab._idx_to_token
model = paddle.Model(
Seq2SeqAttnInferModel(
vocab_size,
args.hidden_size,
args.hidden_size,
args.num_layers,
bos_id=bos_id,
eos_id=eos_id,
beam_size=args.beam_size,
max_out_len=256))
model.prepare()
# Load the trained model
assert args.init_from_ckpt, (
"Please set reload_model to load the infer model.")
model.load(args.init_from_ckpt)
# TODO(guosheng): use model.predict when support variant length
with io.open(args.infer_output_file, 'w', encoding='utf-8') as f:
for data in test_loader():
inputs = data[:2]
finished_seq = model.predict_batch(inputs=list(inputs))[0]
finished_seq = finished_seq[:, :, np.newaxis] if len(
finished_seq.shape) == 2 else finished_seq
finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq:
for beam_idx, beam in enumerate(ins):
id_list = post_process_seq(beam, bos_id, eos_id)
word_list = [trg_idx2word[id] for id in id_list]
sequence = "\x02".join(word_list) + "\n"
f.write(sequence)
break
if __name__ == "__main__":
args = parse_args()
do_predict(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from args import parse_args
from data import create_train_loader
from model import Seq2SeqAttnModel, CrossEntropyCriterion
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlenlp.metrics import Perplexity
def do_train(args):
device = paddle.set_device("gpu" if args.use_gpu else "cpu")
# Define dataloader
train_loader, vocab_size, pad_id = create_train_loader(args.batch_size)
model = paddle.Model(
Seq2SeqAttnModel(vocab_size, args.hidden_size, args.hidden_size,
args.num_layers, pad_id))
optimizer = paddle.optimizer.Adam(
learning_rate=args.learning_rate, parameters=model.parameters())
ppl_metric = Perplexity()
model.prepare(optimizer, CrossEntropyCriterion(), ppl_metric)
print(args)
model.fit(train_data=train_loader,
epochs=args.max_epoch,
eval_freq=1,
save_freq=1,
save_dir=args.model_path,
log_freq=args.log_freq,
callbacks=[paddle.callbacks.VisualDL('./log')])
if __name__ == "__main__":
args = parse_args()
do_train(args)
import os import os
import io import io
import collections import collections
import warnings
from functools import partial from functools import partial
import numpy as np import numpy as np
...@@ -15,15 +16,6 @@ from paddle.dataset.common import md5file ...@@ -15,15 +16,6 @@ from paddle.dataset.common import md5file
__all__ = ['TranslationDataset', 'IWSLT15'] __all__ = ['TranslationDataset', 'IWSLT15']
def vocab_func(vocab, unk_token):
def func(tok_iter):
return [
vocab[tok] if tok in vocab else vocab[unk_token] for tok in tok_iter
]
return func
def sequential_transforms(*transforms): def sequential_transforms(*transforms):
def func(txt_input): def func(txt_input):
for transform in transforms: for transform in transforms:
...@@ -57,6 +49,10 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -57,6 +49,10 @@ class TranslationDataset(paddle.io.Dataset):
URL = None URL = None
MD5 = None MD5 = None
VOCAB_INFO = None VOCAB_INFO = None
UNK_TOKEN = None
BOS_TOKEN = None
EOS_TOKEN = None
PAD_TOKEN = None
def __init__(self, data): def __init__(self, data):
self.data = data self.data = data
...@@ -72,6 +68,8 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -72,6 +68,8 @@ class TranslationDataset(paddle.io.Dataset):
""" """
Download dataset if any data file doesn't exist. Download dataset if any data file doesn't exist.
Args: Args:
mode(str, optional): Data mode to download. It could be 'train',
'dev' or 'test'. Default: 'train'.
root (str, optional): data directory to save dataset. If not root (str, optional): data directory to save dataset. If not
provided, dataset will be saved in provided, dataset will be saved in
`/root/.paddlenlp/datasets/machine_translation`. Default: None. `/root/.paddlenlp/datasets/machine_translation`. Default: None.
...@@ -83,7 +81,8 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -83,7 +81,8 @@ class TranslationDataset(paddle.io.Dataset):
from paddlenlp.datasets import IWSLT15 from paddlenlp.datasets import IWSLT15
data_path = IWSLT15.get_data() data_path = IWSLT15.get_data()
""" """
default_root = os.path.join(DATA_HOME, 'machine_translation') default_root = os.path.join(DATA_HOME, 'machine_translation',
cls.__name__)
src_filename, tgt_filename, src_data_hash, tgt_data_hash = cls.SPLITS[ src_filename, tgt_filename, src_data_hash, tgt_data_hash = cls.SPLITS[
mode] mode]
...@@ -96,6 +95,7 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -96,6 +95,7 @@ class TranslationDataset(paddle.io.Dataset):
filename) if root is None else os.path.join( filename) if root is None else os.path.join(
os.path.expanduser(root), filename) os.path.expanduser(root), filename)
fullname_list.append(fullname) fullname_list.append(fullname)
# print(fullname)
data_hash_list = [ data_hash_list = [
src_data_hash, tgt_data_hash, cls.VOCAB_INFO[2], cls.VOCAB_INFO[3] src_data_hash, tgt_data_hash, cls.VOCAB_INFO[2], cls.VOCAB_INFO[3]
...@@ -107,13 +107,14 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -107,13 +107,14 @@ class TranslationDataset(paddle.io.Dataset):
if root is not None: # not specified, and no need to warn if root is not None: # not specified, and no need to warn
warnings.warn( warnings.warn(
'md5 check failed for {}, download {} data to {}'. 'md5 check failed for {}, download {} data to {}'.
format(filename, self.__class__.__name__, default_root)) format(filename, cls.__name__, default_root))
path = get_path_from_url(cls.URL, default_root, cls.MD5) path = get_path_from_url(cls.URL, root, cls.MD5)
break break
return root if root is not None else default_root return root if root is not None else default_root
@classmethod @classmethod
def build_vocab(cls, root=None): def get_vocab(cls, root=None):
""" """
Load vocab from vocab files. It vocab files don't exist, the will Load vocab from vocab files. It vocab files don't exist, the will
be downloaded. be downloaded.
...@@ -128,24 +129,42 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -128,24 +129,42 @@ class TranslationDataset(paddle.io.Dataset):
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddlenlp.datasets import IWSLT15 from paddlenlp.datasets import IWSLT15
(src_vocab, tgt_vocab) = IWSLT15.build_vocab() (src_vocab, tgt_vocab) = IWSLT15.get_vocab()
""" """
root = cls.get_data(root=root) root = cls.get_data(root=root)
# Get vocab_func
src_vocab_filename, tgt_vocab_filename, _, _ = cls.VOCAB_INFO src_vocab_filename, tgt_vocab_filename, _, _ = cls.VOCAB_INFO
src_file_path = os.path.join(root, src_vocab_filename) src_file_path = os.path.join(root, src_vocab_filename)
tgt_file_path = os.path.join(root, tgt_vocab_filename) tgt_file_path = os.path.join(root, tgt_vocab_filename)
src_vocab = Vocab.load_vocabulary(src_file_path, cls.UNK_TOKEN, src_vocab = Vocab.load_vocabulary(
cls.BOS_TOKEN, cls.EOS_TOKEN) src_file_path,
unk_token=cls.UNK_TOKEN,
tgt_vocab = Vocab.load_vocabulary(tgt_file_path, cls.UNK_TOKEN, pad_token=cls.PAD_TOKEN,
cls.BOS_TOKEN, cls.EOS_TOKEN) bos_token=cls.BOS_TOKEN,
eos_token=cls.EOS_TOKEN)
tgt_vocab = Vocab.load_vocabulary(
tgt_file_path,
unk_token=cls.UNK_TOKEN,
pad_token=cls.PAD_TOKEN,
bos_token=cls.BOS_TOKEN,
eos_token=cls.EOS_TOKEN)
return (src_vocab, tgt_vocab) return (src_vocab, tgt_vocab)
def read_raw_data(self, data_dir, mode): @classmethod
src_filename, tgt_filename, _, _ = self.SPLITS[mode] def read_raw_data(cls, root, mode):
"""Read raw data from data files
Args:
root(str): Data directory of dataset.
mode(str): Indicates the mode to read. It could be 'train', 'dev' or
'test'.
Returns:
list: Raw data list.
"""
# print(root)
src_filename, tgt_filename, _, _ = cls.SPLITS[mode]
def read_raw_files(corpus_path): def read_raw_files(corpus_path):
"""Read raw files, return raw data""" """Read raw files, return raw data"""
...@@ -156,8 +175,9 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -156,8 +175,9 @@ class TranslationDataset(paddle.io.Dataset):
data.append(line.strip()) data.append(line.strip())
return data return data
src_path = os.path.join(data_dir, src_filename) src_path = os.path.join(root, src_filename)
tgt_path = os.path.join(data_dir, tgt_filename) tgt_path = os.path.join(root, tgt_filename)
print(src_path, tgt_path)
src_data = read_raw_files(src_path) src_data = read_raw_files(src_path)
tgt_data = read_raw_files(tgt_path) tgt_data = read_raw_files(tgt_path)
...@@ -182,11 +202,11 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -182,11 +202,11 @@ class TranslationDataset(paddle.io.Dataset):
src_text_vocab_transform = sequential_transforms(src_tokenizer) src_text_vocab_transform = sequential_transforms(src_tokenizer)
tgt_text_vocab_transform = sequential_transforms(tgt_tokenizer) tgt_text_vocab_transform = sequential_transforms(tgt_tokenizer)
(src_vocab, tgt_vocab) = cls.build_vocab(root) (src_vocab, tgt_vocab) = cls.get_vocab(root)
src_text_transform = sequential_transforms( src_text_transform = sequential_transforms(src_text_vocab_transform,
src_text_vocab_transform, vocab_func(src_vocab, cls.UNK_TOKEN)) src_vocab)
tgt_text_transform = sequential_transforms( tgt_text_transform = sequential_transforms(tgt_text_vocab_transform,
tgt_text_vocab_transform, vocab_func(tgt_vocab, cls.UNK_TOKEN)) tgt_vocab)
return (src_text_transform, tgt_text_transform) return (src_text_transform, tgt_text_transform)
...@@ -195,10 +215,11 @@ class IWSLT15(TranslationDataset): ...@@ -195,10 +215,11 @@ class IWSLT15(TranslationDataset):
IWSLT15 Vietnames to English translation dataset. IWSLT15 Vietnames to English translation dataset.
Args: Args:
data(list|optional): Raw data. It is a list of tuple, each tuple mode(str, optional): It could be 'train', 'dev' or 'test'. Default: 'train'.
consists of source and target data. Default: None. root(str, optional): If None, dataset will be downloaded in
vocab(tuple|optional): Tuple of Vocab object or dict. It consists of `/root/.paddlenlp/datasets/machine_translation`. Default: None.
source and target language vocab. Default: None. transform_func(callable, optional): If not None, it transforms raw data
to index data. Default: None.
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddlenlp.datasets import IWSLT15 from paddlenlp.datasets import IWSLT15
...@@ -243,7 +264,7 @@ class IWSLT15(TranslationDataset): ...@@ -243,7 +264,7 @@ class IWSLT15(TranslationDataset):
raise ValueError("`transform_func` must have length of two for" raise ValueError("`transform_func` must have length of two for"
"source and target.") "source and target.")
# Download data # Download data
root = IWSLT15.get_data(root=root) root = self.get_data(root=root)
self.data = self.read_raw_data(root, mode) self.data = self.read_raw_data(root, mode)
if transform_func is not None: if transform_func is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册