未验证 提交 53937db0 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #1055 from guoshengCS/add-WMT-enfr

Add wmt enfr
......@@ -9,13 +9,14 @@
```text
.
├── images # README 文档中的图片
├── optim.py # learning rate scheduling 计算程序
├── config.py # 训练、预测以及模型参数配置
├── infer.py # 预测脚本
├── model.py # 模型定义
├── optim.py # learning rate scheduling 计算程序
├── reader.py # 数据读取接口
├── README.md # 文档
├── train.py # 训练脚本
└── config.py # 训练、预测以及模型参数配置
└── util.py # wordpiece 数据解码工具
```
### 简介
......@@ -58,34 +59,43 @@ Decoder 具有和 Encoder 类似的结构,只是相比于组成 Encoder 的 la
### 数据准备
我们以 [WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)作为示例,同时参照论文中的设置使用 BPE(byte-pair encoding)[4]编码的数据,使用这种方式表示的数据能够更好的解决未登录词(out-of-vocabulary,OOV)的问题。用到的 BPE 数据可以参照[这里](https://github.com/google/seq2seq/blob/master/docs/data.md)进行下载,下载后解压,其中 `train.tok.clean.bpe.32000.en``train.tok.clean.bpe.32000.de` 为使用 BPE 的训练数据(平行语料,分别对应了英语和德语,经过了 tokenize 和 BPE 的处理),`newstest2013.tok.bpe.32000.en``newstest2013.tok.bpe.32000.de` 等为测试数据(`newstest2013.tok.en``newstest2013.tok.de` 等则为对应的未使用 BPE 的测试数据),`vocab.bpe.32000` 为相应的词典文件(源语言和目标语言共享该词典文件)。
WMT 数据集是机器翻译领域公认的主流数据集;WMT 英德和英法数据集也是 Transformer 论文中所用数据集,其中英德数据集使用了 BPE(byte-pair encoding)[4]编码的数据,英法数据集使用了 wordpiece [5]的数据。我们这里也将使用 WMT 英德和英法翻译数据,并和论文保持一致使用 BPE 和 wordpiece 的数据,下面给出了使用的方法。对于其他自定义数据,参照下文遵循或转换为类似的数据格式即可。
#### WMT 英德翻译数据
由于本示例中的数据读取脚本 `reader.py` 使用的样本数据的格式为 `\t` 分隔的的源语言和目标语言句子对(句子中的词之间使用空格分隔), 因此需要将源语言到目标语言的平行语料库文件合并为一个文件,可以执行以下命令进行合并:
[WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)是一个中等规模的数据集。参照论文,英德数据集我们使用 BPE 编码的数据,这能够更好的解决未登录词(out-of-vocabulary,OOV)的问题[4]。用到的 BPE 数据可以参照[这里](https://github.com/google/seq2seq/blob/master/docs/data.md)进行下载(如果希望在自定义数据中使用 BPE 编码,可以参照[这里](https://github.com/rsennrich/subword-nmt)进行预处理),下载后解压,其中 `train.tok.clean.bpe.32000.en``train.tok.clean.bpe.32000.de` 为使用 BPE 的训练数据(平行语料,分别对应了英语和德语,经过了 tokenize 和 BPE 的处理),`newstest2013.tok.bpe.32000.en``newstest2013.tok.bpe.32000.de` 等为测试数据(`newstest2013.tok.en``newstest2013.tok.de` 等则为对应的未使用 BPE 的测试数据),`vocab.bpe.32000` 为相应的词典文件(源语言和目标语言共享该词典文件)。
由于本示例中的数据读取脚本 `reader.py` 默认使用的样本数据的格式为 `\t` 分隔的的源语言和目标语言句子对(默认句子中的词之间使用空格分隔),因此需要将源语言到目标语言的平行语料库文件合并为一个文件,可以执行以下命令进行合并:
```sh
paste -d '\t' train.tok.clean.bpe.32000.en train.tok.clean.bpe.32000.de > train.tok.clean.bpe.32000.en-de
```
此外,下载的词典文件 `vocab.bpe.32000` 中未包含表示序列开始、序列结束和未登录词的特殊符号,可以使用如下命令在词典中加入 `<s>``<e>``<unk>` 作为这三个特殊符号。
此外,下载的词典文件 `vocab.bpe.32000` 中未包含表示序列开始、序列结束和未登录词的特殊符号,可以使用如下命令在词典中加入 `<s>``<e>``<unk>` 作为这三个特殊符号(用 BPE 表示数据已有效避免了未登录词的问题,这里加入只是做通用处理)
```sh
sed -i '1i\<s>\n<e>\n<unk>' vocab.bpe.32000
```
对于其他自定义数据,遵循或转换为上述的数据格式即可。如果希望在自定义数据中使用 BPE 编码,可以参照[这里](https://github.com/rsennrich/subword-nmt)进行预处理。
#### WMT 英法翻译数据
[WMT'14 EN-FR 数据集](http://www.statmt.org/wmt14/translation-task.html)是一个较大规模的数据集。参照论文,英法数据我们使用 wordpiece 表示的数据,wordpiece 和 BPE 类似同为采用 sub-word units 来解决 OOV 问题的方法[5]。我们提供了已完成预处理的 wordpiece 数据的下载,可以从[这里](http://transformer-data.bj.bcebos.com/wmt14_enfr.tar)下载,其中 `train.wordpiece.en-fr` 为使用 wordpiece 的训练数据,`newstest2014.wordpiece.en-fr` 为测试数据(`newstest2014.tok.en``newstest2014.tok.fr` 为对应的未经 wordpiece 处理过的测试数据,使用[脚本](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl)进行了 tokenize 的处理),`vocab.wordpiece.en-fr` 为相应的词典文件(源语言和目标语言共享该词典文件)。
提供的英法翻译数据无需进行额外的处理,可以直接使用;需要注意的是,这些用 wordpiece 表示的数据中句子内的 token 之间使用 `\x01` 而非空格进行分隔(因部分 token 内包含空格),这需要在训练时进行指定。
### 模型训练
`train.py` 是模型训练脚本,可以执行以下命令进行模型训练:
`train.py` 是模型训练脚本。以英德翻译数据为例,可以执行以下命令进行模型训练:
```sh
python -u train.py \
--src_vocab_fpath data/vocab.bpe.32000 \
--trg_vocab_fpath data/vocab.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--train_file_pattern data/train.tok.clean.bpe.32000.en-de \
--token_delimiter ' ' \
--use_token_batch True \
--batch_size 3200 \
--sort_type pool \
--pool_size 200000 \
--pool_size 200000
```
上述命令中设置了源语言词典文件路径(`src_vocab_fpath`)、目标语言词典文件路径(`trg_vocab_fpath`)、训练数据文件(`train_file_pattern`,支持通配符)等数据相关的参数和构造 batch 方式(`use_token_batch`数据按照 token 数目或者 sequence 数目组成 batch)等 reader 相关的参数。有关这些参数更详细的信息可以通过执行以下命令查看:
上述命令中设置了源语言词典文件路径(`src_vocab_fpath`)、目标语言词典文件路径(`trg_vocab_fpath`)、训练数据文件(`train_file_pattern`,支持通配符)等数据相关的参数和构造 batch 方式(`use_token_batch`定了数据按照 token 数目或者 sequence 数目组成 batch)等 reader 相关的参数。有关这些参数更详细的信息可以通过执行以下命令查看:
```sh
python train.py --help
```
......@@ -98,19 +108,20 @@ python -u train.py \
--trg_vocab_fpath data/vocab.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--train_file_pattern data/train.tok.clean.bpe.32000.en-de \
--token_delimiter ' ' \
--use_token_batch True \
--batch_size 3200 \
--sort_type pool \
--pool_size 200000 \
n_layer 8 \
n_layer 6 \
n_head 16 \
d_model 1024 \
d_inner_hid 4096 \
dropout 0.3
```
有关这些参数更详细信息的还请参考 `config.py` 中的注释说明
有关这些参数更详细信息的请参考 `config.py` 中的注释说明。对于英法翻译数据,执行训练和英德翻译训练类似,修改命令中的词典和数据文件为英法数据相应文件的路径,另外要注意的是由于英法翻译数据 token 间不是使用空格进行分隔,需要修改 `token_delimiter` 参数的设置为 `--token_delimiter '\x01'`
训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用的 GPU 数目。也可以只使用CPU训练(通过参数--divice CPU),训练速度相对较慢。在训练过程中,每个 epoch 结束后将保存模型到参数 `model_dir` 指定的目录,每个 iteration 将打印如下的日志到标准输出:
训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用的 GPU 数目。也可以只使用 CPU 训练(通过参数 `--divice CPU` 设置),训练速度相对较慢。在训练过程中,每个 epoch 结束后将保存模型到参数 `model_dir` 指定的目录,每个 epoch 内也会每隔1000个 iteration 进行一次保存,每个 iteration 将打印如下的日志到标准输出:
```txt
epoch: 0, batch: 0, sum loss: 258793.343750, avg loss: 11.069005, ppl: 64151.644531
epoch: 0, batch: 1, sum loss: 256140.718750, avg loss: 11.059616, ppl: 63552.148438
......@@ -126,37 +137,45 @@ epoch: 0, batch: 9, sum loss: 245157.500000, avg loss: 10.966562, ppl: 57905.187
### 模型预测
`infer.py` 是模型预测脚本,模型训练完成后可以执行以下命令对指定文件中的文本进行翻译:
`infer.py` 是模型预测脚本。以英德翻译数据为例,模型训练完成后可以执行以下命令对指定文件中的文本进行翻译:
```sh
python -u infer.py \
--src_vocab_fpath data/vocab.bpe.32000 \
--trg_vocab_fpath data/vocab.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--test_file_pattern data/newstest2013.tok.bpe.32000.en-de \
--use_wordpiece False \
--token_delimiter ' ' \
--batch_size 4 \
model_path trained_models/pass_20.infer.model \
beam_size 5
beam_size 5 \
max_out_len 256
```
和模型训练时类似,预测时也需要设置数据和 reader 相关的参数,并可以执行 `python infer.py --help` 查看这些参数的说明(部分参数意义和训练时略有不同);同样可以在预测命令中设置模型超参数,但应与模型训练时的设置一致;此外相比于模型训练,预测时还有一些额外的参数,如需要设置 `model_path` 来给出模型所在目录,可以设置 `beam_size``max_out_len` 来指定 Beam Search 算法的搜索宽度和最大深度(翻译长度),这些参数也可以在 `config.py` 中的 `InferTaskConfig` 内查阅注释说明并进行更改设置。
执行以上预测命令会打印翻译结果到标准输出,每行输出是对应行输入的得分最高的翻译。需要注意,对于使用 BPE 的数据,预测出的翻译结果也将是 BPE 表示的数据,要恢复成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估,可以使用以下命令来恢复 `predict.txt` 内的翻译结果到 `predict.tok.txt` 中。
执行以上预测命令会打印翻译结果到标准输出,每行输出是对应行输入的得分最高的翻译。对于使用 BPE 的英德数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估,可以使用以下命令来恢复 `predict.txt` 内的翻译结果到 `predict.tok.txt` 中(无需再次 tokenize 处理):
```sh
sed 's/@@ //g' predict.txt > predict.tok.txt
```
接下来就可以使用参考翻译(这里使用的是 `newstest2013.tok.de`)对翻译结果进行 BLEU 指标的评估了。计算 BLEU 值的一个较为广泛使用的脚本可以从[这里](https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl)获取,获取后执行如下命令:
对于英法翻译的 wordpiece 数据,执行预测和英德翻译预测类似,修改命令中的词典和数据文件为英法数据相应文件的路径,另外需要注意修改 `token_delimiter` 参数的设置为 `--token_delimiter '\x01'`;同时要修改 `use_wordpiece` 参数的设置为 `--use_wordpiece True`,这会在预测时将翻译得到的 wordpiece 数据还原为原始数据输出。为了使用 tokenize 的数据进行评估,还需要对翻译结果进行 tokenize 的处理,[Moses](https://github.com/moses-smt/mosesdecoder) 提供了一系列机器翻译相关的脚本。执行 `git clone https://github.com/moses-smt/mosesdecoder.git` 克隆 mosesdecoder 仓库后,可以使用其中的 `tokenizer.perl` 脚本对 `predict.txt` 内的翻译结果进行 tokenize 处理并输出到 `predict.tok.txt` 中,如下:
```sh
perl mosesdecoder/scripts/tokenizer/tokenizer.perl -l fr < predict.txt > predict.tok.txt
```
接下来就可以使用参考翻译对翻译结果进行 BLEU 指标的评估了。计算 BLEU 值的脚本也在 Moses 中包含,以英德翻译 `newstest2013.tok.de` 数据为例,执行如下命令:
```sh
perl multi-bleu.perl data/newstest2013.tok.de < predict.tok.txt
perl mosesdecoder/scripts/generic/multi-bleu.perl data/newstest2013.tok.de < predict.tok.txt
```
可以看到类似如下的结果。
```
BLEU = 25.08, 58.3/31.5/19.6/12.6 (BP=0.966, ratio=0.967, hyp_len=61321, ref_len=63412)
```
目前在未使用 model average 的情况下,使用默认配置单机八卡(同论文中 base model 的配置)进行训练,英德翻译在 `newstest2013` 上测试 BLEU 值为25.,在 `newstest2014` 上测试 BLEU 值为26.;英法翻译在 `newstest2014` 上测试 BLEU 值为36.。
### 分布式训练
transformer 模型支持同步或者异步的分布式训练。分布式的配置主要两个方面:
Transformer 模型支持同步或者异步的分布式训练。分布式的配置主要两个方面:
1 命令行配置
......@@ -234,3 +253,4 @@ export PADDLE_PORT=6177
2. He K, Zhang X, Ren S, et al. [Deep residual learning for image recognition](http://openaccess.thecvf.com/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf)[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.
3. Ba J L, Kiros J R, Hinton G E. [Layer normalization](https://arxiv.org/pdf/1607.06450.pdf)[J]. arXiv preprint arXiv:1607.06450, 2016.
4. Sennrich R, Haddow B, Birch A. [Neural machine translation of rare words with subword units](https://arxiv.org/pdf/1508.07909)[J]. arXiv preprint arXiv:1508.07909, 2015.
5. Wu Y, Schuster M, Chen Z, et al. [Google's neural machine translation system: Bridging the gap between human and machine translation](https://arxiv.org/pdf/1609.08144.pdf)[J]. arXiv preprint arXiv:1609.08144, 2016.
import argparse
import ast
import numpy as np
from functools import partial
import paddle
import paddle.fluid as fluid
......@@ -11,6 +13,7 @@ from model import fast_decode as fast_decoder
from config import *
from train import pad_batch_data
import reader
import util
def parse_args():
......@@ -46,6 +49,22 @@ def parse_args():
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--use_wordpiece",
type=ast.literal_eval,
default=False,
help="The flag indicating if the data in wordpiece. The EN-FR data "
"we provided is wordpiece data. For wordpiece data, converting ids to "
"original words is a little different and some special codes are "
"provided in util.py to do this.")
parser.add_argument(
"--token_delimiter",
type=partial(
str.decode, encoding="string-escape"),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter.; "
"For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
parser.add_argument(
'opts',
help='See config.py for all options',
......@@ -320,7 +339,7 @@ def post_process_seq(seq,
seq)
def py_infer(test_data, trg_idx2word):
def py_infer(test_data, trg_idx2word, use_wordpiece):
"""
Inference by beam search implented by python, while the calculations from
symbols to probilities execute by Fluid operators.
......@@ -399,7 +418,10 @@ def py_infer(test_data, trg_idx2word):
seqs = map(post_process_seq, batch_seqs[i])
scores = batch_scores[i]
for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq]))
if use_wordpiece:
print(util.subword_ids_to_str(seq, trg_idx2word))
else:
print(" ".join([trg_idx2word[idx] for idx in seq]))
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
......@@ -465,7 +487,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
return input_dict
def fast_infer(test_data, trg_idx2word):
def fast_infer(test_data, trg_idx2word, use_wordpiece):
"""
Inference by beam search decoder based solely on Fluid operators.
"""
......@@ -520,7 +542,9 @@ def fast_infer(test_data, trg_idx2word):
trg_idx2word[idx]
for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end])
]))
]) if not use_wordpiece else util.subtoken_ids_to_str(
post_process_seq(np.array(seq_ids)[sub_start:sub_end]),
trg_idx2word))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print hyps[i][-1]
if len(hyps[i]) >= InferTaskConfig.n_best:
......@@ -534,8 +558,9 @@ def infer(args, inferencer=fast_infer):
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
token_delimiter=args.token_delimiter,
use_token_batch=False,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
......@@ -548,7 +573,7 @@ def infer(args, inferencer=fast_infer):
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
inferencer(test_data, trg_idx2word)
inferencer(test_data, trg_idx2word, args.use_wordpiece)
if __name__ == "__main__":
......
......@@ -116,9 +116,12 @@ class DataReader(object):
:param use_token_batch: Whether to produce batch data according to
token number.
:type use_token_batch: bool
:param delimiter: The delimiter used to split source and target in each
line of data file.
:type delimiter: basestring
:param field_delimiter: The delimiter used to split source and target in
each line of data file.
:type field_delimiter: basestring
:param token_delimiter: The delimiter used to split tokens in source or
target sentences.
:type token_delimiter: basestring
:param start_mark: The token representing for the beginning of
sentences in dictionary.
:type start_mark: basestring
......@@ -145,7 +148,8 @@ class DataReader(object):
shuffle=True,
shuffle_batch=False,
use_token_batch=False,
delimiter="\t",
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
......@@ -164,7 +168,8 @@ class DataReader(object):
self._shuffle_batch = shuffle_batch
self._min_length = min_length
self._max_length = max_length
self._delimiter = delimiter
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self._epoch_batches = []
src_seq_words, trg_seq_words = self._load_data(fpattern, tar_fname)
......@@ -196,7 +201,7 @@ class DataReader(object):
trg_seq_words = []
for line in f_obj:
fields = line.strip().split(self._delimiter)
fields = line.strip().split(self._field_delimiter)
if (not self._only_src and len(fields) != 2) or (self._only_src and
len(fields) != 1):
......@@ -207,7 +212,7 @@ class DataReader(object):
max_len = -1
for i, seq in enumerate(fields):
seq_words = seq.split()
seq_words = seq.split(self._token_delimiter)
max_len = max(max_len, len(seq_words))
if len(seq_words) == 0 or \
len(seq_words) < self._min_length or \
......@@ -258,9 +263,9 @@ class DataReader(object):
with open(dict_path, "r") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip()
word_dict[idx] = line.strip('\n')
else:
word_dict[line.strip()] = idx
word_dict[line.strip('\n')] = idx
return word_dict
def _sample_generator(self):
......
......@@ -4,6 +4,7 @@ import argparse
import ast
import numpy as np
import multiprocessing
from functools import partial
import paddle
import paddle.fluid as fluid
......@@ -76,6 +77,14 @@ def parse_args():
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=partial(
str.decode, encoding="string-escape"),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. "
"For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
parser.add_argument(
'opts',
help='See config.py for all options',
......@@ -273,6 +282,7 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.val_file_pattern,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size * (1 if args.use_token_batch else dev_count),
pool_size=args.pool_size,
......@@ -335,6 +345,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.train_file_pattern,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size * (1 if args.use_token_batch else dev_count),
pool_size=args.pool_size,
......@@ -413,6 +424,10 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" %
(pass_id, batch_id, total_sum_cost, total_avg_cost,
np.exp([min(total_avg_cost, 100)])))
if batch_id > 0 and batch_id % 1000 == 0:
fluid.io.save_persistables(
exe,
os.path.join(TrainTaskConfig.ckpt_dir, "latest.checkpoint"))
init = True
# Validate and save the model for inference.
print("epoch: %d, " % pass_id +
......
import sys
import re
import six
import unicodedata
# Regular expression for unescaping token strings.
# '\u' is converted to '_'
# '\\' is converted to '\'
# '\213;' is converted to unichr(213)
# Inverse of escaping.
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
# This set contains all letter and number characters.
_ALPHANUMERIC_CHAR_SET = set(
six.unichr(i) for i in range(sys.maxunicode)
if (unicodedata.category(six.unichr(i)).startswith("L") or
unicodedata.category(six.unichr(i)).startswith("N")))
def unescape_token(escaped_token):
"""
Inverse of encoding escaping.
"""
def match(m):
if m.group(1) is None:
return u"_" if m.group(0) == u"\\u" else u"\\"
try:
return six.unichr(int(m.group(1)))
except (ValueError, OverflowError) as _:
return u"\u3013" # Unicode for undefined character.
trimmed = escaped_token[:-1] if escaped_token.endswith(
"_") else escaped_token
return _UNESCAPE_REGEX.sub(match, trimmed)
def subtoken_ids_to_str(subtoken_ids, vocabs):
"""
Convert a list of subtoken(word piece) ids to a native string.
Refer to SubwordTextEncoder in Tensor2Tensor.
"""
subtokens = [vocabs.get(subtoken_id, u"") for subtoken_id in subtoken_ids]
# Convert a list of subtokens to a list of tokens.
concatenated = "".join([
t if isinstance(t, unicode) else t.decode("utf-8") for t in subtokens
])
split = concatenated.split("_")
tokens = []
for t in split:
if t:
unescaped = unescape_token(t + "_")
if unescaped:
tokens.append(unescaped)
# Convert a list of tokens to a unicode string (by inserting spaces bewteen
# word tokens).
token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens]
ret = []
for i, token in enumerate(tokens):
if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]:
ret.append(u" ")
ret.append(token)
seq = "".join(ret)
return seq.encode("utf-8")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册