未验证 提交 147dce2b 编写于 作者: L LiuChiachi 提交者: GitHub

Update seq2seq using bleu (#5079)

* update __init__

* update seq2seq example, using bleu

* remove target file arg

* update readme
上级 156f29da
......@@ -56,7 +56,7 @@ python train.py \
## 模型预测
训练完成之后,可以使用保存的模型(由 `--init_from_ckpt` 指定)对测试集的数据集进行beam search解码,其中译文数据由 `--infer_target_file` 指定),在linux系统下,默认安装路径为`~/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi`,如果您使用的是Windows系统,需要更改下面的路径。预测命令如下:
训练完成之后,可以使用保存的模型(由 `--init_from_ckpt` 指定)对测试集的数据集进行beam search解码。生成的翻译结果位于`--infer_output_file`指定的路径,预测命令如下:
```sh
python predict.py \
......@@ -67,7 +67,6 @@ python predict.py \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--init_from_ckpt attention_models/9 \
--infer_target_file ~/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi \
--infer_output_file infer_output.txt \
--beam_size 10 \
--use_gpu True
......@@ -76,16 +75,4 @@ python predict.py \
各参数的具体说明请参阅 `args.py` ,注意预测时所用模型超参数需和训练时一致。
## 预测效果评价
使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,将该工具下载在该项目路径下,然后使用如下的命令,可以看到BLEU指标的结果
(需要注意的是,在windows系统下,可能需要更改文件路径`~/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi`):
```sh
perl mosesdecoder/scripts/generic/multi-bleu.perl ~/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi < infer_output.txt
```
取第10个epoch保存的模型进行预测,取beam_size=10。效果如下:
```
tst2013 BLEU: 24.40
```
取第10个epoch的结果,用取beam_size为10的beam search解码,`predict.py`脚本在生成翻译结果之后,会调用`paddlenlp.metrics.BLEU`计算翻译结果的BLEU指标,最终计算出的BLEU分数为0.24074304399683688。
......@@ -81,8 +81,6 @@ def parse_args():
default='model',
help="model path for model to save")
parser.add_argument(
"--infer_target_file", type=str, help="target file name for inference")
parser.add_argument(
"--infer_output_file",
type=str,
......
......@@ -21,6 +21,7 @@ from args import parse_args
from seq2seq_attn import Seq2SeqAttnInferModel
from data import create_infer_loader
from paddlenlp.datasets import IWSLT15
from paddlenlp.metrics import BLEU
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
......@@ -45,6 +46,7 @@ def do_predict(args):
test_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_infer_loader(
args)
_, vocab = IWSLT15.get_vocab()
trg_idx2word = vocab.idx_to_token
model = paddle.Model(
......@@ -67,6 +69,7 @@ def do_predict(args):
"Please set reload_model to load the infer model.")
model.load(args.init_from_ckpt)
cand_list = []
with io.open(args.infer_output_file, 'w', encoding='utf-8') as f:
for data in test_loader():
with paddle.no_grad():
......@@ -80,8 +83,17 @@ def do_predict(args):
word_list = [trg_idx2word[id] for id in id_list]
sequence = " ".join(word_list) + "\n"
f.write(sequence)
cand_list.append(word_list)
break
test_ds = IWSLT15.get_datasets(["test"])
bleu = BLEU()
for i, data in enumerate(test_ds):
ref = data[1].split()
bleu.add_inst(cand_list[i], [ref])
print("BLEU score is %s." % bleu.score())
if __name__ == "__main__":
args = parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册