From f80c221c9317a973da753154964e763e32037cf9 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Wed, 9 Jun 2021 10:29:19 +0800 Subject: [PATCH] Solve code formatting problem for ci --- .../transformer/en-de/module.py | 31 +++++++------- .../transformer/en-de/utils.py | 11 ++--- .../transformer/zh-en/module.py | 40 ++++++++++--------- .../transformer/zh-en/utils.py | 11 ++--- 4 files changed, 49 insertions(+), 44 deletions(-) diff --git a/modules/text/machine_translation/transformer/en-de/module.py b/modules/text/machine_translation/transformer/en-de/module.py index a3826dd1..5508108a 100644 --- a/modules/text/machine_translation/transformer/en-de/module.py +++ b/modules/text/machine_translation/transformer/en-de/module.py @@ -82,23 +82,24 @@ class MTTransformer(nn.Layer): self.max_length = max_length self.beam_size = beam_size - self.tokenizer = MTTokenizer(bpe_codes_file=bpe_codes_file, - lang_src=self.lang_config['source'], - lang_trg=self.lang_config['target']) - self.vocab = Vocab.load_vocabulary(filepath=vocab_file, - unk_token=self.vocab_config['unk_token'], - bos_token=self.vocab_config['bos_token'], - eos_token=self.vocab_config['eos_token']) + self.tokenizer = MTTokenizer( + bpe_codes_file=bpe_codes_file, lang_src=self.lang_config['source'], lang_trg=self.lang_config['target']) + self.vocab = Vocab.load_vocabulary( + filepath=vocab_file, + unk_token=self.vocab_config['unk_token'], + bos_token=self.vocab_config['bos_token'], + eos_token=self.vocab_config['eos_token']) self.vocab_size = (len(self.vocab) + self.vocab_config['pad_factor'] - 1) \ // self.vocab_config['pad_factor'] * self.vocab_config['pad_factor'] - self.transformer = InferTransformerModel(src_vocab_size=self.vocab_size, - trg_vocab_size=self.vocab_size, - bos_id=self.vocab_config['bos_id'], - eos_id=self.vocab_config['eos_id'], - max_length=self.max_length + 1, - max_out_len=max_out_len, - beam_size=self.beam_size, - **self.model_config) + self.transformer = InferTransformerModel( + src_vocab_size=self.vocab_size, + trg_vocab_size=self.vocab_size, + bos_id=self.vocab_config['bos_id'], + eos_id=self.vocab_config['eos_id'], + max_length=self.max_length + 1, + max_out_len=max_out_len, + beam_size=self.beam_size, + **self.model_config) state_dict = paddle.load(checkpoint) diff --git a/modules/text/machine_translation/transformer/en-de/utils.py b/modules/text/machine_translation/transformer/en-de/utils.py index 761f656b..ea3ceba9 100644 --- a/modules/text/machine_translation/transformer/en-de/utils.py +++ b/modules/text/machine_translation/transformer/en-de/utils.py @@ -24,11 +24,12 @@ class MTTokenizer(object): def __init__(self, bpe_codes_file: str, lang_src: str = 'en', lang_trg: str = 'de', separator='@@'): self.moses_tokenizer = MosesTokenizer(lang=lang_src) self.moses_detokenizer = MosesDetokenizer(lang=lang_trg) - self.bpe_tokenizer = BPE(codes=codecs.open(bpe_codes_file, encoding='utf-8'), - merges=-1, - separator=separator, - vocab=None, - glossaries=None) + self.bpe_tokenizer = BPE( + codes=codecs.open(bpe_codes_file, encoding='utf-8'), + merges=-1, + separator=separator, + vocab=None, + glossaries=None) def tokenize(self, text: str): """ diff --git a/modules/text/machine_translation/transformer/zh-en/module.py b/modules/text/machine_translation/transformer/zh-en/module.py index 647c08f9..ad594d95 100644 --- a/modules/text/machine_translation/transformer/zh-en/module.py +++ b/modules/text/machine_translation/transformer/zh-en/module.py @@ -83,29 +83,31 @@ class MTTransformer(nn.Layer): self.max_length = max_length self.beam_size = beam_size - self.tokenizer = MTTokenizer(bpe_codes_file=bpe_codes_file, - lang_src=self.lang_config['source'], - lang_trg=self.lang_config['target']) - self.src_vocab = Vocab.load_vocabulary(filepath=src_vocab_file, - unk_token=self.vocab_config['unk_token'], - bos_token=self.vocab_config['bos_token'], - eos_token=self.vocab_config['eos_token']) - self.trg_vocab = Vocab.load_vocabulary(filepath=trg_vocab_file, - unk_token=self.vocab_config['unk_token'], - bos_token=self.vocab_config['bos_token'], - eos_token=self.vocab_config['eos_token']) + self.tokenizer = MTTokenizer( + bpe_codes_file=bpe_codes_file, lang_src=self.lang_config['source'], lang_trg=self.lang_config['target']) + self.src_vocab = Vocab.load_vocabulary( + filepath=src_vocab_file, + unk_token=self.vocab_config['unk_token'], + bos_token=self.vocab_config['bos_token'], + eos_token=self.vocab_config['eos_token']) + self.trg_vocab = Vocab.load_vocabulary( + filepath=trg_vocab_file, + unk_token=self.vocab_config['unk_token'], + bos_token=self.vocab_config['bos_token'], + eos_token=self.vocab_config['eos_token']) self.src_vocab_size = (len(self.src_vocab) + self.vocab_config['pad_factor'] - 1) \ // self.vocab_config['pad_factor'] * self.vocab_config['pad_factor'] self.trg_vocab_size = (len(self.trg_vocab) + self.vocab_config['pad_factor'] - 1) \ // self.vocab_config['pad_factor'] * self.vocab_config['pad_factor'] - self.transformer = InferTransformerModel(src_vocab_size=self.src_vocab_size, - trg_vocab_size=self.trg_vocab_size, - bos_id=self.vocab_config['bos_id'], - eos_id=self.vocab_config['eos_id'], - max_length=self.max_length + 1, - max_out_len=max_out_len, - beam_size=self.beam_size, - **self.model_config) + self.transformer = InferTransformerModel( + src_vocab_size=self.src_vocab_size, + trg_vocab_size=self.trg_vocab_size, + bos_id=self.vocab_config['bos_id'], + eos_id=self.vocab_config['eos_id'], + max_length=self.max_length + 1, + max_out_len=max_out_len, + beam_size=self.beam_size, + **self.model_config) state_dict = paddle.load(checkpoint) diff --git a/modules/text/machine_translation/transformer/zh-en/utils.py b/modules/text/machine_translation/transformer/zh-en/utils.py index 8a556ac6..aea02ca8 100644 --- a/modules/text/machine_translation/transformer/zh-en/utils.py +++ b/modules/text/machine_translation/transformer/zh-en/utils.py @@ -28,11 +28,12 @@ from subword_nmt.apply_bpe import BPE class MTTokenizer(object): def __init__(self, bpe_codes_file: str, lang_src: str = 'zh', lang_trg: str = 'en', separator='@@'): self.moses_detokenizer = MosesDetokenizer(lang=lang_trg) - self.bpe_tokenizer = BPE(codes=codecs.open(bpe_codes_file, encoding='utf-8'), - merges=-1, - separator=separator, - vocab=None, - glossaries=None) + self.bpe_tokenizer = BPE( + codes=codecs.open(bpe_codes_file, encoding='utf-8'), + merges=-1, + separator=separator, + vocab=None, + glossaries=None) def tokenize(self, text: str): """ -- GitLab