未验证 提交 f80c221c 编写于 作者: K KP 提交者: GitHub

Solve code formatting problem for ci

上级 fda276b1
...@@ -82,23 +82,24 @@ class MTTransformer(nn.Layer): ...@@ -82,23 +82,24 @@ class MTTransformer(nn.Layer):
self.max_length = max_length self.max_length = max_length
self.beam_size = beam_size self.beam_size = beam_size
self.tokenizer = MTTokenizer(bpe_codes_file=bpe_codes_file, self.tokenizer = MTTokenizer(
lang_src=self.lang_config['source'], bpe_codes_file=bpe_codes_file, lang_src=self.lang_config['source'], lang_trg=self.lang_config['target'])
lang_trg=self.lang_config['target']) self.vocab = Vocab.load_vocabulary(
self.vocab = Vocab.load_vocabulary(filepath=vocab_file, filepath=vocab_file,
unk_token=self.vocab_config['unk_token'], unk_token=self.vocab_config['unk_token'],
bos_token=self.vocab_config['bos_token'], bos_token=self.vocab_config['bos_token'],
eos_token=self.vocab_config['eos_token']) eos_token=self.vocab_config['eos_token'])
self.vocab_size = (len(self.vocab) + self.vocab_config['pad_factor'] - 1) \ self.vocab_size = (len(self.vocab) + self.vocab_config['pad_factor'] - 1) \
// self.vocab_config['pad_factor'] * self.vocab_config['pad_factor'] // self.vocab_config['pad_factor'] * self.vocab_config['pad_factor']
self.transformer = InferTransformerModel(src_vocab_size=self.vocab_size, self.transformer = InferTransformerModel(
trg_vocab_size=self.vocab_size, src_vocab_size=self.vocab_size,
bos_id=self.vocab_config['bos_id'], trg_vocab_size=self.vocab_size,
eos_id=self.vocab_config['eos_id'], bos_id=self.vocab_config['bos_id'],
max_length=self.max_length + 1, eos_id=self.vocab_config['eos_id'],
max_out_len=max_out_len, max_length=self.max_length + 1,
beam_size=self.beam_size, max_out_len=max_out_len,
**self.model_config) beam_size=self.beam_size,
**self.model_config)
state_dict = paddle.load(checkpoint) state_dict = paddle.load(checkpoint)
......
...@@ -24,11 +24,12 @@ class MTTokenizer(object): ...@@ -24,11 +24,12 @@ class MTTokenizer(object):
def __init__(self, bpe_codes_file: str, lang_src: str = 'en', lang_trg: str = 'de', separator='@@'): def __init__(self, bpe_codes_file: str, lang_src: str = 'en', lang_trg: str = 'de', separator='@@'):
self.moses_tokenizer = MosesTokenizer(lang=lang_src) self.moses_tokenizer = MosesTokenizer(lang=lang_src)
self.moses_detokenizer = MosesDetokenizer(lang=lang_trg) self.moses_detokenizer = MosesDetokenizer(lang=lang_trg)
self.bpe_tokenizer = BPE(codes=codecs.open(bpe_codes_file, encoding='utf-8'), self.bpe_tokenizer = BPE(
merges=-1, codes=codecs.open(bpe_codes_file, encoding='utf-8'),
separator=separator, merges=-1,
vocab=None, separator=separator,
glossaries=None) vocab=None,
glossaries=None)
def tokenize(self, text: str): def tokenize(self, text: str):
""" """
......
...@@ -83,29 +83,31 @@ class MTTransformer(nn.Layer): ...@@ -83,29 +83,31 @@ class MTTransformer(nn.Layer):
self.max_length = max_length self.max_length = max_length
self.beam_size = beam_size self.beam_size = beam_size
self.tokenizer = MTTokenizer(bpe_codes_file=bpe_codes_file, self.tokenizer = MTTokenizer(
lang_src=self.lang_config['source'], bpe_codes_file=bpe_codes_file, lang_src=self.lang_config['source'], lang_trg=self.lang_config['target'])
lang_trg=self.lang_config['target']) self.src_vocab = Vocab.load_vocabulary(
self.src_vocab = Vocab.load_vocabulary(filepath=src_vocab_file, filepath=src_vocab_file,
unk_token=self.vocab_config['unk_token'], unk_token=self.vocab_config['unk_token'],
bos_token=self.vocab_config['bos_token'], bos_token=self.vocab_config['bos_token'],
eos_token=self.vocab_config['eos_token']) eos_token=self.vocab_config['eos_token'])
self.trg_vocab = Vocab.load_vocabulary(filepath=trg_vocab_file, self.trg_vocab = Vocab.load_vocabulary(
unk_token=self.vocab_config['unk_token'], filepath=trg_vocab_file,
bos_token=self.vocab_config['bos_token'], unk_token=self.vocab_config['unk_token'],
eos_token=self.vocab_config['eos_token']) bos_token=self.vocab_config['bos_token'],
eos_token=self.vocab_config['eos_token'])
self.src_vocab_size = (len(self.src_vocab) + self.vocab_config['pad_factor'] - 1) \ self.src_vocab_size = (len(self.src_vocab) + self.vocab_config['pad_factor'] - 1) \
// self.vocab_config['pad_factor'] * self.vocab_config['pad_factor'] // self.vocab_config['pad_factor'] * self.vocab_config['pad_factor']
self.trg_vocab_size = (len(self.trg_vocab) + self.vocab_config['pad_factor'] - 1) \ self.trg_vocab_size = (len(self.trg_vocab) + self.vocab_config['pad_factor'] - 1) \
// self.vocab_config['pad_factor'] * self.vocab_config['pad_factor'] // self.vocab_config['pad_factor'] * self.vocab_config['pad_factor']
self.transformer = InferTransformerModel(src_vocab_size=self.src_vocab_size, self.transformer = InferTransformerModel(
trg_vocab_size=self.trg_vocab_size, src_vocab_size=self.src_vocab_size,
bos_id=self.vocab_config['bos_id'], trg_vocab_size=self.trg_vocab_size,
eos_id=self.vocab_config['eos_id'], bos_id=self.vocab_config['bos_id'],
max_length=self.max_length + 1, eos_id=self.vocab_config['eos_id'],
max_out_len=max_out_len, max_length=self.max_length + 1,
beam_size=self.beam_size, max_out_len=max_out_len,
**self.model_config) beam_size=self.beam_size,
**self.model_config)
state_dict = paddle.load(checkpoint) state_dict = paddle.load(checkpoint)
......
...@@ -28,11 +28,12 @@ from subword_nmt.apply_bpe import BPE ...@@ -28,11 +28,12 @@ from subword_nmt.apply_bpe import BPE
class MTTokenizer(object): class MTTokenizer(object):
def __init__(self, bpe_codes_file: str, lang_src: str = 'zh', lang_trg: str = 'en', separator='@@'): def __init__(self, bpe_codes_file: str, lang_src: str = 'zh', lang_trg: str = 'en', separator='@@'):
self.moses_detokenizer = MosesDetokenizer(lang=lang_trg) self.moses_detokenizer = MosesDetokenizer(lang=lang_trg)
self.bpe_tokenizer = BPE(codes=codecs.open(bpe_codes_file, encoding='utf-8'), self.bpe_tokenizer = BPE(
merges=-1, codes=codecs.open(bpe_codes_file, encoding='utf-8'),
separator=separator, merges=-1,
vocab=None, separator=separator,
glossaries=None) vocab=None,
glossaries=None)
def tokenize(self, text: str): def tokenize(self, text: str):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册