diff --git a/modules/text/machine_translation/transformer/en-de/module.py b/modules/text/machine_translation/transformer/en-de/module.py index a3826dd1eb0c9a2dab8873679046b7aadae3eda4..5508108a93cd1339452c8bfc9f81a63042c49233 100644 --- a/modules/text/machine_translation/transformer/en-de/module.py +++ b/modules/text/machine_translation/transformer/en-de/module.py @@ -82,23 +82,24 @@ class MTTransformer(nn.Layer): self.max_length = max_length self.beam_size = beam_size - self.tokenizer = MTTokenizer(bpe_codes_file=bpe_codes_file, - lang_src=self.lang_config['source'], - lang_trg=self.lang_config['target']) - self.vocab = Vocab.load_vocabulary(filepath=vocab_file, - unk_token=self.vocab_config['unk_token'], - bos_token=self.vocab_config['bos_token'], - eos_token=self.vocab_config['eos_token']) + self.tokenizer = MTTokenizer( + bpe_codes_file=bpe_codes_file, lang_src=self.lang_config['source'], lang_trg=self.lang_config['target']) + self.vocab = Vocab.load_vocabulary( + filepath=vocab_file, + unk_token=self.vocab_config['unk_token'], + bos_token=self.vocab_config['bos_token'], + eos_token=self.vocab_config['eos_token']) self.vocab_size = (len(self.vocab) + self.vocab_config['pad_factor'] - 1) \ // self.vocab_config['pad_factor'] * self.vocab_config['pad_factor'] - self.transformer = InferTransformerModel(src_vocab_size=self.vocab_size, - trg_vocab_size=self.vocab_size, - bos_id=self.vocab_config['bos_id'], - eos_id=self.vocab_config['eos_id'], - max_length=self.max_length + 1, - max_out_len=max_out_len, - beam_size=self.beam_size, - **self.model_config) + self.transformer = InferTransformerModel( + src_vocab_size=self.vocab_size, + trg_vocab_size=self.vocab_size, + bos_id=self.vocab_config['bos_id'], + eos_id=self.vocab_config['eos_id'], + max_length=self.max_length + 1, + max_out_len=max_out_len, + beam_size=self.beam_size, + **self.model_config) state_dict = paddle.load(checkpoint) diff --git a/modules/text/machine_translation/transformer/en-de/utils.py b/modules/text/machine_translation/transformer/en-de/utils.py index 761f656b00d4533d469d29eecb2eef552305a2d8..ea3ceba9b327da9a5b7d879650e5e1e75b3094d2 100644 --- a/modules/text/machine_translation/transformer/en-de/utils.py +++ b/modules/text/machine_translation/transformer/en-de/utils.py @@ -24,11 +24,12 @@ class MTTokenizer(object): def __init__(self, bpe_codes_file: str, lang_src: str = 'en', lang_trg: str = 'de', separator='@@'): self.moses_tokenizer = MosesTokenizer(lang=lang_src) self.moses_detokenizer = MosesDetokenizer(lang=lang_trg) - self.bpe_tokenizer = BPE(codes=codecs.open(bpe_codes_file, encoding='utf-8'), - merges=-1, - separator=separator, - vocab=None, - glossaries=None) + self.bpe_tokenizer = BPE( + codes=codecs.open(bpe_codes_file, encoding='utf-8'), + merges=-1, + separator=separator, + vocab=None, + glossaries=None) def tokenize(self, text: str): """ diff --git a/modules/text/machine_translation/transformer/zh-en/module.py b/modules/text/machine_translation/transformer/zh-en/module.py index 647c08f9a955dbe0b3352d5d74131b5e11635749..ad594d95f7b579dfea71cfbdd3efe23a8ce9431c 100644 --- a/modules/text/machine_translation/transformer/zh-en/module.py +++ b/modules/text/machine_translation/transformer/zh-en/module.py @@ -83,29 +83,31 @@ class MTTransformer(nn.Layer): self.max_length = max_length self.beam_size = beam_size - self.tokenizer = MTTokenizer(bpe_codes_file=bpe_codes_file, - lang_src=self.lang_config['source'], - lang_trg=self.lang_config['target']) - self.src_vocab = Vocab.load_vocabulary(filepath=src_vocab_file, - unk_token=self.vocab_config['unk_token'], - bos_token=self.vocab_config['bos_token'], - eos_token=self.vocab_config['eos_token']) - self.trg_vocab = Vocab.load_vocabulary(filepath=trg_vocab_file, - unk_token=self.vocab_config['unk_token'], - bos_token=self.vocab_config['bos_token'], - eos_token=self.vocab_config['eos_token']) + self.tokenizer = MTTokenizer( + bpe_codes_file=bpe_codes_file, lang_src=self.lang_config['source'], lang_trg=self.lang_config['target']) + self.src_vocab = Vocab.load_vocabulary( + filepath=src_vocab_file, + unk_token=self.vocab_config['unk_token'], + bos_token=self.vocab_config['bos_token'], + eos_token=self.vocab_config['eos_token']) + self.trg_vocab = Vocab.load_vocabulary( + filepath=trg_vocab_file, + unk_token=self.vocab_config['unk_token'], + bos_token=self.vocab_config['bos_token'], + eos_token=self.vocab_config['eos_token']) self.src_vocab_size = (len(self.src_vocab) + self.vocab_config['pad_factor'] - 1) \ // self.vocab_config['pad_factor'] * self.vocab_config['pad_factor'] self.trg_vocab_size = (len(self.trg_vocab) + self.vocab_config['pad_factor'] - 1) \ // self.vocab_config['pad_factor'] * self.vocab_config['pad_factor'] - self.transformer = InferTransformerModel(src_vocab_size=self.src_vocab_size, - trg_vocab_size=self.trg_vocab_size, - bos_id=self.vocab_config['bos_id'], - eos_id=self.vocab_config['eos_id'], - max_length=self.max_length + 1, - max_out_len=max_out_len, - beam_size=self.beam_size, - **self.model_config) + self.transformer = InferTransformerModel( + src_vocab_size=self.src_vocab_size, + trg_vocab_size=self.trg_vocab_size, + bos_id=self.vocab_config['bos_id'], + eos_id=self.vocab_config['eos_id'], + max_length=self.max_length + 1, + max_out_len=max_out_len, + beam_size=self.beam_size, + **self.model_config) state_dict = paddle.load(checkpoint) diff --git a/modules/text/machine_translation/transformer/zh-en/utils.py b/modules/text/machine_translation/transformer/zh-en/utils.py index 8a556ac668748d3e8ad031c994c295374ee27340..aea02ca859462b0a8820e4f116091c5ab47689d2 100644 --- a/modules/text/machine_translation/transformer/zh-en/utils.py +++ b/modules/text/machine_translation/transformer/zh-en/utils.py @@ -28,11 +28,12 @@ from subword_nmt.apply_bpe import BPE class MTTokenizer(object): def __init__(self, bpe_codes_file: str, lang_src: str = 'zh', lang_trg: str = 'en', separator='@@'): self.moses_detokenizer = MosesDetokenizer(lang=lang_trg) - self.bpe_tokenizer = BPE(codes=codecs.open(bpe_codes_file, encoding='utf-8'), - merges=-1, - separator=separator, - vocab=None, - glossaries=None) + self.bpe_tokenizer = BPE( + codes=codecs.open(bpe_codes_file, encoding='utf-8'), + merges=-1, + separator=separator, + vocab=None, + glossaries=None) def tokenize(self, text: str): """