From a22fa4b306e4735a459a4d919f16f6dceb6c5013 Mon Sep 17 00:00:00 2001 From: Jiaqi Liu Date: Fri, 5 Feb 2021 17:21:35 +0800 Subject: [PATCH] add seq2seq infer (#5205) * add seq2seq infer * update argument description, remove useless import * add deploy directory * add deploy directory, add relative import * update arg usage in README, fix import order. --- .../machine_translation/seq2seq/README.md | 39 +++++- .../machine_translation/seq2seq/args.py | 14 ++- .../seq2seq/deploy/python/infer.py | 116 ++++++++++++++++++ .../seq2seq/export_model.py | 62 ++++++++++ 4 files changed, 222 insertions(+), 9 deletions(-) create mode 100644 PaddleNLP/examples/machine_translation/seq2seq/deploy/python/infer.py create mode 100644 PaddleNLP/examples/machine_translation/seq2seq/export_model.py diff --git a/PaddleNLP/examples/machine_translation/seq2seq/README.md b/PaddleNLP/examples/machine_translation/seq2seq/README.md index 3de3c4d2..3a7562b9 100644 --- a/PaddleNLP/examples/machine_translation/seq2seq/README.md +++ b/PaddleNLP/examples/machine_translation/seq2seq/README.md @@ -4,12 +4,15 @@ ``` . +├── deploy # 预测部署目录 +│ └── python +│ └── infer.py # 用预测模型进行推理的程序 ├── README.md # 文档,本文件 -├── args.py # 训练、预测以及模型参数配置程序 +├── args.py # 训练、预测、导出模型以及模型参数配置程序 ├── data.py # 数据读入程序 -├── download.py # 数据下载程序 ├── train.py # 训练主程序 ├── predict.py # 预测主程序 +├── export_model.py # 导出预测模型的程序 └── seq2seq_attn.py # 带注意力机制的翻译模型程序 ``` @@ -45,9 +48,8 @@ python train.py \ --dropout 0.2 \ --init_scale 0.1 \ --max_grad_norm 5.0 \ - --use_gpu True \ + --select_device gpu \ --model_path ./attention_models - ``` 各参数的具体说明请参阅 `args.py` 。训练程序会在每个epoch训练结束之后,save一次模型。 @@ -69,10 +71,37 @@ python predict.py \ --init_from_ckpt attention_models/9 \ --infer_output_file infer_output.txt \ --beam_size 10 \ - --use_gpu True + --select_device gpu ``` 各参数的具体说明请参阅 `args.py` ,注意预测时所用模型超参数需和训练时一致。 ## 预测效果评价 取第10个epoch的结果,用取beam_size为10的beam search解码,`predict.py`脚本在生成翻译结果之后,会调用`paddlenlp.metrics.BLEU`计算翻译结果的BLEU指标,最终计算出的BLEU分数为0.24074304399683688。 + +## 保存预测模型 +这里指定的参数`export_path` 表示导出预测模型文件的前缀。保存时会添加后缀(`pdiparams`,`pdiparams.info`,`pdmodel`)。 +```shell +python export_model.py \ + --num_layers 2 \ + --hidden_size 512 \ + --batch_size 128 \ + --dropout 0.2 \ + --init_scale 0.1 \ + --max_grad_norm 5.0 \ + --init_from_ckpt attention_models/9.pdparams \ + --beam_size 10 \ + --export_path ./infer_model/model +``` + +## 基于预测引擎推理 +然后按照如下的方式对IWSLT15数据集中的测试集(有标注的)进行预测(基于Paddle的[Python预测API](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/guides/05_inference_deployment/inference/python_infer_cn.html)): + +```shell +cd deploy/python +python infer.py \ + --export_path ../../infer_model/model \ + --select_device gpu \ + --batch_size 128 \ + --infer_output_file infer_output.txt +``` diff --git a/PaddleNLP/examples/machine_translation/seq2seq/args.py b/PaddleNLP/examples/machine_translation/seq2seq/args.py index 9882b72c..9d325cfc 100644 --- a/PaddleNLP/examples/machine_translation/seq2seq/args.py +++ b/PaddleNLP/examples/machine_translation/seq2seq/args.py @@ -91,10 +91,10 @@ def parse_args(): "--beam_size", type=int, default=10, help="file name for inference") parser.add_argument( - '--use_gpu', - type=eval, - default=False, - help='Whether using gpu [True|False]') + "--select_device", + default="gpu", + choices=["gpu", "cpu", "xpu"], + help="Device selected for inference.") parser.add_argument( "--init_from_ckpt", @@ -102,5 +102,11 @@ def parse_args(): default=None, help="The path of checkpoint to be loaded.") + parser.add_argument( + "--export_path", + type=str, + default=None, + help="The output file prefix used to save the exported inference model.") + args = parser.parse_args() return args diff --git a/PaddleNLP/examples/machine_translation/seq2seq/deploy/python/infer.py b/PaddleNLP/examples/machine_translation/seq2seq/deploy/python/infer.py new file mode 100644 index 00000000..ead974ed --- /dev/null +++ b/PaddleNLP/examples/machine_translation/seq2seq/deploy/python/infer.py @@ -0,0 +1,116 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import io +import os +import sys +sys.path.append("../../") + +from functools import partial +import numpy as np + +import paddle +from paddle import inference +from paddlenlp.datasets import IWSLT15 +from paddlenlp.metrics import BLEU + +from args import parse_args +from data import create_infer_loader +from predict import post_process_seq + + +class Predictor(object): + def __init__(self, predictor, input_handles, output_handles): + self.predictor = predictor + self.input_handles = input_handles + self.output_handles = output_handles + + @classmethod + def create_predictor(cls, args): + config = paddle.inference.Config(args.export_path + ".pdmodel", + args.export_path + ".pdiparams") + if args.select_device == "gpu": + # set GPU configs accordingly + config.enable_use_gpu(100, 0) + elif args.select_device == "cpu": + # set CPU configs accordingly, + # such as enable_mkldnn, set_cpu_math_library_num_threads + config.disable_gpu() + elif args.select_device == "xpu": + # set XPU configs accordingly + config.enable_xpu(100) + config.switch_use_feed_fetch_ops(False) + predictor = paddle.inference.create_predictor(config) + input_handles = [ + predictor.get_input_handle(name) + for name in predictor.get_input_names() + ] + output_handles = [ + predictor.get_input_handle(name) + for name in predictor.get_output_names() + ] + return cls(predictor, input_handles, output_handles) + + def predict_batch(self, data): + for input_field, input_handle in zip(data, self.input_handles): + input_handle.copy_from_cpu(input_field.numpy() if isinstance( + input_field, paddle.Tensor) else input_field) + self.predictor.run() + output = [ + output_handle.copy_to_cpu() for output_handle in self.output_handles + ] + return output + + def predict(self, dataloader, infer_output_file, trg_idx2word, bos_id, + eos_id): + cand_list = [] + with io.open(infer_output_file, 'w', encoding='utf-8') as f: + for data in dataloader(): + finished_seq = self.predict_batch(data)[0] + finished_seq = finished_seq[:, :, np.newaxis] if len( + finished_seq.shape) == 2 else finished_seq + finished_seq = np.transpose(finished_seq, [0, 2, 1]) + for ins in finished_seq: + for beam_idx, beam in enumerate(ins): + id_list = post_process_seq(beam, bos_id, eos_id) + word_list = [trg_idx2word[id] for id in id_list] + sequence = " ".join(word_list) + "\n" + f.write(sequence) + cand_list.append(word_list) + break + + test_ds = IWSLT15.get_datasets(["test"]) + bleu = BLEU() + for i, data in enumerate(test_ds): + ref = data[1].split() + bleu.add_inst(cand_list[i], [ref]) + print("BLEU score is %s." % bleu.score()) + + +def main(): + args = parse_args() + + predictor = Predictor.create_predictor(args) + test_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_infer_loader( + args) + _, vocab = IWSLT15.get_vocab() + trg_idx2word = vocab.idx_to_token + + predictor.predict(test_loader, args.infer_output_file, trg_idx2word, bos_id, + eos_id) + + +if __name__ == "__main__": + main() diff --git a/PaddleNLP/examples/machine_translation/seq2seq/export_model.py b/PaddleNLP/examples/machine_translation/seq2seq/export_model.py new file mode 100644 index 00000000..84e012e8 --- /dev/null +++ b/PaddleNLP/examples/machine_translation/seq2seq/export_model.py @@ -0,0 +1,62 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import paddle + +from seq2seq_attn import Seq2SeqAttnInferModel +from args import parse_args +from data import create_infer_loader + + +def main(): + args = parse_args() + _, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_infer_loader( + args) + + # Build model and load trained parameters + model = Seq2SeqAttnInferModel( + src_vocab_size, + tgt_vocab_size, + args.hidden_size, + args.hidden_size, + args.num_layers, + args.dropout, + bos_id=bos_id, + eos_id=eos_id, + beam_size=args.beam_size, + max_out_len=256) + + # Load the trained model + model.set_state_dict(paddle.load(args.init_from_ckpt)) + + # Wwitch to eval model + model.eval() + # Convert to static graph with specific input description + model = paddle.jit.to_static( + model, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None], dtype="int64"), # src + paddle.static.InputSpec( + shape=[None], dtype="int64") # src length + ]) + # Save converted static graph model + paddle.jit.save(model, args.export_path) + + +if __name__ == "__main__": + main() -- GitLab