未验证 提交 11ee20fb 编写于 作者: L LiuChiachi 提交者: GitHub

Add couplet examples (#5007)

* add couplet

* simplify model code

* simplify code

* update couplet README

* add pad_token to TranslationDataset, update CoupletDataset

* update couplet url, add couplet generation example

* update TranslationDataset

* upadte classname to self in __init__

* update README.md
上级 9becf7bb
# 使用Seq2Seq模型完成自动对联
\ No newline at end of file
# 使用Seq2Seq模型完成自动对联
以下是本范例模型的简要目录结构及说明:
```
.
├── README.md # 文档,本文件
├── args.py # 训练、预测以及模型参数配置程序
├── data.py # 数据读入程序
├── train.py # 训练主程序
├── predict.py # 预测主程序
└── model.py # 带注意力机制的对联生成程序
```
## 简介
Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)结构,用编码器将源序列编码成vector,再用解码器将该vector解码为目标序列。Seq2Seq 广泛应用于机器翻译,自动对话机器人,文档摘要自动生成,图片描述自动生成等任务中。
本目录包含Seq2Seq的一个经典样例:自动对联生成,带attention机制的文本生成模型。
运行本目录下的范例模型需要安装PaddlePaddle 2.0-rc0版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。
## 模型概览
本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,在预测时我们使用柱搜索(beam search)算法来生对联的下联。
## 数据介绍
本教程使用[couplet数据集](https://bj.bcebos.com/paddlehub-dataset/couplet.tar.gz)数据集作为训练语料,train.tsv作为训练集,dev.tsv数据作为开发集,test.tsv数据作为测试集
数据集会在`CoupletDataset`初始化时自动下载
## 模型训练
执行以下命令即可训练带有注意力机制的Seq2Seq模型:
```sh
python train.py \
--num_layers 2 \
--hidden_size 512 \
--batch_size 128 \
--use_gpu True \
--model_path ./couplet_models \
--max_epoch 20
```
各参数的具体说明请参阅 `args.py` 。训练程序会在每个epoch训练结束之后,save一次模型。
## 模型预测
训练完成之后,可以使用保存的模型(由 `--init_from_ckpt` 指定)对测试集进行beam search解码,命令如下:
```sh
python predict.py \
--num_layers 2 \
--hidden_size 512 \
--batch_size 128 \
--init_from_ckpt couplet_models/19 \
--infer_output_file infer_output.txt \
--beam_size 10 \
--use_gpu True
```
各参数的具体说明请参阅 `args.py` ,注意预测时所用模型超参数需和训练时一致。
## 生成对联样例
崖悬风雨骤 月落水云寒
约春章柳下 邀月醉花间
箬笠红尘外 扁舟明月中
书香醉倒窗前月 烛影摇红梦里人
踏雪寻梅求雅趣 临风把酒觅知音
未出南阳天下论 先登北斗汉中书
朱联妙语千秋颂 赤胆忠心万代传
月半举杯圆月下 花间对酒醉花间
挥笔如剑倚麓山豪气干云揽月去 落笔似龙飞沧海龙吟破浪乘风来
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--use_gpu',
type=eval,
default=False,
help='Whether using gpu [True|False]')
parser.add_argument(
"--learning_rate",
type=float,
default=0.001,
help="learning rate for optimizer")
parser.add_argument(
"--num_layers",
type=int,
default=1,
help="layers number of encoder and decoder")
parser.add_argument(
"--hidden_size",
type=int,
default=100,
help="hidden size of encoder and decoder")
parser.add_argument(
"--batch_size", type=int, default=128, help="Batch size of each step")
parser.add_argument(
"--max_epoch", type=int, default=50, help="max epoch for the training")
parser.add_argument(
"--max_len",
type=int,
default=50,
help="max length for source and target sentence")
parser.add_argument(
"--max_grad_norm",
type=float,
default=5.0,
help="max grad norm for global norm clip")
parser.add_argument(
"--log_freq",
type=int,
default=200,
help="The frequency to print training logs")
parser.add_argument(
"--model_path",
type=str,
default='model',
help="model path for model to save")
parser.add_argument(
"--init_from_ckpt",
type=str,
default=None,
help="The path of checkpoint to be loaded.")
parser.add_argument(
"--infer_file", type=str, help="file name for inference")
parser.add_argument(
"--infer_output_file",
type=str,
default='infer_output',
help="file name for inference output")
parser.add_argument(
"--beam_size", type=int, default=10, help="file name for inference")
args = parser.parse_args()
return args
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from functools import partial
import numpy as np
import paddle
from paddle.utils.download import get_path_from_url
from paddlenlp.data import Vocab, Pad
from paddlenlp.data import SamplerHelper
from paddlenlp.utils.env import DATA_HOME
from paddlenlp.datasets import TranslationDataset
def create_train_loader(batch_size=128):
train_ds = CoupletDataset.get_datasets(["train"])
vocab, _ = CoupletDataset.get_vocab()
pad_id = vocab[CoupletDataset.EOS_TOKEN]
train_batch_sampler = SamplerHelper(train_ds).shuffle().batch(
batch_size=batch_size).shard()
train_loader = paddle.io.DataLoader(
train_ds,
batch_sampler=train_batch_sampler,
collate_fn=partial(
prepare_input, pad_id=pad_id))
return train_loader, len(vocab), pad_id
def create_infer_loader(batch_size=128):
test_ds = CoupletDataset.get_datasets(["test"])
vocab, _ = CoupletDataset.get_vocab()
pad_id = vocab[CoupletDataset.EOS_TOKEN]
bos_id = vocab[CoupletDataset.BOS_TOKEN]
eos_id = vocab[CoupletDataset.EOS_TOKEN]
test_batch_sampler = SamplerHelper(test_ds).batch(
batch_size=batch_size).shard()
test_loader = paddle.io.DataLoader(
test_ds,
batch_sampler=test_batch_sampler,
collate_fn=partial(
prepare_input, pad_id=pad_id))
return test_loader, len(vocab), pad_id, bos_id, eos_id
def prepare_input(insts, pad_id):
src, src_length = Pad(pad_val=pad_id, ret_length=True)(
[inst[0] for inst in insts])
tgt, tgt_length = Pad(pad_val=pad_id, ret_length=True)(
[inst[1] for inst in insts])
tgt_mask = (tgt[:, :-1] != pad_id).astype(paddle.get_default_dtype())
return src, src_length, tgt[:, :-1], tgt[:, 1:, np.newaxis], tgt_mask
class CoupletDataset(TranslationDataset):
URL = "https://paddlenlp.bj.bcebos.com/datasets/couplet.tar.gz"
SPLITS = {
'train': TranslationDataset.META_INFO(
os.path.join("couplet", "train_src.tsv"),
os.path.join("couplet", "train_tgt.tsv"),
"ad137385ad5e264ac4a54fe8c95d1583",
"daf4dd79dbf26040696eee0d645ef5ad"),
'dev': TranslationDataset.META_INFO(
os.path.join("couplet", "dev_src.tsv"),
os.path.join("couplet", "dev_tgt.tsv"),
"65bf9e72fa8fdf0482751c1fd6b6833c",
"3bc3b300b19d170923edfa8491352951"),
'test': TranslationDataset.META_INFO(
os.path.join("couplet", "test_src.tsv"),
os.path.join("couplet", "test_tgt.tsv"),
"f0a7366dfa0acac884b9f4901aac2cc1",
"56664bff3f2edfd7a751a55a689f90c2")
}
VOCAB_INFO = (os.path.join("couplet", "vocab.txt"), os.path.join(
"couplet", "vocab.txt"), "0bea1445c7c7fb659b856bb07e54a604",
"0bea1445c7c7fb659b856bb07e54a604")
UNK_TOKEN = '<unk>'
BOS_TOKEN = '<s>'
EOS_TOKEN = '</s>'
MD5 = '5c0dcde8eec6a517492227041c2e2d54'
def __init__(self, mode='train', root='./'):
data_select = ('train', 'dev', 'test')
if mode not in data_select:
raise TypeError(
'`train`, `dev` or `test` is supported but `{}` is passed in'.
format(mode))
# Download data
root = self.get_data(root=root)
self.data = self.read_raw_data(root, mode)
self.vocab, _ = self.get_vocab(root)
self.transform()
def transform(self):
eos_id = self.vocab[self.EOS_TOKEN]
bos_id = self.vocab[self.BOS_TOKEN]
self.data = [(
[bos_id] + self.vocab.to_indices(data[0].split("\x02")) + [eos_id],
[bos_id] + self.vocab.to_indices(data[1].split("\x02")) + [eos_id])
for data in self.data]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class CrossEntropyCriterion(nn.Layer):
def __init__(self):
super(CrossEntropyCriterion, self).__init__()
def forward(self, predict, label, trg_mask):
cost = F.softmax_with_cross_entropy(
logits=predict, label=label, soft_label=False)
cost = paddle.squeeze(cost, axis=[2])
masked_cost = cost * trg_mask
batch_mean_cost = paddle.mean(masked_cost, axis=[0])
seq_cost = paddle.sum(batch_mean_cost)
return seq_cost
class Seq2SeqEncoder(nn.Layer):
def __init__(self, vocab_size, embed_dim, hidden_size, num_layers):
super(Seq2SeqEncoder, self).__init__()
self.embedder = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=0.2 if num_layers > 1 else 0.)
def forward(self, sequence, sequence_length):
inputs = self.embedder(sequence)
encoder_output, encoder_state = self.lstm(
inputs, sequence_length=sequence_length)
return encoder_output, encoder_state
class AttentionLayer(nn.Layer):
def __init__(self, hidden_size):
super(AttentionLayer, self).__init__()
self.input_proj = nn.Linear(hidden_size, hidden_size)
self.output_proj = nn.Linear(hidden_size + hidden_size, hidden_size)
def forward(self, hidden, encoder_output, encoder_padding_mask):
encoder_output = self.input_proj(encoder_output)
attn_scores = paddle.matmul(
paddle.unsqueeze(hidden, [1]), encoder_output, transpose_y=True)
if encoder_padding_mask is not None:
attn_scores = paddle.add(attn_scores, encoder_padding_mask)
attn_scores = F.softmax(attn_scores)
attn_out = paddle.squeeze(
paddle.matmul(attn_scores, encoder_output), [1])
attn_out = paddle.concat([attn_out, hidden], 1)
attn_out = self.output_proj(attn_out)
return attn_out
class Seq2SeqDecoderCell(nn.RNNCellBase):
def __init__(self, num_layers, input_size, hidden_size):
super(Seq2SeqDecoderCell, self).__init__()
self.dropout = nn.Dropout(0.2)
self.lstm_cells = nn.LayerList([
nn.LSTMCell(
input_size=input_size + hidden_size if i == 0 else hidden_size,
hidden_size=hidden_size) for i in range(num_layers)
])
self.attention_layer = AttentionLayer(hidden_size)
def forward(self,
step_input,
states,
encoder_output,
encoder_padding_mask=None):
lstm_states, input_feed = states
new_lstm_states = []
step_input = paddle.concat([step_input, input_feed], 1)
for i, lstm_cell in enumerate(self.lstm_cells):
out, new_lstm_state = lstm_cell(step_input, lstm_states[i])
step_input = self.dropout(out)
new_lstm_states.append(new_lstm_state)
out = self.attention_layer(step_input, encoder_output,
encoder_padding_mask)
return out, [new_lstm_states, out]
class Seq2SeqDecoder(nn.Layer):
def __init__(self, vocab_size, embed_dim, hidden_size, num_layers):
super(Seq2SeqDecoder, self).__init__()
self.embedder = nn.Embedding(vocab_size, embed_dim)
self.lstm_attention = nn.RNN(
Seq2SeqDecoderCell(num_layers, embed_dim, hidden_size))
self.output_layer = nn.Linear(hidden_size, vocab_size)
def forward(self, trg, decoder_initial_states, encoder_output,
encoder_padding_mask):
inputs = self.embedder(trg)
decoder_output, _ = self.lstm_attention(
inputs,
initial_states=decoder_initial_states,
encoder_output=encoder_output,
encoder_padding_mask=encoder_padding_mask)
predict = self.output_layer(decoder_output)
return predict
class Seq2SeqAttnModel(nn.Layer):
def __init__(self, vocab_size, embed_dim, hidden_size, num_layers,
eos_id=1):
super(Seq2SeqAttnModel, self).__init__()
self.hidden_size = hidden_size
self.eos_id = eos_id
self.num_layers = num_layers
self.INF = 1e9
self.encoder = Seq2SeqEncoder(vocab_size, embed_dim, hidden_size,
num_layers)
self.decoder = Seq2SeqDecoder(vocab_size, embed_dim, hidden_size,
num_layers)
def forward(self, src, src_length, trg):
encoder_output, encoder_final_state = self.encoder(src, src_length)
# Transfer shape of encoder_final_states to [num_layers, 2, batch_size, hidden_size]
encoder_final_states = [
(encoder_final_state[0][i], encoder_final_state[1][i])
for i in range(self.num_layers)
]
# Construct decoder initial states: use input_feed and the shape is
# [[h,c] * num_layers, input_feed], consistent with Seq2SeqDecoderCell.states
decoder_initial_states = [
encoder_final_states,
self.decoder.lstm_attention.cell.get_initial_states(
batch_ref=encoder_output, shape=[self.hidden_size])
]
# Build attention mask to avoid paying attention on padddings
src_mask = (src != self.eos_id).astype(paddle.get_default_dtype())
encoder_padding_mask = (src_mask - 1.0) * self.INF
encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1])
predict = self.decoder(trg, decoder_initial_states, encoder_output,
encoder_padding_mask)
return predict
class Seq2SeqAttnInferModel(Seq2SeqAttnModel):
def __init__(self,
vocab_size,
embed_dim,
hidden_size,
num_layers,
bos_id=0,
eos_id=1,
beam_size=4,
max_out_len=256):
self.bos_id = bos_id
self.beam_size = beam_size
self.max_out_len = max_out_len
self.num_layers = num_layers
super(Seq2SeqAttnInferModel, self).__init__(
vocab_size, embed_dim, hidden_size, num_layers, eos_id)
# Dynamic decoder for inference
self.beam_search_decoder = nn.BeamSearchDecoder(
self.decoder.lstm_attention.cell,
start_token=bos_id,
end_token=eos_id,
beam_size=beam_size,
embedding_fn=self.decoder.embedder,
output_fn=self.decoder.output_layer)
def forward(self, src, src_length):
encoder_output, encoder_final_state = self.encoder(src, src_length)
encoder_final_state = [
(encoder_final_state[0][i], encoder_final_state[1][i])
for i in range(self.num_layers)
]
# Initial decoder initial states
decoder_initial_states = [
encoder_final_state,
self.decoder.lstm_attention.cell.get_initial_states(
batch_ref=encoder_output, shape=[self.hidden_size])
]
# Build attention mask to avoid paying attention on paddings
src_mask = (src != self.eos_id).astype(paddle.get_default_dtype())
encoder_padding_mask = (src_mask - 1.0) * self.INF
encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1])
# Tile the batch dimension with beam_size
encoder_output = nn.BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_output, self.beam_size)
encoder_padding_mask = nn.BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_padding_mask, self.beam_size)
# Dynamic decoding with beam search
seq_output, _ = nn.dynamic_decode(
decoder=self.beam_search_decoder,
inits=decoder_initial_states,
max_step_num=self.max_out_len,
encoder_output=encoder_output,
encoder_padding_mask=encoder_padding_mask)
return seq_output
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from args import parse_args
from data import create_infer_loader, CoupletDataset
from model import Seq2SeqAttnInferModel
import numpy as np
import paddle
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
"""
Post-process the decoded sequence.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq
def do_predict(args):
device = paddle.set_device("gpu" if args.use_gpu else "cpu")
test_loader, vocab_size, pad_id, bos_id, eos_id = create_infer_loader(
args.batch_size)
vocab, _ = CoupletDataset.get_vocab()
trg_idx2word = vocab._idx_to_token
model = paddle.Model(
Seq2SeqAttnInferModel(
vocab_size,
args.hidden_size,
args.hidden_size,
args.num_layers,
bos_id=bos_id,
eos_id=eos_id,
beam_size=args.beam_size,
max_out_len=256))
model.prepare()
# Load the trained model
assert args.init_from_ckpt, (
"Please set reload_model to load the infer model.")
model.load(args.init_from_ckpt)
# TODO(guosheng): use model.predict when support variant length
with io.open(args.infer_output_file, 'w', encoding='utf-8') as f:
for data in test_loader():
inputs = data[:2]
finished_seq = model.predict_batch(inputs=list(inputs))[0]
finished_seq = finished_seq[:, :, np.newaxis] if len(
finished_seq.shape) == 2 else finished_seq
finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq:
for beam_idx, beam in enumerate(ins):
id_list = post_process_seq(beam, bos_id, eos_id)
word_list = [trg_idx2word[id] for id in id_list]
sequence = "\x02".join(word_list) + "\n"
f.write(sequence)
break
if __name__ == "__main__":
args = parse_args()
do_predict(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from args import parse_args
from data import create_train_loader
from model import Seq2SeqAttnModel, CrossEntropyCriterion
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlenlp.metrics import Perplexity
def do_train(args):
device = paddle.set_device("gpu" if args.use_gpu else "cpu")
# Define dataloader
train_loader, vocab_size, pad_id = create_train_loader(args.batch_size)
model = paddle.Model(
Seq2SeqAttnModel(vocab_size, args.hidden_size, args.hidden_size,
args.num_layers, pad_id))
optimizer = paddle.optimizer.Adam(
learning_rate=args.learning_rate, parameters=model.parameters())
ppl_metric = Perplexity()
model.prepare(optimizer, CrossEntropyCriterion(), ppl_metric)
print(args)
model.fit(train_data=train_loader,
epochs=args.max_epoch,
eval_freq=1,
save_freq=1,
save_dir=args.model_path,
log_freq=args.log_freq,
callbacks=[paddle.callbacks.VisualDL('./log')])
if __name__ == "__main__":
args = parse_args()
do_train(args)
import os
import io
import collections
import warnings
from functools import partial
import numpy as np
......@@ -15,15 +16,6 @@ from paddle.dataset.common import md5file
__all__ = ['TranslationDataset', 'IWSLT15']
def vocab_func(vocab, unk_token):
def func(tok_iter):
return [
vocab[tok] if tok in vocab else vocab[unk_token] for tok in tok_iter
]
return func
def sequential_transforms(*transforms):
def func(txt_input):
for transform in transforms:
......@@ -57,6 +49,10 @@ class TranslationDataset(paddle.io.Dataset):
URL = None
MD5 = None
VOCAB_INFO = None
UNK_TOKEN = None
BOS_TOKEN = None
EOS_TOKEN = None
PAD_TOKEN = None
def __init__(self, data):
self.data = data
......@@ -72,6 +68,8 @@ class TranslationDataset(paddle.io.Dataset):
"""
Download dataset if any data file doesn't exist.
Args:
mode(str, optional): Data mode to download. It could be 'train',
'dev' or 'test'. Default: 'train'.
root (str, optional): data directory to save dataset. If not
provided, dataset will be saved in
`/root/.paddlenlp/datasets/machine_translation`. Default: None.
......@@ -83,7 +81,8 @@ class TranslationDataset(paddle.io.Dataset):
from paddlenlp.datasets import IWSLT15
data_path = IWSLT15.get_data()
"""
default_root = os.path.join(DATA_HOME, 'machine_translation')
default_root = os.path.join(DATA_HOME, 'machine_translation',
cls.__name__)
src_filename, tgt_filename, src_data_hash, tgt_data_hash = cls.SPLITS[
mode]
......@@ -96,6 +95,7 @@ class TranslationDataset(paddle.io.Dataset):
filename) if root is None else os.path.join(
os.path.expanduser(root), filename)
fullname_list.append(fullname)
# print(fullname)
data_hash_list = [
src_data_hash, tgt_data_hash, cls.VOCAB_INFO[2], cls.VOCAB_INFO[3]
......@@ -107,13 +107,14 @@ class TranslationDataset(paddle.io.Dataset):
if root is not None: # not specified, and no need to warn
warnings.warn(
'md5 check failed for {}, download {} data to {}'.
format(filename, self.__class__.__name__, default_root))
path = get_path_from_url(cls.URL, default_root, cls.MD5)
format(filename, cls.__name__, default_root))
path = get_path_from_url(cls.URL, root, cls.MD5)
break
return root if root is not None else default_root
@classmethod
def build_vocab(cls, root=None):
def get_vocab(cls, root=None):
"""
Load vocab from vocab files. It vocab files don't exist, the will
be downloaded.
......@@ -128,24 +129,42 @@ class TranslationDataset(paddle.io.Dataset):
Examples:
.. code-block:: python
from paddlenlp.datasets import IWSLT15
(src_vocab, tgt_vocab) = IWSLT15.build_vocab()
(src_vocab, tgt_vocab) = IWSLT15.get_vocab()
"""
root = cls.get_data(root=root)
# Get vocab_func
src_vocab_filename, tgt_vocab_filename, _, _ = cls.VOCAB_INFO
src_file_path = os.path.join(root, src_vocab_filename)
tgt_file_path = os.path.join(root, tgt_vocab_filename)
src_vocab = Vocab.load_vocabulary(src_file_path, cls.UNK_TOKEN,
cls.BOS_TOKEN, cls.EOS_TOKEN)
tgt_vocab = Vocab.load_vocabulary(tgt_file_path, cls.UNK_TOKEN,
cls.BOS_TOKEN, cls.EOS_TOKEN)
src_vocab = Vocab.load_vocabulary(
src_file_path,
unk_token=cls.UNK_TOKEN,
pad_token=cls.PAD_TOKEN,
bos_token=cls.BOS_TOKEN,
eos_token=cls.EOS_TOKEN)
tgt_vocab = Vocab.load_vocabulary(
tgt_file_path,
unk_token=cls.UNK_TOKEN,
pad_token=cls.PAD_TOKEN,
bos_token=cls.BOS_TOKEN,
eos_token=cls.EOS_TOKEN)
return (src_vocab, tgt_vocab)
def read_raw_data(self, data_dir, mode):
src_filename, tgt_filename, _, _ = self.SPLITS[mode]
@classmethod
def read_raw_data(cls, root, mode):
"""Read raw data from data files
Args:
root(str): Data directory of dataset.
mode(str): Indicates the mode to read. It could be 'train', 'dev' or
'test'.
Returns:
list: Raw data list.
"""
# print(root)
src_filename, tgt_filename, _, _ = cls.SPLITS[mode]
def read_raw_files(corpus_path):
"""Read raw files, return raw data"""
......@@ -156,8 +175,9 @@ class TranslationDataset(paddle.io.Dataset):
data.append(line.strip())
return data
src_path = os.path.join(data_dir, src_filename)
tgt_path = os.path.join(data_dir, tgt_filename)
src_path = os.path.join(root, src_filename)
tgt_path = os.path.join(root, tgt_filename)
print(src_path, tgt_path)
src_data = read_raw_files(src_path)
tgt_data = read_raw_files(tgt_path)
......@@ -182,11 +202,11 @@ class TranslationDataset(paddle.io.Dataset):
src_text_vocab_transform = sequential_transforms(src_tokenizer)
tgt_text_vocab_transform = sequential_transforms(tgt_tokenizer)
(src_vocab, tgt_vocab) = cls.build_vocab(root)
src_text_transform = sequential_transforms(
src_text_vocab_transform, vocab_func(src_vocab, cls.UNK_TOKEN))
tgt_text_transform = sequential_transforms(
tgt_text_vocab_transform, vocab_func(tgt_vocab, cls.UNK_TOKEN))
(src_vocab, tgt_vocab) = cls.get_vocab(root)
src_text_transform = sequential_transforms(src_text_vocab_transform,
src_vocab)
tgt_text_transform = sequential_transforms(tgt_text_vocab_transform,
tgt_vocab)
return (src_text_transform, tgt_text_transform)
......@@ -195,10 +215,11 @@ class IWSLT15(TranslationDataset):
IWSLT15 Vietnames to English translation dataset.
Args:
data(list|optional): Raw data. It is a list of tuple, each tuple
consists of source and target data. Default: None.
vocab(tuple|optional): Tuple of Vocab object or dict. It consists of
source and target language vocab. Default: None.
mode(str, optional): It could be 'train', 'dev' or 'test'. Default: 'train'.
root(str, optional): If None, dataset will be downloaded in
`/root/.paddlenlp/datasets/machine_translation`. Default: None.
transform_func(callable, optional): If not None, it transforms raw data
to index data. Default: None.
Examples:
.. code-block:: python
from paddlenlp.datasets import IWSLT15
......@@ -243,7 +264,7 @@ class IWSLT15(TranslationDataset):
raise ValueError("`transform_func` must have length of two for"
"source and target.")
# Download data
root = IWSLT15.get_data(root=root)
root = self.get_data(root=root)
self.data = self.read_raw_data(root, mode)
if transform_func is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册