未验证 提交 10772bf2 编写于 作者: K KP 提交者: GitHub

Add machine translation model

上级 c0f4bb58
```shell
$ hub install transformer_en-de==1.0.0
```
## 概述
2017 年,Google机器翻译团队在其发表的论文[Attention Is All You Need](https://arxiv.org/abs/1706.03762)中,提出了用于完成机器翻译(Machine Translation)等序列到序列(Seq2Seq)学习任务的一种全新网络结构——Transformer。Tranformer网络完全使用注意力(Attention)机制来实现序列到序列的建模,并且取得了很好的效果。
transformer_en-de包含6层的transformer结构,头数为8,隐藏层参数为512,参数量为64M。该模型在[WMT'14 EN-DE数据集](http://www.statmt.org/wmt14/translation-task.html)进行了预训练,加载后可直接用于预测,提供了英文翻译为德文的能力。
关于机器翻译的Transformer模型训练方式和详情,可查看[Machine Translation using Transformer](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/machine_translation/transformer)
## API
```python
def __init__(max_length: int = 256,
max_out_len: int = 256,
beam_size: int = 5):
```
初始化module,可配置模型的输入输出文本的最大长度和解码时beam search的宽度。
**参数**
- `max_length`(int): 输入文本的最大长度,默认值为256。
- `max_out_len`(int): 输出文本的最大解码长度,默认值为256。
- `beam_size`(int): beam search方式解码的beam宽度,默认为5。
```python
def predict(data: List[str],
batch_size: int = 1,
n_best: int = 1,
use_gpu: bool = False):
```
预测API,输入源语言的文本句子,解码后输出翻译后的目标语言的文本候选句子。
**参数**
- `data`(List[str]): 源语言的文本列表,数据类型为List[str]
- `batch_size`(int): 进行预测的batch_size,默认为1
- `n_best`(int): 每个输入文本经过模型解码后,输出的得分最高的候选句子的数量,必须小于beam_size,默认为1
- `use_gpu`(bool): 是否使用gpu执行预测,默认为False
**返回**
* `results`(List[str]): 翻译后的目标语言的候选句子,长度为`len(data)*n_best`
**代码示例**
```python
import paddlehub as hub
model = hub.Module(name='transformer_en-de', beam_size=5)
src_texts = [
'What are you doing now?',
'The change was for the better; I eat well, I exercise, I take my drugs.',
'Such experiments are not conducted for ethical reasons.',
]
n_best = 3 # 每个输入样本的输出候选句子数量
trg_texts = model.predict(src_texts, n_best=n_best)
for idx, st in enumerate(src_texts):
print('-'*30)
print(f'src: {st}')
for i in range(n_best):
print(f'trg[{i+1}]: {trg_texts[idx*n_best+i]}')
```
## 服务部署
通过启动PaddleHub Serving,可以加载模型部署在线翻译服务。
### Step1: 启动PaddleHub Serving
运行启动命令:
```shell
$ hub serving start -m transformer_en-de
```
通过以上命令可完成一个英德机器翻译API的部署,默认端口号为8866。
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。
### Step2: 发送预测请求
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python
import requests
import json
texts = [
'What are you doing now?',
'The change was for the better; I eat well, I exercise, I take my drugs.',
'Such experiments are not conducted for ethical reasons.',
]
data = {"data": texts}
# 发送post请求,content-type类型应指定json方式,url中的ip地址需改为对应机器的ip
url = "http://127.0.0.1:8866/predict/transformer_en-de"
# 指定post请求的headers为application/json方式
headers = {"Content-Type": "application/json"}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json())
```
## 查看代码
https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/machine_translation/transformer
## 依赖
paddlepaddle >= 2.0.0
paddlehub >= 2.1.0
## 更新历史
* 1.0.0
初始发布
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import os
import paddle
import paddle.nn as nn
from paddlehub.env import MODULE_HOME
from paddlehub.module.module import moduleinfo, serving
from paddlenlp.data import Pad, Vocab
from paddlenlp.transformers import InferTransformerModel, position_encoding_init
from transformer_en_de.utils import MTTokenizer, post_process_seq
@moduleinfo(
name="transformer_en-de",
version="1.0.0",
summary="",
author="PaddlePaddle",
author_email="",
type="nlp/machine_translation",
)
class MTTransformer(nn.Layer):
"""
Transformer model for machine translation.
"""
# Language config
lang_config = {'source': 'en', 'target': 'de'}
# Model config
model_config = {
# Number of sub-layers to be stacked in the encoder and decoder.
"n_layer": 6,
# Number of head used in multi-head attention.
"n_head": 8,
# The dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
"d_model": 512,
# Size of the hidden layer in position-wise feed-forward networks.
"d_inner_hid": 2048,
# The flag indicating whether to share embedding and softmax weights.
# Vocabularies in source and target should be same for weight sharing.
"weight_sharing": True,
# Dropout rate
'dropout': 0
}
# Vocab config
vocab_config = {
# Used to pad vocab size to be multiple of pad_factor.
"pad_factor": 8,
# Index for <bos> token
"bos_id": 0,
"bos_token": "<s>",
# Index for <eos> token
"eos_id": 1,
"eos_token": "<e>",
# Index for <unk> token
"unk_id": 2,
"unk_token": "<unk>",
}
def __init__(self, max_length: int = 256, max_out_len: int = 256, beam_size: int = 5):
super(MTTransformer, self).__init__()
bpe_codes_file = os.path.join(MODULE_HOME, 'transformer_en_de', 'assets', 'bpe.33708')
vocab_file = os.path.join(MODULE_HOME, 'transformer_en_de', 'assets', 'vocab_all.bpe.33708')
checkpoint = os.path.join(MODULE_HOME, 'transformer_en_de', 'assets', 'transformer.pdparams')
self.max_length = max_length
self.beam_size = beam_size
self.tokenizer = MTTokenizer(bpe_codes_file=bpe_codes_file,
lang_src=self.lang_config['source'],
lang_trg=self.lang_config['target'])
self.vocab = Vocab.load_vocabulary(filepath=vocab_file,
unk_token=self.vocab_config['unk_token'],
bos_token=self.vocab_config['bos_token'],
eos_token=self.vocab_config['eos_token'])
self.vocab_size = (len(self.vocab) + self.vocab_config['pad_factor'] - 1) \
// self.vocab_config['pad_factor'] * self.vocab_config['pad_factor']
self.transformer = InferTransformerModel(src_vocab_size=self.vocab_size,
trg_vocab_size=self.vocab_size,
bos_id=self.vocab_config['bos_id'],
eos_id=self.vocab_config['eos_id'],
max_length=self.max_length + 1,
max_out_len=max_out_len,
beam_size=self.beam_size,
**self.model_config)
state_dict = paddle.load(checkpoint)
# To avoid a longer length than training, reset the size of position
# encoding to max_length
state_dict["encoder.pos_encoder.weight"] = position_encoding_init(self.max_length + 1,
self.model_config['d_model'])
state_dict["decoder.pos_encoder.weight"] = position_encoding_init(self.max_length + 1,
self.model_config['d_model'])
self.transformer.set_state_dict(state_dict)
def forward(self, src_words: paddle.Tensor):
return self.transformer(src_words)
def _convert_text_to_input(self, text: str):
"""
Convert input string to ids.
"""
bpe_tokens = self.tokenizer.tokenize(text)
if len(bpe_tokens) > self.max_length:
bpe_tokens = bpe_tokens[:self.max_length]
return self.vocab.to_indices(bpe_tokens)
def _batchify(self, data: List[str], batch_size: int):
"""
Generate input batches.
"""
pad_func = Pad(self.vocab_config['eos_id'])
def _parse_batch(batch_ids):
return pad_func([ids + [self.vocab_config['eos_id']] for ids in batch_ids])
examples = []
for text in data:
examples.append(self._convert_text_to_input(text))
# Seperates data into some batches.
one_batch = []
for example in examples:
one_batch.append(example)
if len(one_batch) == batch_size:
yield _parse_batch(one_batch)
one_batch = []
if one_batch:
yield _parse_batch(one_batch)
@serving
def predict(self, data: List[str], batch_size: int = 1, n_best: int = 1, use_gpu: bool = False):
if n_best > self.beam_size:
raise ValueError(f'Predict arg "n_best" must be smaller or equal to self.beam_size, \
but got {n_best} > {self.beam_size}')
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')
batches = self._batchify(data, batch_size)
results = []
self.eval()
for batch in batches:
src_batch_ids = paddle.to_tensor(batch)
trg_batch_beams = self(src_batch_ids).numpy().transpose([0, 2, 1])
for trg_sample_beams in trg_batch_beams:
for beam_idx, beam in enumerate(trg_sample_beams):
if beam_idx >= n_best:
break
trg_sample_ids = post_process_seq(beam, self.vocab_config['bos_id'], self.vocab_config['eos_id'])
trg_sample_words = self.vocab.to_tokens(trg_sample_ids)
trg_sample_text = self.tokenizer.detokenize(trg_sample_words)
results.append(trg_sample_text)
return results
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import List
import codecs
try:
import nltk
nltk.data.find('misc/perluniprops')
nltk.data.find('corpora/nonbreaking_prefixes')
except LookupError:
nltk.download('perluniprops')
nltk.download('nonbreaking_prefixes')
from nltk.tokenize.moses import MosesTokenizer, MosesDetokenizer
from subword_nmt.apply_bpe import BPE
class MTTokenizer(object):
def __init__(self, bpe_codes_file: str, lang_src: str = 'en', lang_trg: str = 'de', separator='@@'):
self.moses_tokenizer = MosesTokenizer(lang=lang_src)
self.moses_detokenizer = MosesDetokenizer(lang=lang_trg)
self.bpe_tokenizer = BPE(codes=codecs.open(bpe_codes_file, encoding='utf-8'),
merges=-1,
separator=separator,
vocab=None,
glossaries=None)
def tokenize(self, text: str):
"""
Convert source string into bpe tokens.
"""
moses_tokens = self.moses_tokenizer.tokenize(text)
tokenized_text = ' '.join(moses_tokens)
tokenized_bpe_text = self.bpe_tokenizer.process_line(tokenized_text) # Apply bpe to text
bpe_tokens = tokenized_bpe_text.split(' ')
return bpe_tokens
def detokenize(self, tokens: List[str]):
"""
Convert target bpe tokens into string.
"""
separator = self.bpe_tokenizer.separator
text_with_separators = ' '.join(tokens)
clean_text = re.sub(f'({separator} )|({separator} ?$)', '', text_with_separators)
clean_tokens = clean_text.split(' ')
detokenized_text = self.moses_detokenizer.tokenize(clean_tokens, return_str=True)
return detokenized_text
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
"""
Post-process the decoded sequence.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [int(idx) for idx in seq[:eos_pos + 1] if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)]
return seq
```shell
$ hub install transformer_zh-en==1.0.0
```
## 概述
2017 年,Google机器翻译团队在其发表的论文[Attention Is All You Need](https://arxiv.org/abs/1706.03762)中,提出了用于完成机器翻译(Machine Translation)等序列到序列(Seq2Seq)学习任务的一种全新网络结构——Transformer。Tranformer网络完全使用注意力(Attention)机制来实现序列到序列的建模,并且取得了很好的效果。
transformer_zh-en包含6层的transformer结构,头数为8,隐藏层参数为512,参数量为64M。该模型在[CWMT2021的数据集](http://nlp.nju.edu.cn/cwmt-wmt)进行了预训练,加载后可直接用于预测, 提供了中文翻译为英文的能力。
关于机器翻译的Transformer模型训练方式和详情,可查看[Machine Translation using Transformer](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/machine_translation/transformer)
## API
```python
def __init__(max_length: int = 256,
max_out_len: int = 256,
beam_size: int = 5):
```
初始化module,可配置模型的输入输出文本的最大长度和解码时beam search的宽度。
**参数**
- `max_length`(int): 输入文本的最大长度,默认值为256。
- `max_out_len`(int): 输出文本的最大解码长度,默认值为256。
- `beam_size`(int): beam search方式解码的beam宽度,默认为5。
```python
def predict(data: List[str],
batch_size: int = 1,
n_best: int = 1,
use_gpu: bool = False):
```
预测API,输入源语言的文本句子,解码后输出翻译后的目标语言的文本候选句子。
**参数**
- `data`(List[str]): 源语言的文本列表,数据类型为List[str]
- `batch_size`(int): 进行预测的batch_size,默认为1
- `n_best`(int): 每个输入文本经过模型解码后,输出的得分最高的候选句子的数量,必须小于beam_size,默认为1
- `use_gpu`(bool): 是否使用gpu执行预测,默认为False
**返回**
* `results`(List[str]): 翻译后的目标语言的候选句子,长度为`len(data)*n_best`
**代码示例**
```python
import paddlehub as hub
model = hub.Module(name='transformer_zh-en', beam_size=5)
src_texts = [
'今天天气怎么样?',
'我们一起去吃饭吧。',
]
n_best = 3 # 每个输入样本的输出候选句子数量
trg_texts = model.predict(src_texts, n_best=n_best)
for idx, st in enumerate(src_texts):
print('-'*30)
print(f'src: {st}')
for i in range(n_best):
print(f'trg[{i+1}]: {trg_texts[idx*n_best+i]}')
```
## 服务部署
通过启动PaddleHub Serving,可以加载模型部署在线翻译服务。
### Step1: 启动PaddleHub Serving
运行启动命令:
```shell
$ hub serving start -m transformer_zh-en
```
通过以上命令可完成一个中英机器翻译API的部署,默认端口号为8866。
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。
### Step2: 发送预测请求
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python
import requests
import json
texts = [
'今天天气怎么样啊?',
'我们一起去吃饭吧。',
]
data = {"data": texts}
# 发送post请求,content-type类型应指定json方式,url中的ip地址需改为对应机器的ip
url = "http://127.0.0.1:8866/predict/transformer_zh-en"
# 指定post请求的headers为application/json方式
headers = {"Content-Type": "application/json"}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json())
```
## 查看代码
https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/machine_translation/transformer
## 依赖
paddlepaddle >= 2.0.0
paddlehub >= 2.1.0
## 更新历史
* 1.0.0
初始发布
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import os
import paddle
import paddle.nn as nn
from paddlehub.env import MODULE_HOME
from paddlehub.module.module import moduleinfo, serving
from paddlenlp.data import Pad, Vocab
from paddlenlp.transformers import InferTransformerModel, position_encoding_init
from transformer_zh_en.utils import MTTokenizer, post_process_seq
@moduleinfo(
name="transformer_zh-en",
version="1.0.0",
summary="",
author="PaddlePaddle",
author_email="",
type="nlp/machine_translation",
)
class MTTransformer(nn.Layer):
"""
Transformer model for machine translation.
"""
# Language config
lang_config = {'source': 'zh', 'target': 'en'}
# Model config
model_config = {
# Number of sub-layers to be stacked in the encoder and decoder.
"n_layer": 6,
# Number of head used in multi-head attention.
"n_head": 8,
# The dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
"d_model": 512,
# Size of the hidden layer in position-wise feed-forward networks.
"d_inner_hid": 2048,
# The flag indicating whether to share embedding and softmax weights.
# Vocabularies in source and target should be same for weight sharing.
"weight_sharing": False,
# Dropout rate
'dropout': 0
}
# Vocab config
vocab_config = {
# Used to pad vocab size to be multiple of pad_factor.
"pad_factor": 8,
# Index for <bos> token
"bos_id": 0,
"bos_token": "<s>",
# Index for <eos> token
"eos_id": 1,
"eos_token": "<e>",
# Index for <unk> token
"unk_id": 2,
"unk_token": "<unk>",
}
def __init__(self, max_length: int = 256, max_out_len: int = 256, beam_size: int = 5):
super(MTTransformer, self).__init__()
bpe_codes_file = os.path.join(MODULE_HOME, 'transformer_zh_en', 'assets', '2M.zh2en.dict4bpe.zh')
src_vocab_file = os.path.join(MODULE_HOME, 'transformer_zh_en', 'assets', 'vocab.zh')
trg_vocab_file = os.path.join(MODULE_HOME, 'transformer_zh_en', 'assets', 'vocab.en')
checkpoint = os.path.join(MODULE_HOME, 'transformer_zh_en', 'assets', 'transformer.pdparams')
self.max_length = max_length
self.beam_size = beam_size
self.tokenizer = MTTokenizer(bpe_codes_file=bpe_codes_file,
lang_src=self.lang_config['source'],
lang_trg=self.lang_config['target'])
self.src_vocab = Vocab.load_vocabulary(filepath=src_vocab_file,
unk_token=self.vocab_config['unk_token'],
bos_token=self.vocab_config['bos_token'],
eos_token=self.vocab_config['eos_token'])
self.trg_vocab = Vocab.load_vocabulary(filepath=trg_vocab_file,
unk_token=self.vocab_config['unk_token'],
bos_token=self.vocab_config['bos_token'],
eos_token=self.vocab_config['eos_token'])
self.src_vocab_size = (len(self.src_vocab) + self.vocab_config['pad_factor'] - 1) \
// self.vocab_config['pad_factor'] * self.vocab_config['pad_factor']
self.trg_vocab_size = (len(self.trg_vocab) + self.vocab_config['pad_factor'] - 1) \
// self.vocab_config['pad_factor'] * self.vocab_config['pad_factor']
self.transformer = InferTransformerModel(src_vocab_size=self.src_vocab_size,
trg_vocab_size=self.trg_vocab_size,
bos_id=self.vocab_config['bos_id'],
eos_id=self.vocab_config['eos_id'],
max_length=self.max_length + 1,
max_out_len=max_out_len,
beam_size=self.beam_size,
**self.model_config)
state_dict = paddle.load(checkpoint)
# To avoid a longer length than training, reset the size of position
# encoding to max_length
state_dict["encoder.pos_encoder.weight"] = position_encoding_init(self.max_length + 1,
self.model_config['d_model'])
state_dict["decoder.pos_encoder.weight"] = position_encoding_init(self.max_length + 1,
self.model_config['d_model'])
self.transformer.set_state_dict(state_dict)
def forward(self, src_words: paddle.Tensor):
return self.transformer(src_words)
def _convert_text_to_input(self, text: str):
"""
Convert input string to ids.
"""
bpe_tokens = self.tokenizer.tokenize(text)
if len(bpe_tokens) > self.max_length:
bpe_tokens = bpe_tokens[:self.max_length]
return self.src_vocab.to_indices(bpe_tokens)
def _batchify(self, data: List[str], batch_size: int):
"""
Generate input batches.
"""
pad_func = Pad(self.vocab_config['eos_id'])
def _parse_batch(batch_ids):
return pad_func([ids + [self.vocab_config['eos_id']] for ids in batch_ids])
examples = []
for text in data:
examples.append(self._convert_text_to_input(text))
# Seperates data into some batches.
one_batch = []
for example in examples:
one_batch.append(example)
if len(one_batch) == batch_size:
yield _parse_batch(one_batch)
one_batch = []
if one_batch:
yield _parse_batch(one_batch)
@serving
def predict(self, data: List[str], batch_size: int = 1, n_best: int = 1, use_gpu: bool = False):
if n_best > self.beam_size:
raise ValueError(f'Predict arg "n_best" must be smaller or equal to self.beam_size, \
but got {n_best} > {self.beam_size}')
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')
batches = self._batchify(data, batch_size)
results = []
self.eval()
for batch in batches:
src_batch_ids = paddle.to_tensor(batch)
trg_batch_beams = self(src_batch_ids).numpy().transpose([0, 2, 1])
for trg_sample_beams in trg_batch_beams:
for beam_idx, beam in enumerate(trg_sample_beams):
if beam_idx >= n_best:
break
trg_sample_ids = post_process_seq(beam, self.vocab_config['bos_id'], self.vocab_config['eos_id'])
trg_sample_words = self.trg_vocab.to_tokens(trg_sample_ids)
trg_sample_text = self.tokenizer.detokenize(trg_sample_words)
results.append(trg_sample_text)
return results
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from typing import List
import codecs
import jieba
jieba.setLogLevel(logging.INFO)
try:
import nltk
nltk.data.find('misc/perluniprops')
nltk.data.find('corpora/nonbreaking_prefixes')
except LookupError:
nltk.download('perluniprops')
nltk.download('nonbreaking_prefixes')
from nltk.tokenize.moses import MosesDetokenizer
from subword_nmt.apply_bpe import BPE
class MTTokenizer(object):
def __init__(self, bpe_codes_file: str, lang_src: str = 'zh', lang_trg: str = 'en', separator='@@'):
self.moses_detokenizer = MosesDetokenizer(lang=lang_trg)
self.bpe_tokenizer = BPE(codes=codecs.open(bpe_codes_file, encoding='utf-8'),
merges=-1,
separator=separator,
vocab=None,
glossaries=None)
def tokenize(self, text: str):
"""
Convert source string into bpe tokens.
"""
text = text.replace(' ', '') # Remove blanks in Chinese text.
jieba_tokens = list(jieba.cut(text))
tokenized_text = ' '.join(jieba_tokens)
tokenized_bpe_text = self.bpe_tokenizer.process_line(tokenized_text) # Apply bpe to text
bpe_tokens = tokenized_bpe_text.split(' ')
return bpe_tokens
def detokenize(self, tokens: List[str]):
"""
Convert target bpe tokens into string.
"""
separator = self.bpe_tokenizer.separator
text_with_separators = ' '.join(tokens)
clean_text = re.sub(f'({separator} )|({separator} ?$)', '', text_with_separators)
clean_tokens = clean_text.split(' ')
detokenized_text = self.moses_detokenizer.tokenize(clean_tokens, return_str=True)
return detokenized_text
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
"""
Post-process the decoded sequence.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [int(idx) for idx in seq[:eos_pos + 1] if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)]
return seq
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册