未验证 提交 30ccfc67 编写于 作者: L liu zhengxi 提交者: GitHub

[Transformer] Simplify transformer reader and fix TranslationDataset (#5035)

* fix translation dataset and simplify transformer reader
上级 8c9d8f56
...@@ -10,35 +10,21 @@ init_from_pretrain_model: "" ...@@ -10,35 +10,21 @@ init_from_pretrain_model: ""
init_from_params: "./trained_models/step_final/" init_from_params: "./trained_models/step_final/"
# The directory for saving model # The directory for saving model
save_model: "trained_models" save_model: "trained_models"
# The directory for saving inference model.
inference_model_dir: "infer_model"
# Set seed for CE or debug # Set seed for CE or debug
random_seed: None random_seed: None
# The pattern to match training data files.
training_file: "../gen_data/wmt14_ende_data_bpe/train.tok.clean.bpe.33708.en-de"
# The pattern to match validation data files.
validation_file: "../gen_data/wmt14_ende_data_bpe/newstest2013.tok.bpe.33708.en-de"
# The pattern to match test data files.
predict_file: "../gen_data/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en-de"
# The file to output the translation results of predict_file to. # The file to output the translation results of predict_file to.
output_file: "predict.txt" output_file: "predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath: "../gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The path of vocabulary file of target language.
trg_vocab_fpath: "../gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The <bos>, <eos> and <unk> tokens in the dictionary. # The <bos>, <eos> and <unk> tokens in the dictionary.
special_token: ["<s>", "<e>", "<unk>"] special_token: ["<s>", "<e>", "<unk>"]
# The directory to store data.
root: None
# Whether to use cuda # Whether to use cuda
use_gpu: True use_gpu: True
# Args for reader, see reader.py for details # Args for reader, see reader.py for details
token_delimiter: " "
use_token_batch: True
pool_size: 200000 pool_size: 200000
sort_type: "global" sort_type: "global"
shuffle: False
shuffle_batch: False
batch_size: 4096 batch_size: 4096
infer_batch_size: 16 infer_batch_size: 16
......
...@@ -52,8 +52,7 @@ def do_predict(args): ...@@ -52,8 +52,7 @@ def do_predict(args):
paddle.set_device(place) paddle.set_device(place)
# Define data loader # Define data loader
(test_loader, test_loader, to_tokens = reader.create_infer_loader(args)
test_steps_fn), trg_idx2word = reader.create_infer_loader(args)
# Define model # Define model
transformer = InferTransformerModel( transformer = InferTransformerModel(
...@@ -90,6 +89,7 @@ def do_predict(args): ...@@ -90,6 +89,7 @@ def do_predict(args):
transformer.eval() transformer.eval()
f = open(args.output_file, "w") f = open(args.output_file, "w")
with paddle.no_grad():
for (src_word, ) in test_loader: for (src_word, ) in test_loader:
finished_seq = transformer(src_word=src_word) finished_seq = transformer(src_word=src_word)
finished_seq = finished_seq.numpy().transpose([0, 2, 1]) finished_seq = finished_seq.numpy().transpose([0, 2, 1])
...@@ -98,7 +98,7 @@ def do_predict(args): ...@@ -98,7 +98,7 @@ def do_predict(args):
if beam_idx >= args.n_best: if beam_idx >= args.n_best:
break break
id_list = post_process_seq(beam, args.bos_idx, args.eos_idx) id_list = post_process_seq(beam, args.bos_idx, args.eos_idx)
word_list = [trg_idx2word[id] for id in id_list] word_list = to_tokens(id_list)
sequence = " ".join(word_list) + "\n" sequence = " ".join(word_list) + "\n"
f.write(sequence) f.write(sequence)
......
...@@ -51,9 +51,7 @@ def do_train(args): ...@@ -51,9 +51,7 @@ def do_train(args):
paddle.seed(random_seed) paddle.seed(random_seed)
# Define data loader # Define data loader
(train_loader, train_steps_fn), (eval_loader, (train_loader), (eval_loader) = reader.create_data_loader(args)
eval_steps_fn) = reader.create_data_loader(
args, trainer_count, rank)
# Define model # Define model
transformer = TransformerModel( transformer = TransformerModel(
...@@ -176,7 +174,6 @@ def do_train(args): ...@@ -176,7 +174,6 @@ def do_train(args):
if step_idx % args.save_step == 0 and step_idx != 0: if step_idx % args.save_step == 0 and step_idx != 0:
# Validation # Validation
if args.validation_file:
transformer.eval() transformer.eval()
total_sum_cost = 0 total_sum_cost = 0
total_token_num = 0 total_token_num = 0
......
#! /usr/bin/env bash
set -e
OUTPUT_DIR=$PWD/gen_data
###############################################################################
# change these variables for other WMT data
###############################################################################
OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_ende_data"
OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_ende_data_bpe"
LANG1="en"
LANG2="de"
# each of TRAIN_DATA: data_url data_file_lang1 data_file_lang2
TRAIN_DATA=(
'http://statmt.org/wmt13/training-parallel-europarl-v7.tgz'
'europarl-v7.de-en.en' 'europarl-v7.de-en.de'
'http://statmt.org/wmt13/training-parallel-commoncrawl.tgz'
'commoncrawl.de-en.en' 'commoncrawl.de-en.de'
'http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz'
'news-commentary-v12.de-en.en' 'news-commentary-v12.de-en.de'
)
# each of DEV_TEST_DATA: data_url data_file_lang1 data_file_lang2
# source & reference
DEV_TEST_DATA=(
'http://data.statmt.org/wmt17/translation-task/dev.tgz'
'newstest2013-ref.de.sgm' 'newstest2013-src.en.sgm'
'http://statmt.org/wmt14/test-full.tgz'
'newstest2014-deen-ref.en.sgm' 'newstest2014-deen-src.de.sgm'
)
###############################################################################
###############################################################################
# change these variables for other WMT data
###############################################################################
# OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_enfr_data"
# OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_enfr_data_bpe"
# LANG1="en"
# LANG2="fr"
# # each of TRAIN_DATA: ata_url data_tgz data_file
# TRAIN_DATA=(
# 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
# 'commoncrawl.fr-en.en' 'commoncrawl.fr-en.fr'
# 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz'
# 'training/europarl-v7.fr-en.en' 'training/europarl-v7.fr-en.fr'
# 'http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz'
# 'training/news-commentary-v9.fr-en.en' 'training/news-commentary-v9.fr-en.fr'
# 'http://www.statmt.org/wmt10/training-giga-fren.tar'
# 'giga-fren.release2.fixed.en.*' 'giga-fren.release2.fixed.fr.*'
# 'http://www.statmt.org/wmt13/training-parallel-un.tgz'
# 'un/undoc.2000.fr-en.en' 'un/undoc.2000.fr-en.fr'
# )
# # each of DEV_TEST_DATA: data_url data_tgz data_file_lang1 data_file_lang2
# DEV_TEST_DATA=(
# 'http://data.statmt.org/wmt16/translation-task/dev.tgz'
# '.*/newstest201[45]-fren-ref.en.sgm' '.*/newstest201[45]-fren-src.fr.sgm'
# 'http://data.statmt.org/wmt16/translation-task/test.tgz'
# '.*/newstest2016-fren-ref.en.sgm' '.*/newstest2016-fren-src.fr.sgm'
# )
###############################################################################
mkdir -p $OUTPUT_DIR_DATA $OUTPUT_DIR_BPE_DATA
# Extract training data
for ((i=0;i<${#TRAIN_DATA[@]};i+=3)); do
data_url=${TRAIN_DATA[i]}
data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
data=${data_tgz%.*} # training-parallel-commoncrawl
data_lang1=${TRAIN_DATA[i+1]}
data_lang2=${TRAIN_DATA[i+2]}
if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
echo "Download "${data_url}
echo "Dir "${OUTPUT_DIR_DATA}/${data_tgz}
wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
fi
if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
echo "Extract "${data_tgz}
mkdir -p ${OUTPUT_DIR_DATA}/${data}
tar_type=${data_tgz:0-3}
if [ ${tar_type} == "tar" ]; then
tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
else
tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
fi
fi
# concatenate all training data
for data_lang in $data_lang1 $data_lang2; do
for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
data_dir=`dirname $f`
data_file=`basename $f`
f_base=${f%.*}
f_ext=${f##*.}
if [ $f_ext == "gz" ]; then
gunzip $f
l=${f_base##*.}
f_base=${f_base%.*}
else
l=${f_ext}
fi
if [ $i -eq 0 ]; then
cat ${f_base}.$l > ${OUTPUT_DIR_DATA}/train.$l
else
cat ${f_base}.$l >> ${OUTPUT_DIR_DATA}/train.$l
fi
done
done
done
# Clone mosesdecoder
if [ ! -d ${OUTPUT_DIR}/mosesdecoder ]; then
echo "Cloning moses for data processing"
git clone https://github.com/moses-smt/mosesdecoder.git ${OUTPUT_DIR}/mosesdecoder
fi
# Extract develop and test data
dev_test_data=""
for ((i=0;i<${#DEV_TEST_DATA[@]};i+=3)); do
data_url=${DEV_TEST_DATA[i]}
data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
data=${data_tgz%.*} # training-parallel-commoncrawl
data_lang1=${DEV_TEST_DATA[i+1]}
data_lang2=${DEV_TEST_DATA[i+2]}
if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
echo "Download "${data_url}
wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
fi
if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
echo "Extract "${data_tgz}
mkdir -p ${OUTPUT_DIR_DATA}/${data}
tar_type=${data_tgz:0-3}
if [ ${tar_type} == "tar" ]; then
tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
else
tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
fi
fi
for data_lang in $data_lang1 $data_lang2; do
for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
echo "input-from-sgm"
data_dir=`dirname $f`
data_file=`basename $f`
data_out=`echo ${data_file} | cut -d '-' -f 1` # newstest2016
l=`echo ${data_file} | cut -d '.' -f 2` # en
dev_test_data="${dev_test_data}\|${data_out}" # to make regexp
if [ ! -e ${OUTPUT_DIR_DATA}/${data_out}.$l ]; then
${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
< $f > ${OUTPUT_DIR_DATA}/${data_out}.$l
fi
done
done
done
# Tokenize data
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train\|newstest2013\)\.$l$"`; do
f_base=${f%.*} # dir/train dir/newstest2013
f_out=$f_base.tok.$l
f_tmp=$f_base.tmp.$l
if [ ! -e $f_out ]; then
echo "Tokenize "$f
cat $f | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/normalize-punctuation.perl $l | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/remove-non-printing-char.perl | \
tee -a $tmp/valid.raw.$l | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l $l -threads 8 >> $f_out
echo $f_out
fi
done
done
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(newstest2014\)\.$l$"`; do
f_base=${f%.*} # dir/newstest2014
f_out=$f_base.tok.$l
if [ ! -e $f_out ]; then
echo "Tokenize "$f
cat $f | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l $l -threads 8 >> $f_out
echo $f_out
fi
done
done
# Clean data
for f in ${OUTPUT_DIR_DATA}/train.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.${LANG1}; do
f_base=${f%.*} # dir/train dir/train.tok
f_out=${f_base}.clean
if [ ! -e $f_out.${LANG1} ] && [ ! -e $f_out.${LANG2} ]; then
echo "Clean "${f_base}
${OUTPUT_DIR}/mosesdecoder/scripts/training/clean-corpus-n.perl $f_base ${LANG1} ${LANG2} ${f_out} 1 256
fi
done
python -m pip install subword-nmt
# Generate BPE data and vocabulary
for num_operations in 33708; do
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} ]; then
echo "Learn BPE with ${num_operations} merge operations"
cat ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG2} | \
subword-nmt learn-bpe -s $num_operations > ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations}
fi
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train${dev_test_data}\)\.tok\(\.clean\)\?\.$l$"`; do
f_base=${f%.*} # dir/train.tok dir/train.tok.clean dir/newstest2016.tok
f_base=${f_base##*/} # train.tok train.tok.clean newstest2016.tok
f_out=${OUTPUT_DIR_BPE_DATA}/${f_base}.bpe.${num_operations}.$l
if [ ! -e $f_out ]; then
echo "Apply BPE to "$f
subword-nmt apply-bpe -c ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} < $f > $f_out
fi
done
done
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} ]; then
echo "Create vocabulary for BPE data"
cat ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG1} ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG2} | \
subword-nmt get-vocab | cut -f1 -d ' ' > ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations}
fi
done
# Adapt to the reader
for f in ${OUTPUT_DIR_BPE_DATA}/*.bpe.${num_operations}.${LANG1}; do
f_base=${f%.*} # dir/train.tok.clean.bpe.32000 dir/newstest2016.tok.bpe.32000
f_out=${f_base}.${LANG1}-${LANG2}
if [ ! -e $f_out ]; then
paste -d '\t' $f_base.${LANG1} $f_base.${LANG2} > $f_out
fi
done
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations} ]; then
sed '1i\<s>\n<e>\n<unk>' ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} > ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations}
fi
echo "All done."
...@@ -22,79 +22,105 @@ from functools import partial ...@@ -22,79 +22,105 @@ from functools import partial
import numpy as np import numpy as np
from paddle.io import BatchSampler, DataLoader, Dataset from paddle.io import BatchSampler, DataLoader, Dataset
from paddlenlp.data import Pad from paddlenlp.data import Pad
from paddlenlp.datasets import WMT14ende
from paddlenlp.data.sampler import SamplerHelper
def create_infer_loader(args): def min_max_filer(data, max_len, min_len=0):
dataset = TransformerDataset( # 1 for special tokens.
fpattern=args.predict_file, data_min_len = min(len(data[0]), len(data[1])) + 1
src_vocab_fpath=args.src_vocab_fpath, data_max_len = max(len(data[0]), len(data[1])) + 1
trg_vocab_fpath=args.trg_vocab_fpath, return (data_min_len >= min_len) and (data_max_len <= max_len)
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1], def create_data_loader(args):
unk_mark=args.special_token[2]) root = None if args.root == "None" else args.root
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
args.unk_idx = dataset.get_vocab_summary() args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
trg_idx2word = TransformerDataset.load_dict( transform_func = WMT14ende.get_default_transform_func(root=root)
dict_path=args.trg_vocab_fpath, reverse=True) datasets = [
batch_sampler = TransformerBatchSampler( WMT14ende.get_datasets(
dataset=dataset, mode=m, transform_func=transform_func) for m in ["train", "dev"]
use_token_batch=False, ]
batch_size=args.infer_batch_size,
max_length=args.max_length) def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
data_source):
return max(tokens_sofar,
len(data_source[current_idx][0]) + 1,
len(data_source[current_idx][1]) + 1)
def _key(size_so_far, minibatch_len):
return size_so_far * minibatch_len
data_loaders = [(None)] * 2
for i, dataset in enumerate(datasets):
m = dataset.mode
dataset = dataset.filter(
partial(
min_max_filer, max_len=args.max_length))
sampler = SamplerHelper(dataset)
src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
if args.sort_type == SortType.GLOBAL:
buffer_size = -1
trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
# Sort twice
sampler = sampler.sort(
key=trg_key, buffer_size=buffer_size).sort(
key=src_key, buffer_size=buffer_size)
else:
sampler = sampler.shuffle()
if args.sort_type == SortType.POOL:
buffer_size = args.pool_size
sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
batch_sampler = sampler.batch(
batch_size=args.batch_size,
drop_last=False,
batch_size_fn=_max_token_fn,
key=_key)
if m == "train":
batch_sampler = batch_sampler.shard()
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=partial( collate_fn=partial(
prepare_infer_input, prepare_train_input,
bos_idx=args.bos_idx, bos_idx=args.bos_idx,
eos_idx=args.eos_idx, eos_idx=args.eos_idx,
pad_idx=args.eos_idx), pad_idx=args.bos_idx),
num_workers=0, num_workers=0,
return_list=True) return_list=True)
data_loaders = (data_loader, batch_sampler.__len__) data_loaders[i] = (data_loader)
return data_loaders, trg_idx2word return data_loaders
def create_data_loader(args, world_size=1, rank=0): def create_infer_loader(args):
data_loaders = [(None, None)] * 2 root = None if args.root == "None" else args.root
data_files = [args.training_file, args.validation_file (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
] if args.validation_file else [args.training_file] args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
for i, data_file in enumerate(data_files): transform_func = WMT14ende.get_default_transform_func(root=root)
dataset = TransformerDataset( dataset = WMT14ende.get_datasets(
fpattern=data_file, mode="test", transform_func=transform_func).filter(
src_vocab_fpath=args.src_vocab_fpath, partial(
trg_vocab_fpath=args.trg_vocab_fpath, min_max_filer, max_len=args.max_length))
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0], batch_sampler = SamplerHelper(dataset).batch(
end_mark=args.special_token[1], batch_size=args.infer_batch_size, drop_last=False)
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = TransformerBatchSampler(
dataset=dataset,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
use_token_batch=args.use_token_batch,
max_length=args.max_length,
distribute_mode=True if i == 0 else False,
world_size=world_size,
rank=rank)
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=partial( collate_fn=partial(
prepare_train_input, prepare_infer_input,
bos_idx=args.bos_idx, bos_idx=args.bos_idx,
eos_idx=args.eos_idx, eos_idx=args.eos_idx,
pad_idx=args.bos_idx), pad_idx=args.bos_idx),
num_workers=0, num_workers=0,
return_list=True) return_list=True)
data_loaders[i] = (data_loader, batch_sampler.__len__) return data_loader, trg_vocab.to_tokens
return data_loaders
def prepare_train_input(insts, bos_idx, eos_idx, pad_idx): def prepare_train_input(insts, bos_idx, eos_idx, pad_idx):
...@@ -126,301 +152,3 @@ class SortType(object): ...@@ -126,301 +152,3 @@ class SortType(object):
GLOBAL = 'global' GLOBAL = 'global'
POOL = 'pool' POOL = 'pool'
NONE = "none" NONE = "none"
class Converter(object):
def __init__(self, vocab, beg, end, unk, delimiter, add_beg, add_end):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
self._delimiter = delimiter
self._add_beg = add_beg
self._add_end = add_end
def __call__(self, sentence):
return ([self._beg] if self._add_beg else []) + [
self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter)
] + ([self._end] if self._add_end else [])
class ComposedConverter(object):
def __init__(self, converters):
self._converters = converters
def __call__(self, fields):
return [
converter(field)
for field, converter in zip(fields, self._converters)
]
class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self.batch) + 1) > self._batch_size:
result = self.batch
self.batch = [info]
self.max_len = cur_len
return result
else:
self.max_len = max_len
self.batch.append(info)
class SampleInfo(object):
def __init__(self, i, lens):
self.i = i
# take bos and eos into account
self.min_len = min(lens[0] + 1, lens[1] + 1)
self.max_len = max(lens[0] + 1, lens[1] + 1)
self.src_len = lens[0]
self.trg_len = lens[1]
class MinMaxFilter(object):
def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
def append(self, info):
if info.max_len > self._max_len or info.min_len < self._min_len:
return
else:
return self._creator.append(info)
@property
def batch(self):
return self._creator.batch
class TransformerDataset(Dataset):
def __init__(self,
src_vocab_fpath,
trg_vocab_fpath,
fpattern,
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
trg_fpattern=None):
self._src_vocab = self.load_dict(src_vocab_fpath)
self._trg_vocab = self.load_dict(trg_vocab_fpath)
self._bos_idx = self._src_vocab[start_mark]
self._eos_idx = self._src_vocab[end_mark]
self._unk_idx = self._src_vocab[unk_mark]
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self.load_src_trg_ids(fpattern, trg_fpattern)
def load_src_trg_ids(self, fpattern, trg_fpattern=None):
src_converter = Converter(
vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False,
add_end=False)
trg_converter = Converter(
vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False,
add_end=False)
converters = ComposedConverter([src_converter, trg_converter])
self._src_seq_ids = []
self._trg_seq_ids = []
self._sample_infos = []
slots = [self._src_seq_ids, self._trg_seq_ids]
for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)):
lens = []
for field, slot in zip(converters(line), slots):
slot.append(field)
lens.append(len(field))
self._sample_infos.append(SampleInfo(i, lens))
def _load_lines(self, fpattern, trg_fpattern=None):
fpaths = glob.glob(fpattern)
fpaths = sorted(fpaths) # TODO: Add custum sort
assert len(fpaths) > 0, "no matching file to the provided data path"
(f_mode, f_encoding, endl) = ("r", "utf8", "\n")
if trg_fpattern is None:
for fpath in fpaths:
with io.open(fpath, f_mode, encoding=f_encoding) as f:
for line in f:
fields = line.strip(endl).split(self._field_delimiter)
yield fields
else:
# separated source and target language data files
# assume we can get aligned data by sort the two language files
trg_fpaths = glob.glob(trg_fpattern)
trg_fpaths = sorted(trg_fpaths)
assert len(fpaths) == len(
trg_fpaths
), "the number of source language data files must equal \
with that of source language"
for fpath, trg_fpath in zip(fpaths, trg_fpaths):
with io.open(fpath, f_mode, encoding=f_encoding) as f:
with io.open(
trg_fpath, f_mode, encoding=f_encoding) as trg_f:
for line in zip(f, trg_f):
fields = [field.strip(endl) for field in line]
yield fields
@staticmethod
def load_dict(dict_path, reverse=False):
word_dict = {}
(f_mode, f_encoding, endl) = ("r", "utf8", "\n")
with io.open(dict_path, f_mode, encoding=f_encoding) as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip(endl)
else:
word_dict[line.strip(endl)] = idx
return word_dict
def get_vocab_summary(self):
return len(self._src_vocab), len(
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
def __getitem__(self, idx):
return (self._src_seq_ids[idx], self._trg_seq_ids[idx]
) if self._trg_seq_ids else self._src_seq_ids[idx]
def __len__(self):
return len(self._sample_infos)
class TransformerBatchSampler(BatchSampler):
def __init__(self,
dataset,
batch_size,
pool_size=10000,
sort_type=SortType.NONE,
min_length=0,
max_length=100,
shuffle=False,
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
distribute_mode=True,
seed=0,
world_size=1,
rank=0):
for arg, value in locals().items():
if arg != "self":
setattr(self, "_" + arg, value)
self._random = np.random
self._random.seed(seed)
# for multi-devices
self._distribute_mode = distribute_mode
self._nranks = world_size
self._local_rank = rank
def __iter__(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self._dataset._sample_infos, key=lambda x: x.trg_len)
infos = sorted(infos, key=lambda x: x.src_len)
else:
if self._shuffle:
infos = self._dataset._sample_infos
self._random.shuffle(infos)
else:
infos = self._dataset._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# to avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
batches = []
batch_creator = TokenBatchCreator(
self.
_batch_size) if self._use_token_batch else SentenceBatchCreator(
self._batch_size * self._nranks)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)
if self._shuffle_batch:
self._random.shuffle(batches)
if not self._use_token_batch:
# when producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches = [[
batch[self._batch_size * i:self._batch_size * (i + 1)]
for i in range(self._nranks)
] for batch in batches]
batches = list(itertools.chain.from_iterable(batches))
self.batch_number = (len(batches) + self._nranks - 1) // self._nranks
# for multi-device
for batch_id, batch in enumerate(batches):
if not self._distribute_mode or (
batch_id % self._nranks == self._local_rank):
batch_indices = [info.i for info in batch]
yield batch_indices
if self._distribute_mode and len(batches) % self._nranks != 0:
if self._local_rank >= len(batches) % self._nranks:
# use previous data to pad
yield batch_indices
def __len__(self):
if hasattr(self, "batch_number"): #
return self.batch_number
if not self._use_token_batch:
batch_number = (
len(self._dataset) + self._batch_size * self._nranks - 1) // (
self._batch_size * self._nranks)
else:
# for uncertain batch number, the actual value is self.batch_number
batch_number = sys.maxsize
return batch_number
...@@ -63,9 +63,7 @@ def do_train(args): ...@@ -63,9 +63,7 @@ def do_train(args):
paddle.seed(random_seed) paddle.seed(random_seed)
# Define data loader # Define data loader
# NOTE: To guarantee all data is involved, use world_size=1 and rank=0. (train_loader), (eval_loader) = reader.create_data_loader(args)
(train_loader, train_steps_fn), (
eval_loader, eval_steps_fn) = reader.create_data_loader(args)
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
......
...@@ -39,7 +39,7 @@ ...@@ -39,7 +39,7 @@
| 数据集名称 | 简介 | 调用方法 | | 数据集名称 | 简介 | 调用方法 |
| ---- | --------- | ------ | | ---- | --------- | ------ |
| [IWSLT15](https://workshop2015.iwslt.org/) | IWSLT'15 English-Vietnamese data 英语-越南语翻译数据集| `paddlenlp.datasets.IWSLT15`| | [IWSLT15](https://workshop2015.iwslt.org/) | IWSLT'15 English-Vietnamese data 英语-越南语翻译数据集| `paddlenlp.datasets.IWSLT15`|
| [WMT14](http://www.statmt.org/wmt14/translation-task.html) | WMT14 EN-DE 英语-德语翻译数据集| `paddlenlp.datasets.WMT14`| | [WMT14](http://www.statmt.org/wmt14/translation-task.html) | WMT14 EN-DE 英语-德语翻译数据集| `paddlenlp.datasets.WMT14ende`|
## 时序预测 ## 时序预测
......
...@@ -4,13 +4,12 @@ ...@@ -4,13 +4,12 @@
```text ```text
. .
├── images # README 文档中的图片 ├── images/ # README 文档中的图片
├── predict.py # 预测脚本 ├── predict.py # 预测脚本
├── reader.py # 数据读取接口 ├── reader.py # 数据读取接口
├── README.md # 文档 ├── README.md # 文档
├── train.py # 训练脚本 ├── train.py # 训练脚本
├── transformer.py # 模型定义文件 └── configs/ # 配置文件
└── transformer.yaml # 配置文件
``` ```
## 模型简介 ## 模型简介
...@@ -46,6 +45,15 @@ ...@@ -46,6 +45,15 @@
公开数据集:WMT 翻译大赛是机器翻译领域最具权威的国际评测大赛,其中英德翻译任务提供了一个中等规模的数据集,这个数据集是较多论文中使用的数据集,也是 Transformer 论文中用到的一个数据集。我们也将[WMT'14 EN-DE 数据集](http://www.statmt.org/wmt14/translation-task.html)作为示例提供。 公开数据集:WMT 翻译大赛是机器翻译领域最具权威的国际评测大赛,其中英德翻译任务提供了一个中等规模的数据集,这个数据集是较多论文中使用的数据集,也是 Transformer 论文中用到的一个数据集。我们也将[WMT'14 EN-DE 数据集](http://www.statmt.org/wmt14/translation-task.html)作为示例提供。
同时,我们提供了一份已经处理好的数据集,可以编写如下代码,对应的数据集将会自动下载并且解压到 `~/.paddlenlp/datasets/machine_translation/WMT14ende/`
``` python
# 获取默认的数据处理方式
transform_func = WMT14ende.get_default_transform_func(root=root)
# 下载并处理 WMT14.en-de 翻译数据集
dataset = WMT14ende.get_datasets(mode="train", transform_func=transform_func)
```
### 单机训练 ### 单机训练
### 单机单卡 ### 单机单卡
...@@ -55,10 +63,10 @@ ...@@ -55,10 +63,10 @@
```sh ```sh
# setting visible devices for training # setting visible devices for training
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python train.py python train.py --config ./configs/transformer.base.yaml
``` ```
可以在 transformer.yaml 文件中设置相应的参数,比如设置控制最大迭代次数的 `max_iter` 可以在 `configs/transformer.big.yaml``configs/transformer.base.yaml` 文件中设置相应的参数
### 单机多卡 ### 单机多卡
...@@ -66,7 +74,7 @@ python train.py ...@@ -66,7 +74,7 @@ python train.py
```sh ```sh
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" train.py python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" train.py --config ./configs/transformer.base.yaml
``` ```
...@@ -80,7 +88,7 @@ export CUDA_VISIBLE_DEVICES=0 ...@@ -80,7 +88,7 @@ export CUDA_VISIBLE_DEVICES=0
python predict.py python predict.py
``` ```
`predict_file` 指定的文件中文本的翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型所在目录,更多参数的使用可以在 `transformer.yaml` 文件中查阅注释说明并进行更改设置。需要注意的是,目前预测仅实现了单卡的预测,原因在于,翻译后面需要的模型评估依赖于预测结果写入文件顺序,多卡情况下,目前不能保证结果写入文件的顺序。 `predict_file` 指定的文件中文本的翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型所在目录,更多参数的使用可以在 `configs/transformer.big.yaml``configs/transformer.base.yaml` 文件中查阅注释说明并进行更改设置。需要注意的是,目前预测仅实现了单卡的预测,原因在于,翻译后面需要的模型评估依赖于预测结果写入文件顺序,多卡情况下,目前不能保证结果写入文件的顺序。
### 模型评估 ### 模型评估
...@@ -91,13 +99,13 @@ python predict.py ...@@ -91,13 +99,13 @@ python predict.py
# 还原 predict.txt 中的预测结果为 tokenize 后的数据 # 还原 predict.txt 中的预测结果为 tokenize 后的数据
sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt
# 若无 BLEU 评估工具,需先进行下载 # 若无 BLEU 评估工具,需先进行下载
# git clone https://github.com/moses-smt/mosesdecoder.git git clone https://github.com/moses-smt/mosesdecoder.git
# 以英德翻译 newstest2014 测试数据为例 # 以英德翻译 newstest2014 测试数据为例
perl gen_data/mosesdecoder/scripts/generic/multi-bleu.perl gen_data/wmt14_ende_data/newstest2014.tok.de < predict.tok.txt perl mosesdecoder/scripts/generic/multi-bleu.perl ~/.paddlenlp/datasets/machine_translation/WMT14ende/WMT14.en-de/wmt14_ende_data/newstest2014.tok.de < predict.tok.txt
``` ```
可以看到类似如下的结果: 可以看到类似如下的结果,此处结果是 big model 在 newstest2014 上的结果
``` ```
BLEU = 26.35, 57.7/32.1/20.0/13.0 (BP=1.000, ratio=1.013, hyp_len=63903, ref_len=63078) BLEU = 27.48, 58.6/33.2/21.1/13.9 (BP=1.000, ratio=1.012, hyp_len=65312, ref_len=64506)
``` ```
## 进阶使用 ## 进阶使用
......
...@@ -10,37 +10,23 @@ init_from_pretrain_model: "" ...@@ -10,37 +10,23 @@ init_from_pretrain_model: ""
init_from_params: "./trained_models/step_final/" init_from_params: "./trained_models/step_final/"
# The directory for saving model # The directory for saving model
save_model: "trained_models" save_model: "trained_models"
# The directory for saving inference model.
inference_model_dir: "infer_model"
# Set seed for CE or debug # Set seed for CE or debug
random_seed: None random_seed: None
# The pattern to match training data files.
training_file: "gen_data/wmt14_ende_data_bpe/train.tok.clean.bpe.33708.en-de"
# The pattern to match validation data files.
validation_file: "gen_data/wmt14_ende_data_bpe/newstest2013.tok.bpe.33708.en-de"
# The pattern to match test data files.
predict_file: "gen_data/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en-de"
# The file to output the translation results of predict_file to. # The file to output the translation results of predict_file to.
output_file: "predict.txt" output_file: "predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath: "gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The path of vocabulary file of target language.
trg_vocab_fpath: "gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The <bos>, <eos> and <unk> tokens in the dictionary. # The <bos>, <eos> and <unk> tokens in the dictionary.
special_token: ["<s>", "<e>", "<unk>"] special_token: ["<s>", "<e>", "<unk>"]
# The directory to store data.
root: None
# Whether to use cuda # Whether to use cuda
use_gpu: True use_gpu: True
# Args for reader, see reader.py for details # Args for reader, see reader.py for details
token_delimiter: " "
use_token_batch: True
pool_size: 200000 pool_size: 200000
sort_type: "pool" sort_type: "pool"
shuffle: True
shuffle_batch: True
batch_size: 4096 batch_size: 4096
infer_batch_size: 32 infer_batch_size: 8
# Hyparams for training: # Hyparams for training:
# The number of epoches for training # The number of epoches for training
......
...@@ -10,37 +10,23 @@ init_from_pretrain_model: "" ...@@ -10,37 +10,23 @@ init_from_pretrain_model: ""
init_from_params: "./trained_models/step_final/" init_from_params: "./trained_models/step_final/"
# The directory for saving model # The directory for saving model
save_model: "trained_models" save_model: "trained_models"
# The directory for saving inference model.
inference_model_dir: "infer_model"
# Set seed for CE or debug # Set seed for CE or debug
random_seed: None random_seed: None
# The pattern to match training data files.
training_file: "gen_data/wmt14_ende_data_bpe/train.tok.clean.bpe.33708.en-de"
# The pattern to match validation data files.
validation_file: "gen_data/wmt14_ende_data_bpe/newstest2013.tok.bpe.33708.en-de"
# The pattern to match test data files.
predict_file: "gen_data/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en-de"
# The file to output the translation results of predict_file to. # The file to output the translation results of predict_file to.
output_file: "predict.txt" output_file: "predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath: "gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The path of vocabulary file of target language.
trg_vocab_fpath: "gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The <bos>, <eos> and <unk> tokens in the dictionary. # The <bos>, <eos> and <unk> tokens in the dictionary.
special_token: ["<s>", "<e>", "<unk>"] special_token: ["<s>", "<e>", "<unk>"]
# The directory to store data.
root: None
# Whether to use cuda # Whether to use cuda
use_gpu: True use_gpu: True
# Args for reader, see reader.py for details # Args for reader, see reader.py for details
token_delimiter: " "
use_token_batch: True
pool_size: 200000 pool_size: 200000
sort_type: "pool" sort_type: "pool"
shuffle: True
shuffle_batch: True
batch_size: 4096 batch_size: 4096
infer_batch_size: 16 infer_batch_size: 8
# Hyparams for training: # Hyparams for training:
# The number of epoches for training # The number of epoches for training
......
#! /usr/bin/env bash
set -e
OUTPUT_DIR=$PWD/gen_data
###############################################################################
# change these variables for other WMT data
###############################################################################
OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_ende_data"
OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_ende_data_bpe"
LANG1="en"
LANG2="de"
# each of TRAIN_DATA: data_url data_file_lang1 data_file_lang2
TRAIN_DATA=(
'http://statmt.org/wmt13/training-parallel-europarl-v7.tgz'
'europarl-v7.de-en.en' 'europarl-v7.de-en.de'
'http://statmt.org/wmt13/training-parallel-commoncrawl.tgz'
'commoncrawl.de-en.en' 'commoncrawl.de-en.de'
'http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz'
'news-commentary-v12.de-en.en' 'news-commentary-v12.de-en.de'
)
# each of DEV_TEST_DATA: data_url data_file_lang1 data_file_lang2
# source & reference
DEV_TEST_DATA=(
'http://data.statmt.org/wmt17/translation-task/dev.tgz'
'newstest2013-ref.de.sgm' 'newstest2013-src.en.sgm'
'http://statmt.org/wmt14/test-full.tgz'
'newstest2014-deen-ref.en.sgm' 'newstest2014-deen-src.de.sgm'
)
###############################################################################
###############################################################################
# change these variables for other WMT data
###############################################################################
# OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_enfr_data"
# OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_enfr_data_bpe"
# LANG1="en"
# LANG2="fr"
# # each of TRAIN_DATA: ata_url data_tgz data_file
# TRAIN_DATA=(
# 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
# 'commoncrawl.fr-en.en' 'commoncrawl.fr-en.fr'
# 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz'
# 'training/europarl-v7.fr-en.en' 'training/europarl-v7.fr-en.fr'
# 'http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz'
# 'training/news-commentary-v9.fr-en.en' 'training/news-commentary-v9.fr-en.fr'
# 'http://www.statmt.org/wmt10/training-giga-fren.tar'
# 'giga-fren.release2.fixed.en.*' 'giga-fren.release2.fixed.fr.*'
# 'http://www.statmt.org/wmt13/training-parallel-un.tgz'
# 'un/undoc.2000.fr-en.en' 'un/undoc.2000.fr-en.fr'
# )
# # each of DEV_TEST_DATA: data_url data_tgz data_file_lang1 data_file_lang2
# DEV_TEST_DATA=(
# 'http://data.statmt.org/wmt16/translation-task/dev.tgz'
# '.*/newstest201[45]-fren-ref.en.sgm' '.*/newstest201[45]-fren-src.fr.sgm'
# 'http://data.statmt.org/wmt16/translation-task/test.tgz'
# '.*/newstest2016-fren-ref.en.sgm' '.*/newstest2016-fren-src.fr.sgm'
# )
###############################################################################
mkdir -p $OUTPUT_DIR_DATA $OUTPUT_DIR_BPE_DATA
# Extract training data
for ((i=0;i<${#TRAIN_DATA[@]};i+=3)); do
data_url=${TRAIN_DATA[i]}
data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
data=${data_tgz%.*} # training-parallel-commoncrawl
data_lang1=${TRAIN_DATA[i+1]}
data_lang2=${TRAIN_DATA[i+2]}
if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
echo "Download "${data_url}
echo "Dir "${OUTPUT_DIR_DATA}/${data_tgz}
wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
fi
if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
echo "Extract "${data_tgz}
mkdir -p ${OUTPUT_DIR_DATA}/${data}
tar_type=${data_tgz:0-3}
if [ ${tar_type} == "tar" ]; then
tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
else
tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
fi
fi
# concatenate all training data
for data_lang in $data_lang1 $data_lang2; do
for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
data_dir=`dirname $f`
data_file=`basename $f`
f_base=${f%.*}
f_ext=${f##*.}
if [ $f_ext == "gz" ]; then
gunzip $f
l=${f_base##*.}
f_base=${f_base%.*}
else
l=${f_ext}
fi
if [ $i -eq 0 ]; then
cat ${f_base}.$l > ${OUTPUT_DIR_DATA}/train.$l
else
cat ${f_base}.$l >> ${OUTPUT_DIR_DATA}/train.$l
fi
done
done
done
# Clone mosesdecoder
if [ ! -d ${OUTPUT_DIR}/mosesdecoder ]; then
echo "Cloning moses for data processing"
git clone https://github.com/moses-smt/mosesdecoder.git ${OUTPUT_DIR}/mosesdecoder
fi
# Extract develop and test data
dev_test_data=""
for ((i=0;i<${#DEV_TEST_DATA[@]};i+=3)); do
data_url=${DEV_TEST_DATA[i]}
data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
data=${data_tgz%.*} # training-parallel-commoncrawl
data_lang1=${DEV_TEST_DATA[i+1]}
data_lang2=${DEV_TEST_DATA[i+2]}
if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
echo "Download "${data_url}
wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
fi
if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
echo "Extract "${data_tgz}
mkdir -p ${OUTPUT_DIR_DATA}/${data}
tar_type=${data_tgz:0-3}
if [ ${tar_type} == "tar" ]; then
tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
else
tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
fi
fi
for data_lang in $data_lang1 $data_lang2; do
for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
echo "input-from-sgm"
data_dir=`dirname $f`
data_file=`basename $f`
data_out=`echo ${data_file} | cut -d '-' -f 1` # newstest2016
l=`echo ${data_file} | cut -d '.' -f 2` # en
dev_test_data="${dev_test_data}\|${data_out}" # to make regexp
if [ ! -e ${OUTPUT_DIR_DATA}/${data_out}.$l ]; then
${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
< $f > ${OUTPUT_DIR_DATA}/${data_out}.$l
fi
done
done
done
# Tokenize data
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train\|newstest2013\)\.$l$"`; do
f_base=${f%.*} # dir/train dir/newstest2013
f_out=$f_base.tok.$l
f_tmp=$f_base.tmp.$l
if [ ! -e $f_out ]; then
echo "Tokenize "$f
cat $f | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/normalize-punctuation.perl $l | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/remove-non-printing-char.perl | \
tee -a $tmp/valid.raw.$l | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l $l -threads 8 >> $f_out
echo $f_out
fi
done
done
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(newstest2014\)\.$l$"`; do
f_base=${f%.*} # dir/newstest2014
f_out=$f_base.tok.$l
if [ ! -e $f_out ]; then
echo "Tokenize "$f
cat $f | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l $l -threads 8 >> $f_out
echo $f_out
fi
done
done
# Clean data
for f in ${OUTPUT_DIR_DATA}/train.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.${LANG1}; do
f_base=${f%.*} # dir/train dir/train.tok
f_out=${f_base}.clean
if [ ! -e $f_out.${LANG1} ] && [ ! -e $f_out.${LANG2} ]; then
echo "Clean "${f_base}
${OUTPUT_DIR}/mosesdecoder/scripts/training/clean-corpus-n.perl $f_base ${LANG1} ${LANG2} ${f_out} 1 256
fi
done
python -m pip install subword-nmt
# Generate BPE data and vocabulary
for num_operations in 33708; do
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} ]; then
echo "Learn BPE with ${num_operations} merge operations"
cat ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG2} | \
subword-nmt learn-bpe -s $num_operations > ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations}
fi
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train${dev_test_data}\)\.tok\(\.clean\)\?\.$l$"`; do
f_base=${f%.*} # dir/train.tok dir/train.tok.clean dir/newstest2016.tok
f_base=${f_base##*/} # train.tok train.tok.clean newstest2016.tok
f_out=${OUTPUT_DIR_BPE_DATA}/${f_base}.bpe.${num_operations}.$l
if [ ! -e $f_out ]; then
echo "Apply BPE to "$f
subword-nmt apply-bpe -c ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} < $f > $f_out
fi
done
done
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} ]; then
echo "Create vocabulary for BPE data"
cat ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG1} ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG2} | \
subword-nmt get-vocab | cut -f1 -d ' ' > ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations}
fi
done
# Adapt to the reader
for f in ${OUTPUT_DIR_BPE_DATA}/*.bpe.${num_operations}.${LANG1}; do
f_base=${f%.*} # dir/train.tok.clean.bpe.32000 dir/newstest2016.tok.bpe.32000
f_out=${f_base}.${LANG1}-${LANG2}
if [ ! -e $f_out ]; then
paste -d '\t' $f_base.${LANG1} $f_base.${LANG2} > $f_out
fi
done
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations} ]; then
sed '1i\<s>\n<e>\n<unk>' ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} > ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations}
fi
echo "All done."
...@@ -48,8 +48,7 @@ def do_predict(args): ...@@ -48,8 +48,7 @@ def do_predict(args):
paddle.set_device(place) paddle.set_device(place)
# Define data loader # Define data loader
(test_loader, test_loader, to_tokens = reader.create_infer_loader(args)
test_steps_fn), trg_idx2word = reader.create_infer_loader(args)
# Define model # Define model
transformer = InferTransformerModel( transformer = InferTransformerModel(
...@@ -95,7 +94,7 @@ def do_predict(args): ...@@ -95,7 +94,7 @@ def do_predict(args):
if beam_idx >= args.n_best: if beam_idx >= args.n_best:
break break
id_list = post_process_seq(beam, args.bos_idx, args.eos_idx) id_list = post_process_seq(beam, args.bos_idx, args.eos_idx)
word_list = [trg_idx2word[id] for id in id_list] word_list = to_tokens(id_list)
sequence = " ".join(word_list) + "\n" sequence = " ".join(word_list) + "\n"
f.write(sequence) f.write(sequence)
......
...@@ -22,79 +22,105 @@ from functools import partial ...@@ -22,79 +22,105 @@ from functools import partial
import numpy as np import numpy as np
from paddle.io import BatchSampler, DataLoader, Dataset from paddle.io import BatchSampler, DataLoader, Dataset
from paddlenlp.data import Pad from paddlenlp.data import Pad
from paddlenlp.datasets import WMT14ende
from paddlenlp.data.sampler import SamplerHelper
def create_infer_loader(args): def min_max_filer(data, max_len, min_len=0):
dataset = TransformerDataset( # 1 for special tokens.
fpattern=args.predict_file, data_min_len = min(len(data[0]), len(data[1])) + 1
src_vocab_fpath=args.src_vocab_fpath, data_max_len = max(len(data[0]), len(data[1])) + 1
trg_vocab_fpath=args.trg_vocab_fpath, return (data_min_len >= min_len) and (data_max_len <= max_len)
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1], def create_data_loader(args):
unk_mark=args.special_token[2]) root = None if args.root == "None" else args.root
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
args.unk_idx = dataset.get_vocab_summary() args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
trg_idx2word = TransformerDataset.load_dict( transform_func = WMT14ende.get_default_transform_func(root=root)
dict_path=args.trg_vocab_fpath, reverse=True) datasets = [
batch_sampler = TransformerBatchSampler( WMT14ende.get_datasets(
dataset=dataset, mode=m, transform_func=transform_func) for m in ["train", "dev"]
use_token_batch=False, ]
batch_size=args.infer_batch_size,
max_length=args.max_length) def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
data_source):
return max(tokens_sofar,
len(data_source[current_idx][0]) + 1,
len(data_source[current_idx][1]) + 1)
def _key(size_so_far, minibatch_len):
return size_so_far * minibatch_len
data_loaders = [(None)] * 2
for i, dataset in enumerate(datasets):
m = dataset.mode
dataset = dataset.filter(
partial(
min_max_filer, max_len=args.max_length))
sampler = SamplerHelper(dataset)
src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
if args.sort_type == SortType.GLOBAL:
buffer_size = -1
trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
# Sort twice
sampler = sampler.sort(
key=trg_key, buffer_size=buffer_size).sort(
key=src_key, buffer_size=buffer_size)
else:
sampler = sampler.shuffle()
if args.sort_type == SortType.POOL:
buffer_size = args.pool_size
sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
batch_sampler = sampler.batch(
batch_size=args.batch_size,
drop_last=False,
batch_size_fn=_max_token_fn,
key=_key)
if m == "train":
batch_sampler = batch_sampler.shard()
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=partial( collate_fn=partial(
prepare_infer_input, prepare_train_input,
bos_idx=args.bos_idx, bos_idx=args.bos_idx,
eos_idx=args.eos_idx, eos_idx=args.eos_idx,
pad_idx=args.bos_idx), pad_idx=args.bos_idx),
num_workers=0, num_workers=0,
return_list=True) return_list=True)
data_loaders = (data_loader, batch_sampler.__len__) data_loaders[i] = (data_loader)
return data_loaders, trg_idx2word return data_loaders
def create_data_loader(args, world_size=1, rank=0): def create_infer_loader(args):
data_loaders = [(None, None)] * 2 root = None if args.root == "None" else args.root
data_files = [args.training_file, args.validation_file (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
] if args.validation_file else [args.training_file] args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
for i, data_file in enumerate(data_files): transform_func = WMT14ende.get_default_transform_func(root=root)
dataset = TransformerDataset( dataset = WMT14ende.get_datasets(
fpattern=data_file, mode="test", transform_func=transform_func).filter(
src_vocab_fpath=args.src_vocab_fpath, partial(
trg_vocab_fpath=args.trg_vocab_fpath, min_max_filer, max_len=args.max_length))
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0], batch_sampler = SamplerHelper(dataset).batch(
end_mark=args.special_token[1], batch_size=args.infer_batch_size, drop_last=False)
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = TransformerBatchSampler(
dataset=dataset,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
use_token_batch=args.use_token_batch,
max_length=args.max_length,
distribute_mode=True if i == 0 else False,
world_size=world_size,
rank=rank)
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=partial( collate_fn=partial(
prepare_train_input, prepare_infer_input,
bos_idx=args.bos_idx, bos_idx=args.bos_idx,
eos_idx=args.eos_idx, eos_idx=args.eos_idx,
pad_idx=args.bos_idx), pad_idx=args.bos_idx),
num_workers=0, num_workers=0,
return_list=True) return_list=True)
data_loaders[i] = (data_loader, batch_sampler.__len__) return data_loader, trg_vocab.to_tokens
return data_loaders
def prepare_train_input(insts, bos_idx, eos_idx, pad_idx): def prepare_train_input(insts, bos_idx, eos_idx, pad_idx):
...@@ -126,301 +152,3 @@ class SortType(object): ...@@ -126,301 +152,3 @@ class SortType(object):
GLOBAL = 'global' GLOBAL = 'global'
POOL = 'pool' POOL = 'pool'
NONE = "none" NONE = "none"
class Converter(object):
def __init__(self, vocab, beg, end, unk, delimiter, add_beg, add_end):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
self._delimiter = delimiter
self._add_beg = add_beg
self._add_end = add_end
def __call__(self, sentence):
return ([self._beg] if self._add_beg else []) + [
self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter)
] + ([self._end] if self._add_end else [])
class ComposedConverter(object):
def __init__(self, converters):
self._converters = converters
def __call__(self, fields):
return [
converter(field)
for field, converter in zip(fields, self._converters)
]
class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self.batch) + 1) > self._batch_size:
result = self.batch
self.batch = [info]
self.max_len = cur_len
return result
else:
self.max_len = max_len
self.batch.append(info)
class SampleInfo(object):
def __init__(self, i, lens):
self.i = i
# take bos and eos into account
self.min_len = min(lens[0] + 1, lens[1] + 1)
self.max_len = max(lens[0] + 1, lens[1] + 1)
self.src_len = lens[0]
self.trg_len = lens[1]
class MinMaxFilter(object):
def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
def append(self, info):
if info.max_len > self._max_len or info.min_len < self._min_len:
return
else:
return self._creator.append(info)
@property
def batch(self):
return self._creator.batch
class TransformerDataset(Dataset):
def __init__(self,
src_vocab_fpath,
trg_vocab_fpath,
fpattern,
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
trg_fpattern=None):
self._src_vocab = self.load_dict(src_vocab_fpath)
self._trg_vocab = self.load_dict(trg_vocab_fpath)
self._bos_idx = self._src_vocab[start_mark]
self._eos_idx = self._src_vocab[end_mark]
self._unk_idx = self._src_vocab[unk_mark]
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self.load_src_trg_ids(fpattern, trg_fpattern)
def load_src_trg_ids(self, fpattern, trg_fpattern=None):
src_converter = Converter(
vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False,
add_end=False)
trg_converter = Converter(
vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False,
add_end=False)
converters = ComposedConverter([src_converter, trg_converter])
self._src_seq_ids = []
self._trg_seq_ids = []
self._sample_infos = []
slots = [self._src_seq_ids, self._trg_seq_ids]
for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)):
lens = []
for field, slot in zip(converters(line), slots):
slot.append(field)
lens.append(len(field))
self._sample_infos.append(SampleInfo(i, lens))
def _load_lines(self, fpattern, trg_fpattern=None):
fpaths = glob.glob(fpattern)
fpaths = sorted(fpaths) # TODO: Add custum sort
assert len(fpaths) > 0, "no matching file to the provided data path"
(f_mode, f_encoding, endl) = ("r", "utf8", "\n")
if trg_fpattern is None:
for fpath in fpaths:
with io.open(fpath, f_mode, encoding=f_encoding) as f:
for line in f:
fields = line.strip(endl).split(self._field_delimiter)
yield fields
else:
# separated source and target language data files
# assume we can get aligned data by sort the two language files
trg_fpaths = glob.glob(trg_fpattern)
trg_fpaths = sorted(trg_fpaths)
assert len(fpaths) == len(
trg_fpaths
), "the number of source language data files must equal \
with that of source language"
for fpath, trg_fpath in zip(fpaths, trg_fpaths):
with io.open(fpath, f_mode, encoding=f_encoding) as f:
with io.open(
trg_fpath, f_mode, encoding=f_encoding) as trg_f:
for line in zip(f, trg_f):
fields = [field.strip(endl) for field in line]
yield fields
@staticmethod
def load_dict(dict_path, reverse=False):
word_dict = {}
(f_mode, f_encoding, endl) = ("r", "utf8", "\n")
with io.open(dict_path, f_mode, encoding=f_encoding) as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip(endl)
else:
word_dict[line.strip(endl)] = idx
return word_dict
def get_vocab_summary(self):
return len(self._src_vocab), len(
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
def __getitem__(self, idx):
return (self._src_seq_ids[idx], self._trg_seq_ids[idx]
) if self._trg_seq_ids else self._src_seq_ids[idx]
def __len__(self):
return len(self._sample_infos)
class TransformerBatchSampler(BatchSampler):
def __init__(self,
dataset,
batch_size,
pool_size=10000,
sort_type=SortType.NONE,
min_length=0,
max_length=100,
shuffle=False,
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
distribute_mode=True,
seed=0,
world_size=1,
rank=0):
for arg, value in locals().items():
if arg != "self":
setattr(self, "_" + arg, value)
self._random = np.random
self._random.seed(seed)
# for multi-devices
self._distribute_mode = distribute_mode
self._nranks = world_size
self._local_rank = rank
def __iter__(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self._dataset._sample_infos, key=lambda x: x.trg_len)
infos = sorted(infos, key=lambda x: x.src_len)
else:
if self._shuffle:
infos = self._dataset._sample_infos
self._random.shuffle(infos)
else:
infos = self._dataset._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# to avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
batches = []
batch_creator = TokenBatchCreator(
self.
_batch_size) if self._use_token_batch else SentenceBatchCreator(
self._batch_size * self._nranks)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)
if self._shuffle_batch:
self._random.shuffle(batches)
if not self._use_token_batch:
# when producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches = [[
batch[self._batch_size * i:self._batch_size * (i + 1)]
for i in range(self._nranks)
] for batch in batches]
batches = list(itertools.chain.from_iterable(batches))
self.batch_number = (len(batches) + self._nranks - 1) // self._nranks
# for multi-device
for batch_id, batch in enumerate(batches):
if not self._distribute_mode or (
batch_id % self._nranks == self._local_rank):
batch_indices = [info.i for info in batch]
yield batch_indices
if self._distribute_mode and len(batches) % self._nranks != 0:
if self._local_rank >= len(batches) % self._nranks:
# use previous data to pad
yield batch_indices
def __len__(self):
if hasattr(self, "batch_number"): #
return self.batch_number
if not self._use_token_batch:
batch_number = (
len(self._dataset) + self._batch_size * self._nranks - 1) // (
self._batch_size * self._nranks)
else:
# for uncertain batch number, the actual value is self.batch_number
batch_number = sys.maxsize
return batch_number
...@@ -43,9 +43,7 @@ def do_train(args): ...@@ -43,9 +43,7 @@ def do_train(args):
paddle.seed(random_seed) paddle.seed(random_seed)
# Define data loader # Define data loader
(train_loader, train_steps_fn), (eval_loader, (train_loader), (eval_loader) = reader.create_data_loader(args)
eval_steps_fn) = reader.create_data_loader(
args, trainer_count, rank)
# Define model # Define model
transformer = TransformerModel( transformer = TransformerModel(
...@@ -150,7 +148,6 @@ def do_train(args): ...@@ -150,7 +148,6 @@ def do_train(args):
if step_idx % args.save_step == 0 and step_idx != 0: if step_idx % args.save_step == 0 and step_idx != 0:
# Validation # Validation
if args.validation_file:
transformer.eval() transformer.eval()
total_sum_cost = 0 total_sum_cost = 0
total_token_num = 0 total_token_num = 0
......
...@@ -137,7 +137,7 @@ class SamplerHelper(object): ...@@ -137,7 +137,7 @@ class SamplerHelper(object):
""" """
Sort samples according to given callable cmp or key. Sort samples according to given callable cmp or key.
Args: Args:
cmp (callable): The funcation of comparison. Default: None. cmp (callable): The function of comparison. Default: None.
key (callable): Return element to be compared. Default: None. key (callable): Return element to be compared. Default: None.
reverse (bool): If True, it means in descending order, and False means in ascending order. Default: False. reverse (bool): If True, it means in descending order, and False means in ascending order. Default: False.
buffer_size (int): Buffer size for sort. If buffer_size < 0 or buffer_size is more than the length of the data, buffer_size (int): Buffer size for sort. If buffer_size < 0 or buffer_size is more than the length of the data,
......
...@@ -16,6 +16,7 @@ import collections ...@@ -16,6 +16,7 @@ import collections
import io import io
import json import json
import os import os
import warnings
class Vocab(object): class Vocab(object):
...@@ -179,7 +180,12 @@ class Vocab(object): ...@@ -179,7 +180,12 @@ class Vocab(object):
tokens = [] tokens = []
for idx in indices: for idx in indices:
if not isinstance(idx, int) or idx > max_idx: if not isinstance(idx, int):
warnings.warn(
"The type of `to_tokens()`'s input `indices` is not `int` which will be forcibly transfered to `int`. "
)
idx = int(idx)
if idx > max_idx:
raise ValueError( raise ValueError(
'Token index {} in the provided `indices` is invalid.'. 'Token index {} in the provided `indices` is invalid.'.
format(idx)) format(idx))
......
...@@ -13,7 +13,7 @@ from paddlenlp.data.sampler import SamplerHelper ...@@ -13,7 +13,7 @@ from paddlenlp.data.sampler import SamplerHelper
from paddlenlp.utils.env import DATA_HOME from paddlenlp.utils.env import DATA_HOME
from paddle.dataset.common import md5file from paddle.dataset.common import md5file
__all__ = ['TranslationDataset', 'IWSLT15'] __all__ = ['TranslationDataset', 'IWSLT15', 'WMT14ende']
def sequential_transforms(*transforms): def sequential_transforms(*transforms):
...@@ -29,8 +29,8 @@ def get_default_tokenizer(): ...@@ -29,8 +29,8 @@ def get_default_tokenizer():
"""Only support split tokenizer """Only support split tokenizer
""" """
def _split_tokenizer(x): def _split_tokenizer(x, delimiter=None):
return x.split() return x.split(delimiter)
return _split_tokenizer return _split_tokenizer
...@@ -50,9 +50,9 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -50,9 +50,9 @@ class TranslationDataset(paddle.io.Dataset):
MD5 = None MD5 = None
VOCAB_INFO = None VOCAB_INFO = None
UNK_TOKEN = None UNK_TOKEN = None
PAD_TOKEN = None
BOS_TOKEN = None BOS_TOKEN = None
EOS_TOKEN = None EOS_TOKEN = None
PAD_TOKEN = None
def __init__(self, data): def __init__(self, data):
self.data = data self.data = data
...@@ -143,14 +143,14 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -143,14 +143,14 @@ class TranslationDataset(paddle.io.Dataset):
tgt_file_path = os.path.join(root, tgt_vocab_filename) tgt_file_path = os.path.join(root, tgt_vocab_filename)
src_vocab = Vocab.load_vocabulary( src_vocab = Vocab.load_vocabulary(
src_file_path, filepath=src_file_path,
unk_token=cls.UNK_TOKEN, unk_token=cls.UNK_TOKEN,
pad_token=cls.PAD_TOKEN, pad_token=cls.PAD_TOKEN,
bos_token=cls.BOS_TOKEN, bos_token=cls.BOS_TOKEN,
eos_token=cls.EOS_TOKEN) eos_token=cls.EOS_TOKEN)
tgt_vocab = Vocab.load_vocabulary( tgt_vocab = Vocab.load_vocabulary(
tgt_file_path, filepath=tgt_file_path,
unk_token=cls.UNK_TOKEN, unk_token=cls.UNK_TOKEN,
pad_token=cls.PAD_TOKEN, pad_token=cls.PAD_TOKEN,
bos_token=cls.BOS_TOKEN, bos_token=cls.BOS_TOKEN,
...@@ -273,6 +273,90 @@ class IWSLT15(TranslationDataset): ...@@ -273,6 +273,90 @@ class IWSLT15(TranslationDataset):
transform_func[1](data[1])) for data in self.data] transform_func[1](data[1])) for data in self.data]
class WMT14ende(TranslationDataset):
"""
WMT14 English to German translation dataset.
Args:
mode(str, optional): It could be 'train', 'dev' or 'test'. Default: 'train'.
root(str, optional): If None, dataset will be downloaded in
`/root/.paddlenlp/datasets/machine_translation/WMT14ende/`. Default: None.
transform_func(callable, optional): If not None, it transforms raw data
to index data. Default: None.
Examples:
.. code-block:: python
from paddlenlp.datasets import WMT14ende
transform_func = WMT14ende.get_default_transform_func(root=root)
train_dataset = WMT14ende.get_datasets(mode="train", transform_func=transform_func)
"""
URL = "https://paddlenlp.bj.bcebos.com/datasets/WMT14.en-de.tar.gz"
SPLITS = {
'train': TranslationDataset.META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"train.tok.clean.bpe.33708.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"train.tok.clean.bpe.33708.de"),
"c7c0b77e672fc69f20be182ae37ff62c",
"1865ece46948fda1209d3b7794770a0a"),
'dev': TranslationDataset.META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"newstest2013.tok.bpe.33708.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"newstest2013.tok.bpe.33708.de"),
"aa4228a4bedb6c45d67525fbfbcee75e",
"9b1eeaff43a6d5e78a381a9b03170501"),
'test': TranslationDataset.META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"newstest2014.tok.bpe.33708.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"newstest2014.tok.bpe.33708.de"),
"c9403eacf623c6e2d9e5a1155bdff0b5",
"0058855b55e37c4acfcb8cffecba1050"),
'dev-eval': TranslationDataset.META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data",
"newstest2013.tok.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data",
"newstest2013.tok.de"),
"d74712eb35578aec022265c439831b0e",
"6ff76ced35b70e63a61ecec77a1c418f"),
'test-eval': TranslationDataset.META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data",
"newstest2014.tok.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data",
"newstest2014.tok.de"),
"8cce2028e4ca3d4cc039dfd33adbfb43",
"a1b1f4c47f487253e1ac88947b68b3b8")
}
VOCAB_INFO = (os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"vocab_all.bpe.33708"),
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"vocab_all.bpe.33708"),
"2fc775b7df37368e936a8e1f63846bb0",
"2fc775b7df37368e936a8e1f63846bb0")
UNK_TOKEN = "<unk>"
BOS_TOKEN = "<s>"
EOS_TOKEN = "<e>"
MD5 = "5506d213dba4124121c682368257bae4"
def __init__(self, mode="train", root=None, transform_func=None):
if mode not in ("train", "dev", "test", "dev-eval", "test-eval"):
raise TypeError(
'`train`, `dev`, `test`, `dev-eval` or `test-eval` is supported but `{}` is passed in'.
format(mode))
if transform_func is not None and len(transform_func) != 2:
if len(transform_func) != 2:
raise ValueError("`transform_func` must have length of two for"
"source and target.")
self.data = WMT14ende.get_data(mode=mode, root=root)
self.mode = mode
if transform_func is not None:
self.data = [(transform_func[0](data[0]),
transform_func[1](data[1])) for data in self.data]
super(WMT14ende, self).__init__(self.data)
# For test, not API # For test, not API
def prepare_train_input(insts, pad_id): def prepare_train_input(insts, pad_id):
src, src_length = Pad(pad_val=pad_id, ret_length=True)( src, src_length = Pad(pad_val=pad_id, ret_length=True)(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册