From e3ee127f56d7b9b21bb8f3363dda267266c0106d Mon Sep 17 00:00:00 2001 From: jm12138 <2286040843@qq.com> Date: Thu, 29 Dec 2022 10:13:11 +0800 Subject: [PATCH] add gradio app (#2161) --- .../transformer/zh-en/README.md | 16 +++- .../transformer/zh-en/module.py | 77 +++++++++++-------- .../transformer/zh-en/utils.py | 17 ++-- 3 files changed, 67 insertions(+), 43 deletions(-) diff --git a/modules/text/machine_translation/transformer/zh-en/README.md b/modules/text/machine_translation/transformer/zh-en/README.md index db4135f8..75f5a67d 100644 --- a/modules/text/machine_translation/transformer/zh-en/README.md +++ b/modules/text/machine_translation/transformer/zh-en/README.md @@ -1,6 +1,6 @@ # transformer_zh-en |模型名称|transformer_zh-en| -| :--- | :---: | +| :--- | :---: | |类别|文本-机器翻译| |网络|Transformer| |数据集|CWMT2021| @@ -24,7 +24,7 @@ - ### 1、环境依赖 - paddlepaddle >= 2.1.0 - + - paddlehub >= 2.1.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst) - ### 2、安装 @@ -55,7 +55,7 @@ print('-'*30) print(f'src: {st}') for i in range(n_best): - print(f'trg[{i+1}]: {trg_texts[idx*n_best+i]}') + print(f'trg[{i+1}]: {trg_texts[idx*n_best+i]}') ``` - ### 2、API @@ -132,6 +132,9 @@ - 关于PaddleHub Serving更多信息参考:[服务部署](../../../../docs/docs_ch/tutorial/serving.md) +- ### Gradio APP 支持 + 从 PaddleHub 2.3.1 开始支持使用链接 http://127.0.0.1:8866/gradio/transformer_zh-en 在浏览器中访问 transformer_zh-en 的 Gradio APP。 + ## 五、更新历史 * 1.0.0 @@ -141,6 +144,11 @@ * 1.0.1 修复模型初始化的兼容性问题 + +* 1.1.0 + + 添加 Gradio APP 支持 + - ```shell - $ hub install transformer_zh-en==1.0.1 + $ hub install transformer_zh-en==1.1.0 ``` diff --git a/modules/text/machine_translation/transformer/zh-en/module.py b/modules/text/machine_translation/transformer/zh-en/module.py index 318d5728..2589fd9d 100644 --- a/modules/text/machine_translation/transformer/zh-en/module.py +++ b/modules/text/machine_translation/transformer/zh-en/module.py @@ -11,24 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os from typing import List import paddle import paddle.nn as nn -from paddlehub.env import MODULE_HOME -from paddlehub.module.module import moduleinfo, serving -import paddlenlp -from paddlenlp.data import Pad, Vocab -from paddlenlp.transformers import InferTransformerModel, position_encoding_init +from paddlenlp.data import Pad +from paddlenlp.data import Vocab +from paddlenlp.transformers import InferTransformerModel +from paddlenlp.transformers import position_encoding_init +from transformer_zh_en.utils import MTTokenizer +from transformer_zh_en.utils import post_process_seq -from transformer_zh_en.utils import MTTokenizer, post_process_seq +from paddlehub.env import MODULE_HOME +from paddlehub.module.module import moduleinfo +from paddlehub.module.module import serving @moduleinfo( name="transformer_zh-en", - version="1.0.1", + version="1.1.0", summary="", author="PaddlePaddle", author_email="", @@ -57,7 +59,7 @@ class MTTransformer(nn.Layer): # Dropout rate 'dropout': 0, # Number of sub-layers to be stacked in the encoder and decoder. - "num_encoder_layers": 6, + "num_encoder_layers": 6, "num_decoder_layers": 6 } @@ -85,31 +87,29 @@ class MTTransformer(nn.Layer): self.max_length = max_length self.beam_size = beam_size - self.tokenizer = MTTokenizer( - bpe_codes_file=bpe_codes_file, lang_src=self.lang_config['source'], lang_trg=self.lang_config['target']) - self.src_vocab = Vocab.load_vocabulary( - filepath=src_vocab_file, - unk_token=self.vocab_config['unk_token'], - bos_token=self.vocab_config['bos_token'], - eos_token=self.vocab_config['eos_token']) - self.trg_vocab = Vocab.load_vocabulary( - filepath=trg_vocab_file, - unk_token=self.vocab_config['unk_token'], - bos_token=self.vocab_config['bos_token'], - eos_token=self.vocab_config['eos_token']) + self.tokenizer = MTTokenizer(bpe_codes_file=bpe_codes_file, + lang_src=self.lang_config['source'], + lang_trg=self.lang_config['target']) + self.src_vocab = Vocab.load_vocabulary(filepath=src_vocab_file, + unk_token=self.vocab_config['unk_token'], + bos_token=self.vocab_config['bos_token'], + eos_token=self.vocab_config['eos_token']) + self.trg_vocab = Vocab.load_vocabulary(filepath=trg_vocab_file, + unk_token=self.vocab_config['unk_token'], + bos_token=self.vocab_config['bos_token'], + eos_token=self.vocab_config['eos_token']) self.src_vocab_size = (len(self.src_vocab) + self.vocab_config['pad_factor'] - 1) \ // self.vocab_config['pad_factor'] * self.vocab_config['pad_factor'] self.trg_vocab_size = (len(self.trg_vocab) + self.vocab_config['pad_factor'] - 1) \ // self.vocab_config['pad_factor'] * self.vocab_config['pad_factor'] - self.transformer = InferTransformerModel( - src_vocab_size=self.src_vocab_size, - trg_vocab_size=self.trg_vocab_size, - bos_id=self.vocab_config['bos_id'], - eos_id=self.vocab_config['eos_id'], - max_length=self.max_length + 1, - max_out_len=max_out_len, - beam_size=self.beam_size, - **self.model_config) + self.transformer = InferTransformerModel(src_vocab_size=self.src_vocab_size, + trg_vocab_size=self.trg_vocab_size, + bos_id=self.vocab_config['bos_id'], + eos_id=self.vocab_config['eos_id'], + max_length=self.max_length + 1, + max_out_len=max_out_len, + beam_size=self.beam_size, + **self.model_config) state_dict = paddle.load(checkpoint) @@ -184,3 +184,20 @@ class MTTransformer(nn.Layer): results.append(trg_sample_text) return results + + def create_gradio_app(self): + import gradio as gr + + def inference(text): + results = self.predict(data=[text]) + return results[0] + + examples = [['今天是个好日子']] + + interface = gr.Interface(inference, + "text", [gr.outputs.Textbox(label="Translation")], + title="transformer_zh-en", + examples=examples, + allow_flagging='never') + + return interface diff --git a/modules/text/machine_translation/transformer/zh-en/utils.py b/modules/text/machine_translation/transformer/zh-en/utils.py index aea02ca8..1356a678 100644 --- a/modules/text/machine_translation/transformer/zh-en/utils.py +++ b/modules/text/machine_translation/transformer/zh-en/utils.py @@ -11,29 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import codecs import logging import re from typing import List -import codecs import jieba jieba.setLogLevel(logging.INFO) -from sacremoses import MosesTokenizer, MosesDetokenizer +from sacremoses import MosesDetokenizer from subword_nmt.apply_bpe import BPE class MTTokenizer(object): + def __init__(self, bpe_codes_file: str, lang_src: str = 'zh', lang_trg: str = 'en', separator='@@'): self.moses_detokenizer = MosesDetokenizer(lang=lang_trg) - self.bpe_tokenizer = BPE( - codes=codecs.open(bpe_codes_file, encoding='utf-8'), - merges=-1, - separator=separator, - vocab=None, - glossaries=None) + self.bpe_tokenizer = BPE(codes=codecs.open(bpe_codes_file, encoding='utf-8'), + merges=-1, + separator=separator, + vocab=None, + glossaries=None) def tokenize(self, text: str): """ -- GitLab