未验证 提交 30ccfc67 编写于 作者: L liu zhengxi 提交者: GitHub

[Transformer] Simplify transformer reader and fix TranslationDataset (#5035)

* fix translation dataset and simplify transformer reader
上级 8c9d8f56
......@@ -10,35 +10,21 @@ init_from_pretrain_model: ""
init_from_params: "./trained_models/step_final/"
# The directory for saving model
save_model: "trained_models"
# The directory for saving inference model.
inference_model_dir: "infer_model"
# Set seed for CE or debug
random_seed: None
# The pattern to match training data files.
training_file: "../gen_data/wmt14_ende_data_bpe/train.tok.clean.bpe.33708.en-de"
# The pattern to match validation data files.
validation_file: "../gen_data/wmt14_ende_data_bpe/newstest2013.tok.bpe.33708.en-de"
# The pattern to match test data files.
predict_file: "../gen_data/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en-de"
# The file to output the translation results of predict_file to.
output_file: "predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath: "../gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The path of vocabulary file of target language.
trg_vocab_fpath: "../gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token: ["<s>", "<e>", "<unk>"]
# The directory to store data.
root: None
# Whether to use cuda
use_gpu: True
# Args for reader, see reader.py for details
token_delimiter: " "
use_token_batch: True
pool_size: 200000
sort_type: "global"
shuffle: False
shuffle_batch: False
batch_size: 4096
infer_batch_size: 16
......
......@@ -52,8 +52,7 @@ def do_predict(args):
paddle.set_device(place)
# Define data loader
(test_loader,
test_steps_fn), trg_idx2word = reader.create_infer_loader(args)
test_loader, to_tokens = reader.create_infer_loader(args)
# Define model
transformer = InferTransformerModel(
......@@ -90,6 +89,7 @@ def do_predict(args):
transformer.eval()
f = open(args.output_file, "w")
with paddle.no_grad():
for (src_word, ) in test_loader:
finished_seq = transformer(src_word=src_word)
finished_seq = finished_seq.numpy().transpose([0, 2, 1])
......@@ -98,7 +98,7 @@ def do_predict(args):
if beam_idx >= args.n_best:
break
id_list = post_process_seq(beam, args.bos_idx, args.eos_idx)
word_list = [trg_idx2word[id] for id in id_list]
word_list = to_tokens(id_list)
sequence = " ".join(word_list) + "\n"
f.write(sequence)
......
......@@ -51,9 +51,7 @@ def do_train(args):
paddle.seed(random_seed)
# Define data loader
(train_loader, train_steps_fn), (eval_loader,
eval_steps_fn) = reader.create_data_loader(
args, trainer_count, rank)
(train_loader), (eval_loader) = reader.create_data_loader(args)
# Define model
transformer = TransformerModel(
......@@ -176,7 +174,6 @@ def do_train(args):
if step_idx % args.save_step == 0 and step_idx != 0:
# Validation
if args.validation_file:
transformer.eval()
total_sum_cost = 0
total_token_num = 0
......
#! /usr/bin/env bash
set -e
OUTPUT_DIR=$PWD/gen_data
###############################################################################
# change these variables for other WMT data
###############################################################################
OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_ende_data"
OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_ende_data_bpe"
LANG1="en"
LANG2="de"
# each of TRAIN_DATA: data_url data_file_lang1 data_file_lang2
TRAIN_DATA=(
'http://statmt.org/wmt13/training-parallel-europarl-v7.tgz'
'europarl-v7.de-en.en' 'europarl-v7.de-en.de'
'http://statmt.org/wmt13/training-parallel-commoncrawl.tgz'
'commoncrawl.de-en.en' 'commoncrawl.de-en.de'
'http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz'
'news-commentary-v12.de-en.en' 'news-commentary-v12.de-en.de'
)
# each of DEV_TEST_DATA: data_url data_file_lang1 data_file_lang2
# source & reference
DEV_TEST_DATA=(
'http://data.statmt.org/wmt17/translation-task/dev.tgz'
'newstest2013-ref.de.sgm' 'newstest2013-src.en.sgm'
'http://statmt.org/wmt14/test-full.tgz'
'newstest2014-deen-ref.en.sgm' 'newstest2014-deen-src.de.sgm'
)
###############################################################################
###############################################################################
# change these variables for other WMT data
###############################################################################
# OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_enfr_data"
# OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_enfr_data_bpe"
# LANG1="en"
# LANG2="fr"
# # each of TRAIN_DATA: ata_url data_tgz data_file
# TRAIN_DATA=(
# 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
# 'commoncrawl.fr-en.en' 'commoncrawl.fr-en.fr'
# 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz'
# 'training/europarl-v7.fr-en.en' 'training/europarl-v7.fr-en.fr'
# 'http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz'
# 'training/news-commentary-v9.fr-en.en' 'training/news-commentary-v9.fr-en.fr'
# 'http://www.statmt.org/wmt10/training-giga-fren.tar'
# 'giga-fren.release2.fixed.en.*' 'giga-fren.release2.fixed.fr.*'
# 'http://www.statmt.org/wmt13/training-parallel-un.tgz'
# 'un/undoc.2000.fr-en.en' 'un/undoc.2000.fr-en.fr'
# )
# # each of DEV_TEST_DATA: data_url data_tgz data_file_lang1 data_file_lang2
# DEV_TEST_DATA=(
# 'http://data.statmt.org/wmt16/translation-task/dev.tgz'
# '.*/newstest201[45]-fren-ref.en.sgm' '.*/newstest201[45]-fren-src.fr.sgm'
# 'http://data.statmt.org/wmt16/translation-task/test.tgz'
# '.*/newstest2016-fren-ref.en.sgm' '.*/newstest2016-fren-src.fr.sgm'
# )
###############################################################################
mkdir -p $OUTPUT_DIR_DATA $OUTPUT_DIR_BPE_DATA
# Extract training data
for ((i=0;i<${#TRAIN_DATA[@]};i+=3)); do
data_url=${TRAIN_DATA[i]}
data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
data=${data_tgz%.*} # training-parallel-commoncrawl
data_lang1=${TRAIN_DATA[i+1]}
data_lang2=${TRAIN_DATA[i+2]}
if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
echo "Download "${data_url}
echo "Dir "${OUTPUT_DIR_DATA}/${data_tgz}
wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
fi
if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
echo "Extract "${data_tgz}
mkdir -p ${OUTPUT_DIR_DATA}/${data}
tar_type=${data_tgz:0-3}
if [ ${tar_type} == "tar" ]; then
tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
else
tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
fi
fi
# concatenate all training data
for data_lang in $data_lang1 $data_lang2; do
for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
data_dir=`dirname $f`
data_file=`basename $f`
f_base=${f%.*}
f_ext=${f##*.}
if [ $f_ext == "gz" ]; then
gunzip $f
l=${f_base##*.}
f_base=${f_base%.*}
else
l=${f_ext}
fi
if [ $i -eq 0 ]; then
cat ${f_base}.$l > ${OUTPUT_DIR_DATA}/train.$l
else
cat ${f_base}.$l >> ${OUTPUT_DIR_DATA}/train.$l
fi
done
done
done
# Clone mosesdecoder
if [ ! -d ${OUTPUT_DIR}/mosesdecoder ]; then
echo "Cloning moses for data processing"
git clone https://github.com/moses-smt/mosesdecoder.git ${OUTPUT_DIR}/mosesdecoder
fi
# Extract develop and test data
dev_test_data=""
for ((i=0;i<${#DEV_TEST_DATA[@]};i+=3)); do
data_url=${DEV_TEST_DATA[i]}
data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
data=${data_tgz%.*} # training-parallel-commoncrawl
data_lang1=${DEV_TEST_DATA[i+1]}
data_lang2=${DEV_TEST_DATA[i+2]}
if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
echo "Download "${data_url}
wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
fi
if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
echo "Extract "${data_tgz}
mkdir -p ${OUTPUT_DIR_DATA}/${data}
tar_type=${data_tgz:0-3}
if [ ${tar_type} == "tar" ]; then
tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
else
tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
fi
fi
for data_lang in $data_lang1 $data_lang2; do
for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
echo "input-from-sgm"
data_dir=`dirname $f`
data_file=`basename $f`
data_out=`echo ${data_file} | cut -d '-' -f 1` # newstest2016
l=`echo ${data_file} | cut -d '.' -f 2` # en
dev_test_data="${dev_test_data}\|${data_out}" # to make regexp
if [ ! -e ${OUTPUT_DIR_DATA}/${data_out}.$l ]; then
${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
< $f > ${OUTPUT_DIR_DATA}/${data_out}.$l
fi
done
done
done
# Tokenize data
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train\|newstest2013\)\.$l$"`; do
f_base=${f%.*} # dir/train dir/newstest2013
f_out=$f_base.tok.$l
f_tmp=$f_base.tmp.$l
if [ ! -e $f_out ]; then
echo "Tokenize "$f
cat $f | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/normalize-punctuation.perl $l | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/remove-non-printing-char.perl | \
tee -a $tmp/valid.raw.$l | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l $l -threads 8 >> $f_out
echo $f_out
fi
done
done
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(newstest2014\)\.$l$"`; do
f_base=${f%.*} # dir/newstest2014
f_out=$f_base.tok.$l
if [ ! -e $f_out ]; then
echo "Tokenize "$f
cat $f | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l $l -threads 8 >> $f_out
echo $f_out
fi
done
done
# Clean data
for f in ${OUTPUT_DIR_DATA}/train.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.${LANG1}; do
f_base=${f%.*} # dir/train dir/train.tok
f_out=${f_base}.clean
if [ ! -e $f_out.${LANG1} ] && [ ! -e $f_out.${LANG2} ]; then
echo "Clean "${f_base}
${OUTPUT_DIR}/mosesdecoder/scripts/training/clean-corpus-n.perl $f_base ${LANG1} ${LANG2} ${f_out} 1 256
fi
done
python -m pip install subword-nmt
# Generate BPE data and vocabulary
for num_operations in 33708; do
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} ]; then
echo "Learn BPE with ${num_operations} merge operations"
cat ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG2} | \
subword-nmt learn-bpe -s $num_operations > ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations}
fi
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train${dev_test_data}\)\.tok\(\.clean\)\?\.$l$"`; do
f_base=${f%.*} # dir/train.tok dir/train.tok.clean dir/newstest2016.tok
f_base=${f_base##*/} # train.tok train.tok.clean newstest2016.tok
f_out=${OUTPUT_DIR_BPE_DATA}/${f_base}.bpe.${num_operations}.$l
if [ ! -e $f_out ]; then
echo "Apply BPE to "$f
subword-nmt apply-bpe -c ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} < $f > $f_out
fi
done
done
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} ]; then
echo "Create vocabulary for BPE data"
cat ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG1} ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG2} | \
subword-nmt get-vocab | cut -f1 -d ' ' > ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations}
fi
done
# Adapt to the reader
for f in ${OUTPUT_DIR_BPE_DATA}/*.bpe.${num_operations}.${LANG1}; do
f_base=${f%.*} # dir/train.tok.clean.bpe.32000 dir/newstest2016.tok.bpe.32000
f_out=${f_base}.${LANG1}-${LANG2}
if [ ! -e $f_out ]; then
paste -d '\t' $f_base.${LANG1} $f_base.${LANG2} > $f_out
fi
done
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations} ]; then
sed '1i\<s>\n<e>\n<unk>' ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} > ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations}
fi
echo "All done."
......@@ -22,79 +22,105 @@ from functools import partial
import numpy as np
from paddle.io import BatchSampler, DataLoader, Dataset
from paddlenlp.data import Pad
from paddlenlp.datasets import WMT14ende
from paddlenlp.data.sampler import SamplerHelper
def create_infer_loader(args):
dataset = TransformerDataset(
fpattern=args.predict_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
trg_idx2word = TransformerDataset.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
batch_sampler = TransformerBatchSampler(
dataset=dataset,
use_token_batch=False,
batch_size=args.infer_batch_size,
max_length=args.max_length)
def min_max_filer(data, max_len, min_len=0):
# 1 for special tokens.
data_min_len = min(len(data[0]), len(data[1])) + 1
data_max_len = max(len(data[0]), len(data[1])) + 1
return (data_min_len >= min_len) and (data_max_len <= max_len)
def create_data_loader(args):
root = None if args.root == "None" else args.root
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
transform_func = WMT14ende.get_default_transform_func(root=root)
datasets = [
WMT14ende.get_datasets(
mode=m, transform_func=transform_func) for m in ["train", "dev"]
]
def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
data_source):
return max(tokens_sofar,
len(data_source[current_idx][0]) + 1,
len(data_source[current_idx][1]) + 1)
def _key(size_so_far, minibatch_len):
return size_so_far * minibatch_len
data_loaders = [(None)] * 2
for i, dataset in enumerate(datasets):
m = dataset.mode
dataset = dataset.filter(
partial(
min_max_filer, max_len=args.max_length))
sampler = SamplerHelper(dataset)
src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
if args.sort_type == SortType.GLOBAL:
buffer_size = -1
trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
# Sort twice
sampler = sampler.sort(
key=trg_key, buffer_size=buffer_size).sort(
key=src_key, buffer_size=buffer_size)
else:
sampler = sampler.shuffle()
if args.sort_type == SortType.POOL:
buffer_size = args.pool_size
sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
batch_sampler = sampler.batch(
batch_size=args.batch_size,
drop_last=False,
batch_size_fn=_max_token_fn,
key=_key)
if m == "train":
batch_sampler = batch_sampler.shard()
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
collate_fn=partial(
prepare_infer_input,
prepare_train_input,
bos_idx=args.bos_idx,
eos_idx=args.eos_idx,
pad_idx=args.eos_idx),
pad_idx=args.bos_idx),
num_workers=0,
return_list=True)
data_loaders = (data_loader, batch_sampler.__len__)
return data_loaders, trg_idx2word
data_loaders[i] = (data_loader)
return data_loaders
def create_data_loader(args, world_size=1, rank=0):
data_loaders = [(None, None)] * 2
data_files = [args.training_file, args.validation_file
] if args.validation_file else [args.training_file]
for i, data_file in enumerate(data_files):
dataset = TransformerDataset(
fpattern=data_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = TransformerBatchSampler(
dataset=dataset,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
use_token_batch=args.use_token_batch,
max_length=args.max_length,
distribute_mode=True if i == 0 else False,
world_size=world_size,
rank=rank)
def create_infer_loader(args):
root = None if args.root == "None" else args.root
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
transform_func = WMT14ende.get_default_transform_func(root=root)
dataset = WMT14ende.get_datasets(
mode="test", transform_func=transform_func).filter(
partial(
min_max_filer, max_len=args.max_length))
batch_sampler = SamplerHelper(dataset).batch(
batch_size=args.infer_batch_size, drop_last=False)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
collate_fn=partial(
prepare_train_input,
prepare_infer_input,
bos_idx=args.bos_idx,
eos_idx=args.eos_idx,
pad_idx=args.bos_idx),
num_workers=0,
return_list=True)
data_loaders[i] = (data_loader, batch_sampler.__len__)
return data_loaders
return data_loader, trg_vocab.to_tokens
def prepare_train_input(insts, bos_idx, eos_idx, pad_idx):
......@@ -126,301 +152,3 @@ class SortType(object):
GLOBAL = 'global'
POOL = 'pool'
NONE = "none"
class Converter(object):
def __init__(self, vocab, beg, end, unk, delimiter, add_beg, add_end):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
self._delimiter = delimiter
self._add_beg = add_beg
self._add_end = add_end
def __call__(self, sentence):
return ([self._beg] if self._add_beg else []) + [
self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter)
] + ([self._end] if self._add_end else [])
class ComposedConverter(object):
def __init__(self, converters):
self._converters = converters
def __call__(self, fields):
return [
converter(field)
for field, converter in zip(fields, self._converters)
]
class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self.batch) + 1) > self._batch_size:
result = self.batch
self.batch = [info]
self.max_len = cur_len
return result
else:
self.max_len = max_len
self.batch.append(info)
class SampleInfo(object):
def __init__(self, i, lens):
self.i = i
# take bos and eos into account
self.min_len = min(lens[0] + 1, lens[1] + 1)
self.max_len = max(lens[0] + 1, lens[1] + 1)
self.src_len = lens[0]
self.trg_len = lens[1]
class MinMaxFilter(object):
def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
def append(self, info):
if info.max_len > self._max_len or info.min_len < self._min_len:
return
else:
return self._creator.append(info)
@property
def batch(self):
return self._creator.batch
class TransformerDataset(Dataset):
def __init__(self,
src_vocab_fpath,
trg_vocab_fpath,
fpattern,
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
trg_fpattern=None):
self._src_vocab = self.load_dict(src_vocab_fpath)
self._trg_vocab = self.load_dict(trg_vocab_fpath)
self._bos_idx = self._src_vocab[start_mark]
self._eos_idx = self._src_vocab[end_mark]
self._unk_idx = self._src_vocab[unk_mark]
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self.load_src_trg_ids(fpattern, trg_fpattern)
def load_src_trg_ids(self, fpattern, trg_fpattern=None):
src_converter = Converter(
vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False,
add_end=False)
trg_converter = Converter(
vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False,
add_end=False)
converters = ComposedConverter([src_converter, trg_converter])
self._src_seq_ids = []
self._trg_seq_ids = []
self._sample_infos = []
slots = [self._src_seq_ids, self._trg_seq_ids]
for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)):
lens = []
for field, slot in zip(converters(line), slots):
slot.append(field)
lens.append(len(field))
self._sample_infos.append(SampleInfo(i, lens))
def _load_lines(self, fpattern, trg_fpattern=None):
fpaths = glob.glob(fpattern)
fpaths = sorted(fpaths) # TODO: Add custum sort
assert len(fpaths) > 0, "no matching file to the provided data path"
(f_mode, f_encoding, endl) = ("r", "utf8", "\n")
if trg_fpattern is None:
for fpath in fpaths:
with io.open(fpath, f_mode, encoding=f_encoding) as f:
for line in f:
fields = line.strip(endl).split(self._field_delimiter)
yield fields
else:
# separated source and target language data files
# assume we can get aligned data by sort the two language files
trg_fpaths = glob.glob(trg_fpattern)
trg_fpaths = sorted(trg_fpaths)
assert len(fpaths) == len(
trg_fpaths
), "the number of source language data files must equal \
with that of source language"
for fpath, trg_fpath in zip(fpaths, trg_fpaths):
with io.open(fpath, f_mode, encoding=f_encoding) as f:
with io.open(
trg_fpath, f_mode, encoding=f_encoding) as trg_f:
for line in zip(f, trg_f):
fields = [field.strip(endl) for field in line]
yield fields
@staticmethod
def load_dict(dict_path, reverse=False):
word_dict = {}
(f_mode, f_encoding, endl) = ("r", "utf8", "\n")
with io.open(dict_path, f_mode, encoding=f_encoding) as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip(endl)
else:
word_dict[line.strip(endl)] = idx
return word_dict
def get_vocab_summary(self):
return len(self._src_vocab), len(
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
def __getitem__(self, idx):
return (self._src_seq_ids[idx], self._trg_seq_ids[idx]
) if self._trg_seq_ids else self._src_seq_ids[idx]
def __len__(self):
return len(self._sample_infos)
class TransformerBatchSampler(BatchSampler):
def __init__(self,
dataset,
batch_size,
pool_size=10000,
sort_type=SortType.NONE,
min_length=0,
max_length=100,
shuffle=False,
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
distribute_mode=True,
seed=0,
world_size=1,
rank=0):
for arg, value in locals().items():
if arg != "self":
setattr(self, "_" + arg, value)
self._random = np.random
self._random.seed(seed)
# for multi-devices
self._distribute_mode = distribute_mode
self._nranks = world_size
self._local_rank = rank
def __iter__(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self._dataset._sample_infos, key=lambda x: x.trg_len)
infos = sorted(infos, key=lambda x: x.src_len)
else:
if self._shuffle:
infos = self._dataset._sample_infos
self._random.shuffle(infos)
else:
infos = self._dataset._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# to avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
batches = []
batch_creator = TokenBatchCreator(
self.
_batch_size) if self._use_token_batch else SentenceBatchCreator(
self._batch_size * self._nranks)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)
if self._shuffle_batch:
self._random.shuffle(batches)
if not self._use_token_batch:
# when producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches = [[
batch[self._batch_size * i:self._batch_size * (i + 1)]
for i in range(self._nranks)
] for batch in batches]
batches = list(itertools.chain.from_iterable(batches))
self.batch_number = (len(batches) + self._nranks - 1) // self._nranks
# for multi-device
for batch_id, batch in enumerate(batches):
if not self._distribute_mode or (
batch_id % self._nranks == self._local_rank):
batch_indices = [info.i for info in batch]
yield batch_indices
if self._distribute_mode and len(batches) % self._nranks != 0:
if self._local_rank >= len(batches) % self._nranks:
# use previous data to pad
yield batch_indices
def __len__(self):
if hasattr(self, "batch_number"): #
return self.batch_number
if not self._use_token_batch:
batch_number = (
len(self._dataset) + self._batch_size * self._nranks - 1) // (
self._batch_size * self._nranks)
else:
# for uncertain batch number, the actual value is self.batch_number
batch_number = sys.maxsize
return batch_number
......@@ -63,9 +63,7 @@ def do_train(args):
paddle.seed(random_seed)
# Define data loader
# NOTE: To guarantee all data is involved, use world_size=1 and rank=0.
(train_loader, train_steps_fn), (
eval_loader, eval_steps_fn) = reader.create_data_loader(args)
(train_loader), (eval_loader) = reader.create_data_loader(args)
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
......
......@@ -39,7 +39,7 @@
| 数据集名称 | 简介 | 调用方法 |
| ---- | --------- | ------ |
| [IWSLT15](https://workshop2015.iwslt.org/) | IWSLT'15 English-Vietnamese data 英语-越南语翻译数据集| `paddlenlp.datasets.IWSLT15`|
| [WMT14](http://www.statmt.org/wmt14/translation-task.html) | WMT14 EN-DE 英语-德语翻译数据集| `paddlenlp.datasets.WMT14`|
| [WMT14](http://www.statmt.org/wmt14/translation-task.html) | WMT14 EN-DE 英语-德语翻译数据集| `paddlenlp.datasets.WMT14ende`|
## 时序预测
......
......@@ -4,13 +4,12 @@
```text
.
├── images # README 文档中的图片
├── images/ # README 文档中的图片
├── predict.py # 预测脚本
├── reader.py # 数据读取接口
├── README.md # 文档
├── train.py # 训练脚本
├── transformer.py # 模型定义文件
└── transformer.yaml # 配置文件
└── configs/ # 配置文件
```
## 模型简介
......@@ -46,6 +45,15 @@
公开数据集:WMT 翻译大赛是机器翻译领域最具权威的国际评测大赛,其中英德翻译任务提供了一个中等规模的数据集,这个数据集是较多论文中使用的数据集,也是 Transformer 论文中用到的一个数据集。我们也将[WMT'14 EN-DE 数据集](http://www.statmt.org/wmt14/translation-task.html)作为示例提供。
同时,我们提供了一份已经处理好的数据集,可以编写如下代码,对应的数据集将会自动下载并且解压到 `~/.paddlenlp/datasets/machine_translation/WMT14ende/`
``` python
# 获取默认的数据处理方式
transform_func = WMT14ende.get_default_transform_func(root=root)
# 下载并处理 WMT14.en-de 翻译数据集
dataset = WMT14ende.get_datasets(mode="train", transform_func=transform_func)
```
### 单机训练
### 单机单卡
......@@ -55,10 +63,10 @@
```sh
# setting visible devices for training
export CUDA_VISIBLE_DEVICES=0
python train.py
python train.py --config ./configs/transformer.base.yaml
```
可以在 transformer.yaml 文件中设置相应的参数,比如设置控制最大迭代次数的 `max_iter`
可以在 `configs/transformer.big.yaml``configs/transformer.base.yaml` 文件中设置相应的参数
### 单机多卡
......@@ -66,7 +74,7 @@ python train.py
```sh
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" train.py
python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" train.py --config ./configs/transformer.base.yaml
```
......@@ -80,7 +88,7 @@ export CUDA_VISIBLE_DEVICES=0
python predict.py
```
`predict_file` 指定的文件中文本的翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型所在目录,更多参数的使用可以在 `transformer.yaml` 文件中查阅注释说明并进行更改设置。需要注意的是,目前预测仅实现了单卡的预测,原因在于,翻译后面需要的模型评估依赖于预测结果写入文件顺序,多卡情况下,目前不能保证结果写入文件的顺序。
`predict_file` 指定的文件中文本的翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型所在目录,更多参数的使用可以在 `configs/transformer.big.yaml``configs/transformer.base.yaml` 文件中查阅注释说明并进行更改设置。需要注意的是,目前预测仅实现了单卡的预测,原因在于,翻译后面需要的模型评估依赖于预测结果写入文件顺序,多卡情况下,目前不能保证结果写入文件的顺序。
### 模型评估
......@@ -91,13 +99,13 @@ python predict.py
# 还原 predict.txt 中的预测结果为 tokenize 后的数据
sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt
# 若无 BLEU 评估工具,需先进行下载
# git clone https://github.com/moses-smt/mosesdecoder.git
git clone https://github.com/moses-smt/mosesdecoder.git
# 以英德翻译 newstest2014 测试数据为例
perl gen_data/mosesdecoder/scripts/generic/multi-bleu.perl gen_data/wmt14_ende_data/newstest2014.tok.de < predict.tok.txt
perl mosesdecoder/scripts/generic/multi-bleu.perl ~/.paddlenlp/datasets/machine_translation/WMT14ende/WMT14.en-de/wmt14_ende_data/newstest2014.tok.de < predict.tok.txt
```
可以看到类似如下的结果:
可以看到类似如下的结果,此处结果是 big model 在 newstest2014 上的结果
```
BLEU = 26.35, 57.7/32.1/20.0/13.0 (BP=1.000, ratio=1.013, hyp_len=63903, ref_len=63078)
BLEU = 27.48, 58.6/33.2/21.1/13.9 (BP=1.000, ratio=1.012, hyp_len=65312, ref_len=64506)
```
## 进阶使用
......
......@@ -10,37 +10,23 @@ init_from_pretrain_model: ""
init_from_params: "./trained_models/step_final/"
# The directory for saving model
save_model: "trained_models"
# The directory for saving inference model.
inference_model_dir: "infer_model"
# Set seed for CE or debug
random_seed: None
# The pattern to match training data files.
training_file: "gen_data/wmt14_ende_data_bpe/train.tok.clean.bpe.33708.en-de"
# The pattern to match validation data files.
validation_file: "gen_data/wmt14_ende_data_bpe/newstest2013.tok.bpe.33708.en-de"
# The pattern to match test data files.
predict_file: "gen_data/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en-de"
# The file to output the translation results of predict_file to.
output_file: "predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath: "gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The path of vocabulary file of target language.
trg_vocab_fpath: "gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token: ["<s>", "<e>", "<unk>"]
# The directory to store data.
root: None
# Whether to use cuda
use_gpu: True
# Args for reader, see reader.py for details
token_delimiter: " "
use_token_batch: True
pool_size: 200000
sort_type: "pool"
shuffle: True
shuffle_batch: True
batch_size: 4096
infer_batch_size: 32
infer_batch_size: 8
# Hyparams for training:
# The number of epoches for training
......
......@@ -10,37 +10,23 @@ init_from_pretrain_model: ""
init_from_params: "./trained_models/step_final/"
# The directory for saving model
save_model: "trained_models"
# The directory for saving inference model.
inference_model_dir: "infer_model"
# Set seed for CE or debug
random_seed: None
# The pattern to match training data files.
training_file: "gen_data/wmt14_ende_data_bpe/train.tok.clean.bpe.33708.en-de"
# The pattern to match validation data files.
validation_file: "gen_data/wmt14_ende_data_bpe/newstest2013.tok.bpe.33708.en-de"
# The pattern to match test data files.
predict_file: "gen_data/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en-de"
# The file to output the translation results of predict_file to.
output_file: "predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath: "gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The path of vocabulary file of target language.
trg_vocab_fpath: "gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token: ["<s>", "<e>", "<unk>"]
# The directory to store data.
root: None
# Whether to use cuda
use_gpu: True
# Args for reader, see reader.py for details
token_delimiter: " "
use_token_batch: True
pool_size: 200000
sort_type: "pool"
shuffle: True
shuffle_batch: True
batch_size: 4096
infer_batch_size: 16
infer_batch_size: 8
# Hyparams for training:
# The number of epoches for training
......
#! /usr/bin/env bash
set -e
OUTPUT_DIR=$PWD/gen_data
###############################################################################
# change these variables for other WMT data
###############################################################################
OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_ende_data"
OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_ende_data_bpe"
LANG1="en"
LANG2="de"
# each of TRAIN_DATA: data_url data_file_lang1 data_file_lang2
TRAIN_DATA=(
'http://statmt.org/wmt13/training-parallel-europarl-v7.tgz'
'europarl-v7.de-en.en' 'europarl-v7.de-en.de'
'http://statmt.org/wmt13/training-parallel-commoncrawl.tgz'
'commoncrawl.de-en.en' 'commoncrawl.de-en.de'
'http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz'
'news-commentary-v12.de-en.en' 'news-commentary-v12.de-en.de'
)
# each of DEV_TEST_DATA: data_url data_file_lang1 data_file_lang2
# source & reference
DEV_TEST_DATA=(
'http://data.statmt.org/wmt17/translation-task/dev.tgz'
'newstest2013-ref.de.sgm' 'newstest2013-src.en.sgm'
'http://statmt.org/wmt14/test-full.tgz'
'newstest2014-deen-ref.en.sgm' 'newstest2014-deen-src.de.sgm'
)
###############################################################################
###############################################################################
# change these variables for other WMT data
###############################################################################
# OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_enfr_data"
# OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_enfr_data_bpe"
# LANG1="en"
# LANG2="fr"
# # each of TRAIN_DATA: ata_url data_tgz data_file
# TRAIN_DATA=(
# 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
# 'commoncrawl.fr-en.en' 'commoncrawl.fr-en.fr'
# 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz'
# 'training/europarl-v7.fr-en.en' 'training/europarl-v7.fr-en.fr'
# 'http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz'
# 'training/news-commentary-v9.fr-en.en' 'training/news-commentary-v9.fr-en.fr'
# 'http://www.statmt.org/wmt10/training-giga-fren.tar'
# 'giga-fren.release2.fixed.en.*' 'giga-fren.release2.fixed.fr.*'
# 'http://www.statmt.org/wmt13/training-parallel-un.tgz'
# 'un/undoc.2000.fr-en.en' 'un/undoc.2000.fr-en.fr'
# )
# # each of DEV_TEST_DATA: data_url data_tgz data_file_lang1 data_file_lang2
# DEV_TEST_DATA=(
# 'http://data.statmt.org/wmt16/translation-task/dev.tgz'
# '.*/newstest201[45]-fren-ref.en.sgm' '.*/newstest201[45]-fren-src.fr.sgm'
# 'http://data.statmt.org/wmt16/translation-task/test.tgz'
# '.*/newstest2016-fren-ref.en.sgm' '.*/newstest2016-fren-src.fr.sgm'
# )
###############################################################################
mkdir -p $OUTPUT_DIR_DATA $OUTPUT_DIR_BPE_DATA
# Extract training data
for ((i=0;i<${#TRAIN_DATA[@]};i+=3)); do
data_url=${TRAIN_DATA[i]}
data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
data=${data_tgz%.*} # training-parallel-commoncrawl
data_lang1=${TRAIN_DATA[i+1]}
data_lang2=${TRAIN_DATA[i+2]}
if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
echo "Download "${data_url}
echo "Dir "${OUTPUT_DIR_DATA}/${data_tgz}
wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
fi
if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
echo "Extract "${data_tgz}
mkdir -p ${OUTPUT_DIR_DATA}/${data}
tar_type=${data_tgz:0-3}
if [ ${tar_type} == "tar" ]; then
tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
else
tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
fi
fi
# concatenate all training data
for data_lang in $data_lang1 $data_lang2; do
for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
data_dir=`dirname $f`
data_file=`basename $f`
f_base=${f%.*}
f_ext=${f##*.}
if [ $f_ext == "gz" ]; then
gunzip $f
l=${f_base##*.}
f_base=${f_base%.*}
else
l=${f_ext}
fi
if [ $i -eq 0 ]; then
cat ${f_base}.$l > ${OUTPUT_DIR_DATA}/train.$l
else
cat ${f_base}.$l >> ${OUTPUT_DIR_DATA}/train.$l
fi
done
done
done
# Clone mosesdecoder
if [ ! -d ${OUTPUT_DIR}/mosesdecoder ]; then
echo "Cloning moses for data processing"
git clone https://github.com/moses-smt/mosesdecoder.git ${OUTPUT_DIR}/mosesdecoder
fi
# Extract develop and test data
dev_test_data=""
for ((i=0;i<${#DEV_TEST_DATA[@]};i+=3)); do
data_url=${DEV_TEST_DATA[i]}
data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
data=${data_tgz%.*} # training-parallel-commoncrawl
data_lang1=${DEV_TEST_DATA[i+1]}
data_lang2=${DEV_TEST_DATA[i+2]}
if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
echo "Download "${data_url}
wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
fi
if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
echo "Extract "${data_tgz}
mkdir -p ${OUTPUT_DIR_DATA}/${data}
tar_type=${data_tgz:0-3}
if [ ${tar_type} == "tar" ]; then
tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
else
tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
fi
fi
for data_lang in $data_lang1 $data_lang2; do
for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
echo "input-from-sgm"
data_dir=`dirname $f`
data_file=`basename $f`
data_out=`echo ${data_file} | cut -d '-' -f 1` # newstest2016
l=`echo ${data_file} | cut -d '.' -f 2` # en
dev_test_data="${dev_test_data}\|${data_out}" # to make regexp
if [ ! -e ${OUTPUT_DIR_DATA}/${data_out}.$l ]; then
${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
< $f > ${OUTPUT_DIR_DATA}/${data_out}.$l
fi
done
done
done
# Tokenize data
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train\|newstest2013\)\.$l$"`; do
f_base=${f%.*} # dir/train dir/newstest2013
f_out=$f_base.tok.$l
f_tmp=$f_base.tmp.$l
if [ ! -e $f_out ]; then
echo "Tokenize "$f
cat $f | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/normalize-punctuation.perl $l | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/remove-non-printing-char.perl | \
tee -a $tmp/valid.raw.$l | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l $l -threads 8 >> $f_out
echo $f_out
fi
done
done
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(newstest2014\)\.$l$"`; do
f_base=${f%.*} # dir/newstest2014
f_out=$f_base.tok.$l
if [ ! -e $f_out ]; then
echo "Tokenize "$f
cat $f | \
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l $l -threads 8 >> $f_out
echo $f_out
fi
done
done
# Clean data
for f in ${OUTPUT_DIR_DATA}/train.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.${LANG1}; do
f_base=${f%.*} # dir/train dir/train.tok
f_out=${f_base}.clean
if [ ! -e $f_out.${LANG1} ] && [ ! -e $f_out.${LANG2} ]; then
echo "Clean "${f_base}
${OUTPUT_DIR}/mosesdecoder/scripts/training/clean-corpus-n.perl $f_base ${LANG1} ${LANG2} ${f_out} 1 256
fi
done
python -m pip install subword-nmt
# Generate BPE data and vocabulary
for num_operations in 33708; do
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} ]; then
echo "Learn BPE with ${num_operations} merge operations"
cat ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG2} | \
subword-nmt learn-bpe -s $num_operations > ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations}
fi
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train${dev_test_data}\)\.tok\(\.clean\)\?\.$l$"`; do
f_base=${f%.*} # dir/train.tok dir/train.tok.clean dir/newstest2016.tok
f_base=${f_base##*/} # train.tok train.tok.clean newstest2016.tok
f_out=${OUTPUT_DIR_BPE_DATA}/${f_base}.bpe.${num_operations}.$l
if [ ! -e $f_out ]; then
echo "Apply BPE to "$f
subword-nmt apply-bpe -c ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} < $f > $f_out
fi
done
done
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} ]; then
echo "Create vocabulary for BPE data"
cat ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG1} ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG2} | \
subword-nmt get-vocab | cut -f1 -d ' ' > ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations}
fi
done
# Adapt to the reader
for f in ${OUTPUT_DIR_BPE_DATA}/*.bpe.${num_operations}.${LANG1}; do
f_base=${f%.*} # dir/train.tok.clean.bpe.32000 dir/newstest2016.tok.bpe.32000
f_out=${f_base}.${LANG1}-${LANG2}
if [ ! -e $f_out ]; then
paste -d '\t' $f_base.${LANG1} $f_base.${LANG2} > $f_out
fi
done
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations} ]; then
sed '1i\<s>\n<e>\n<unk>' ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} > ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations}
fi
echo "All done."
......@@ -48,8 +48,7 @@ def do_predict(args):
paddle.set_device(place)
# Define data loader
(test_loader,
test_steps_fn), trg_idx2word = reader.create_infer_loader(args)
test_loader, to_tokens = reader.create_infer_loader(args)
# Define model
transformer = InferTransformerModel(
......@@ -95,7 +94,7 @@ def do_predict(args):
if beam_idx >= args.n_best:
break
id_list = post_process_seq(beam, args.bos_idx, args.eos_idx)
word_list = [trg_idx2word[id] for id in id_list]
word_list = to_tokens(id_list)
sequence = " ".join(word_list) + "\n"
f.write(sequence)
......
......@@ -22,79 +22,105 @@ from functools import partial
import numpy as np
from paddle.io import BatchSampler, DataLoader, Dataset
from paddlenlp.data import Pad
from paddlenlp.datasets import WMT14ende
from paddlenlp.data.sampler import SamplerHelper
def create_infer_loader(args):
dataset = TransformerDataset(
fpattern=args.predict_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
trg_idx2word = TransformerDataset.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
batch_sampler = TransformerBatchSampler(
dataset=dataset,
use_token_batch=False,
batch_size=args.infer_batch_size,
max_length=args.max_length)
def min_max_filer(data, max_len, min_len=0):
# 1 for special tokens.
data_min_len = min(len(data[0]), len(data[1])) + 1
data_max_len = max(len(data[0]), len(data[1])) + 1
return (data_min_len >= min_len) and (data_max_len <= max_len)
def create_data_loader(args):
root = None if args.root == "None" else args.root
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
transform_func = WMT14ende.get_default_transform_func(root=root)
datasets = [
WMT14ende.get_datasets(
mode=m, transform_func=transform_func) for m in ["train", "dev"]
]
def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
data_source):
return max(tokens_sofar,
len(data_source[current_idx][0]) + 1,
len(data_source[current_idx][1]) + 1)
def _key(size_so_far, minibatch_len):
return size_so_far * minibatch_len
data_loaders = [(None)] * 2
for i, dataset in enumerate(datasets):
m = dataset.mode
dataset = dataset.filter(
partial(
min_max_filer, max_len=args.max_length))
sampler = SamplerHelper(dataset)
src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
if args.sort_type == SortType.GLOBAL:
buffer_size = -1
trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
# Sort twice
sampler = sampler.sort(
key=trg_key, buffer_size=buffer_size).sort(
key=src_key, buffer_size=buffer_size)
else:
sampler = sampler.shuffle()
if args.sort_type == SortType.POOL:
buffer_size = args.pool_size
sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
batch_sampler = sampler.batch(
batch_size=args.batch_size,
drop_last=False,
batch_size_fn=_max_token_fn,
key=_key)
if m == "train":
batch_sampler = batch_sampler.shard()
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
collate_fn=partial(
prepare_infer_input,
prepare_train_input,
bos_idx=args.bos_idx,
eos_idx=args.eos_idx,
pad_idx=args.bos_idx),
num_workers=0,
return_list=True)
data_loaders = (data_loader, batch_sampler.__len__)
return data_loaders, trg_idx2word
data_loaders[i] = (data_loader)
return data_loaders
def create_data_loader(args, world_size=1, rank=0):
data_loaders = [(None, None)] * 2
data_files = [args.training_file, args.validation_file
] if args.validation_file else [args.training_file]
for i, data_file in enumerate(data_files):
dataset = TransformerDataset(
fpattern=data_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = TransformerBatchSampler(
dataset=dataset,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
use_token_batch=args.use_token_batch,
max_length=args.max_length,
distribute_mode=True if i == 0 else False,
world_size=world_size,
rank=rank)
def create_infer_loader(args):
root = None if args.root == "None" else args.root
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
transform_func = WMT14ende.get_default_transform_func(root=root)
dataset = WMT14ende.get_datasets(
mode="test", transform_func=transform_func).filter(
partial(
min_max_filer, max_len=args.max_length))
batch_sampler = SamplerHelper(dataset).batch(
batch_size=args.infer_batch_size, drop_last=False)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
collate_fn=partial(
prepare_train_input,
prepare_infer_input,
bos_idx=args.bos_idx,
eos_idx=args.eos_idx,
pad_idx=args.bos_idx),
num_workers=0,
return_list=True)
data_loaders[i] = (data_loader, batch_sampler.__len__)
return data_loaders
return data_loader, trg_vocab.to_tokens
def prepare_train_input(insts, bos_idx, eos_idx, pad_idx):
......@@ -126,301 +152,3 @@ class SortType(object):
GLOBAL = 'global'
POOL = 'pool'
NONE = "none"
class Converter(object):
def __init__(self, vocab, beg, end, unk, delimiter, add_beg, add_end):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
self._delimiter = delimiter
self._add_beg = add_beg
self._add_end = add_end
def __call__(self, sentence):
return ([self._beg] if self._add_beg else []) + [
self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter)
] + ([self._end] if self._add_end else [])
class ComposedConverter(object):
def __init__(self, converters):
self._converters = converters
def __call__(self, fields):
return [
converter(field)
for field, converter in zip(fields, self._converters)
]
class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self.batch) + 1) > self._batch_size:
result = self.batch
self.batch = [info]
self.max_len = cur_len
return result
else:
self.max_len = max_len
self.batch.append(info)
class SampleInfo(object):
def __init__(self, i, lens):
self.i = i
# take bos and eos into account
self.min_len = min(lens[0] + 1, lens[1] + 1)
self.max_len = max(lens[0] + 1, lens[1] + 1)
self.src_len = lens[0]
self.trg_len = lens[1]
class MinMaxFilter(object):
def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
def append(self, info):
if info.max_len > self._max_len or info.min_len < self._min_len:
return
else:
return self._creator.append(info)
@property
def batch(self):
return self._creator.batch
class TransformerDataset(Dataset):
def __init__(self,
src_vocab_fpath,
trg_vocab_fpath,
fpattern,
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
trg_fpattern=None):
self._src_vocab = self.load_dict(src_vocab_fpath)
self._trg_vocab = self.load_dict(trg_vocab_fpath)
self._bos_idx = self._src_vocab[start_mark]
self._eos_idx = self._src_vocab[end_mark]
self._unk_idx = self._src_vocab[unk_mark]
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self.load_src_trg_ids(fpattern, trg_fpattern)
def load_src_trg_ids(self, fpattern, trg_fpattern=None):
src_converter = Converter(
vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False,
add_end=False)
trg_converter = Converter(
vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False,
add_end=False)
converters = ComposedConverter([src_converter, trg_converter])
self._src_seq_ids = []
self._trg_seq_ids = []
self._sample_infos = []
slots = [self._src_seq_ids, self._trg_seq_ids]
for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)):
lens = []
for field, slot in zip(converters(line), slots):
slot.append(field)
lens.append(len(field))
self._sample_infos.append(SampleInfo(i, lens))
def _load_lines(self, fpattern, trg_fpattern=None):
fpaths = glob.glob(fpattern)
fpaths = sorted(fpaths) # TODO: Add custum sort
assert len(fpaths) > 0, "no matching file to the provided data path"
(f_mode, f_encoding, endl) = ("r", "utf8", "\n")
if trg_fpattern is None:
for fpath in fpaths:
with io.open(fpath, f_mode, encoding=f_encoding) as f:
for line in f:
fields = line.strip(endl).split(self._field_delimiter)
yield fields
else:
# separated source and target language data files
# assume we can get aligned data by sort the two language files
trg_fpaths = glob.glob(trg_fpattern)
trg_fpaths = sorted(trg_fpaths)
assert len(fpaths) == len(
trg_fpaths
), "the number of source language data files must equal \
with that of source language"
for fpath, trg_fpath in zip(fpaths, trg_fpaths):
with io.open(fpath, f_mode, encoding=f_encoding) as f:
with io.open(
trg_fpath, f_mode, encoding=f_encoding) as trg_f:
for line in zip(f, trg_f):
fields = [field.strip(endl) for field in line]
yield fields
@staticmethod
def load_dict(dict_path, reverse=False):
word_dict = {}
(f_mode, f_encoding, endl) = ("r", "utf8", "\n")
with io.open(dict_path, f_mode, encoding=f_encoding) as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip(endl)
else:
word_dict[line.strip(endl)] = idx
return word_dict
def get_vocab_summary(self):
return len(self._src_vocab), len(
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
def __getitem__(self, idx):
return (self._src_seq_ids[idx], self._trg_seq_ids[idx]
) if self._trg_seq_ids else self._src_seq_ids[idx]
def __len__(self):
return len(self._sample_infos)
class TransformerBatchSampler(BatchSampler):
def __init__(self,
dataset,
batch_size,
pool_size=10000,
sort_type=SortType.NONE,
min_length=0,
max_length=100,
shuffle=False,
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
distribute_mode=True,
seed=0,
world_size=1,
rank=0):
for arg, value in locals().items():
if arg != "self":
setattr(self, "_" + arg, value)
self._random = np.random
self._random.seed(seed)
# for multi-devices
self._distribute_mode = distribute_mode
self._nranks = world_size
self._local_rank = rank
def __iter__(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self._dataset._sample_infos, key=lambda x: x.trg_len)
infos = sorted(infos, key=lambda x: x.src_len)
else:
if self._shuffle:
infos = self._dataset._sample_infos
self._random.shuffle(infos)
else:
infos = self._dataset._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# to avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
batches = []
batch_creator = TokenBatchCreator(
self.
_batch_size) if self._use_token_batch else SentenceBatchCreator(
self._batch_size * self._nranks)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)
if self._shuffle_batch:
self._random.shuffle(batches)
if not self._use_token_batch:
# when producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches = [[
batch[self._batch_size * i:self._batch_size * (i + 1)]
for i in range(self._nranks)
] for batch in batches]
batches = list(itertools.chain.from_iterable(batches))
self.batch_number = (len(batches) + self._nranks - 1) // self._nranks
# for multi-device
for batch_id, batch in enumerate(batches):
if not self._distribute_mode or (
batch_id % self._nranks == self._local_rank):
batch_indices = [info.i for info in batch]
yield batch_indices
if self._distribute_mode and len(batches) % self._nranks != 0:
if self._local_rank >= len(batches) % self._nranks:
# use previous data to pad
yield batch_indices
def __len__(self):
if hasattr(self, "batch_number"): #
return self.batch_number
if not self._use_token_batch:
batch_number = (
len(self._dataset) + self._batch_size * self._nranks - 1) // (
self._batch_size * self._nranks)
else:
# for uncertain batch number, the actual value is self.batch_number
batch_number = sys.maxsize
return batch_number
......@@ -43,9 +43,7 @@ def do_train(args):
paddle.seed(random_seed)
# Define data loader
(train_loader, train_steps_fn), (eval_loader,
eval_steps_fn) = reader.create_data_loader(
args, trainer_count, rank)
(train_loader), (eval_loader) = reader.create_data_loader(args)
# Define model
transformer = TransformerModel(
......@@ -150,7 +148,6 @@ def do_train(args):
if step_idx % args.save_step == 0 and step_idx != 0:
# Validation
if args.validation_file:
transformer.eval()
total_sum_cost = 0
total_token_num = 0
......
......@@ -137,7 +137,7 @@ class SamplerHelper(object):
"""
Sort samples according to given callable cmp or key.
Args:
cmp (callable): The funcation of comparison. Default: None.
cmp (callable): The function of comparison. Default: None.
key (callable): Return element to be compared. Default: None.
reverse (bool): If True, it means in descending order, and False means in ascending order. Default: False.
buffer_size (int): Buffer size for sort. If buffer_size < 0 or buffer_size is more than the length of the data,
......
......@@ -16,6 +16,7 @@ import collections
import io
import json
import os
import warnings
class Vocab(object):
......@@ -179,7 +180,12 @@ class Vocab(object):
tokens = []
for idx in indices:
if not isinstance(idx, int) or idx > max_idx:
if not isinstance(idx, int):
warnings.warn(
"The type of `to_tokens()`'s input `indices` is not `int` which will be forcibly transfered to `int`. "
)
idx = int(idx)
if idx > max_idx:
raise ValueError(
'Token index {} in the provided `indices` is invalid.'.
format(idx))
......
......@@ -13,7 +13,7 @@ from paddlenlp.data.sampler import SamplerHelper
from paddlenlp.utils.env import DATA_HOME
from paddle.dataset.common import md5file
__all__ = ['TranslationDataset', 'IWSLT15']
__all__ = ['TranslationDataset', 'IWSLT15', 'WMT14ende']
def sequential_transforms(*transforms):
......@@ -29,8 +29,8 @@ def get_default_tokenizer():
"""Only support split tokenizer
"""
def _split_tokenizer(x):
return x.split()
def _split_tokenizer(x, delimiter=None):
return x.split(delimiter)
return _split_tokenizer
......@@ -50,9 +50,9 @@ class TranslationDataset(paddle.io.Dataset):
MD5 = None
VOCAB_INFO = None
UNK_TOKEN = None
PAD_TOKEN = None
BOS_TOKEN = None
EOS_TOKEN = None
PAD_TOKEN = None
def __init__(self, data):
self.data = data
......@@ -143,14 +143,14 @@ class TranslationDataset(paddle.io.Dataset):
tgt_file_path = os.path.join(root, tgt_vocab_filename)
src_vocab = Vocab.load_vocabulary(
src_file_path,
filepath=src_file_path,
unk_token=cls.UNK_TOKEN,
pad_token=cls.PAD_TOKEN,
bos_token=cls.BOS_TOKEN,
eos_token=cls.EOS_TOKEN)
tgt_vocab = Vocab.load_vocabulary(
tgt_file_path,
filepath=tgt_file_path,
unk_token=cls.UNK_TOKEN,
pad_token=cls.PAD_TOKEN,
bos_token=cls.BOS_TOKEN,
......@@ -273,6 +273,90 @@ class IWSLT15(TranslationDataset):
transform_func[1](data[1])) for data in self.data]
class WMT14ende(TranslationDataset):
"""
WMT14 English to German translation dataset.
Args:
mode(str, optional): It could be 'train', 'dev' or 'test'. Default: 'train'.
root(str, optional): If None, dataset will be downloaded in
`/root/.paddlenlp/datasets/machine_translation/WMT14ende/`. Default: None.
transform_func(callable, optional): If not None, it transforms raw data
to index data. Default: None.
Examples:
.. code-block:: python
from paddlenlp.datasets import WMT14ende
transform_func = WMT14ende.get_default_transform_func(root=root)
train_dataset = WMT14ende.get_datasets(mode="train", transform_func=transform_func)
"""
URL = "https://paddlenlp.bj.bcebos.com/datasets/WMT14.en-de.tar.gz"
SPLITS = {
'train': TranslationDataset.META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"train.tok.clean.bpe.33708.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"train.tok.clean.bpe.33708.de"),
"c7c0b77e672fc69f20be182ae37ff62c",
"1865ece46948fda1209d3b7794770a0a"),
'dev': TranslationDataset.META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"newstest2013.tok.bpe.33708.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"newstest2013.tok.bpe.33708.de"),
"aa4228a4bedb6c45d67525fbfbcee75e",
"9b1eeaff43a6d5e78a381a9b03170501"),
'test': TranslationDataset.META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"newstest2014.tok.bpe.33708.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"newstest2014.tok.bpe.33708.de"),
"c9403eacf623c6e2d9e5a1155bdff0b5",
"0058855b55e37c4acfcb8cffecba1050"),
'dev-eval': TranslationDataset.META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data",
"newstest2013.tok.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data",
"newstest2013.tok.de"),
"d74712eb35578aec022265c439831b0e",
"6ff76ced35b70e63a61ecec77a1c418f"),
'test-eval': TranslationDataset.META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data",
"newstest2014.tok.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data",
"newstest2014.tok.de"),
"8cce2028e4ca3d4cc039dfd33adbfb43",
"a1b1f4c47f487253e1ac88947b68b3b8")
}
VOCAB_INFO = (os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"vocab_all.bpe.33708"),
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"vocab_all.bpe.33708"),
"2fc775b7df37368e936a8e1f63846bb0",
"2fc775b7df37368e936a8e1f63846bb0")
UNK_TOKEN = "<unk>"
BOS_TOKEN = "<s>"
EOS_TOKEN = "<e>"
MD5 = "5506d213dba4124121c682368257bae4"
def __init__(self, mode="train", root=None, transform_func=None):
if mode not in ("train", "dev", "test", "dev-eval", "test-eval"):
raise TypeError(
'`train`, `dev`, `test`, `dev-eval` or `test-eval` is supported but `{}` is passed in'.
format(mode))
if transform_func is not None and len(transform_func) != 2:
if len(transform_func) != 2:
raise ValueError("`transform_func` must have length of two for"
"source and target.")
self.data = WMT14ende.get_data(mode=mode, root=root)
self.mode = mode
if transform_func is not None:
self.data = [(transform_func[0](data[0]),
transform_func[1](data[1])) for data in self.data]
super(WMT14ende, self).__init__(self.data)
# For test, not API
def prepare_train_input(insts, pad_id):
src, src_length = Pad(pad_val=pad_id, ret_length=True)(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册