未验证 提交 cd646190 编写于 作者: L LiuChiachi 提交者: GitHub

Fix seq2seq readme error (#5069)

* update __init__

* update couplet api usage, delete useless import

* update seq2seq readme

* update seq2seq readme path
上级 5f0ace0a
......@@ -31,7 +31,7 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)
本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集。
### 数据获取
如果用户在初始化数据集时没有提供路径,数据集会自动下载到`paddlenlp.utils.env.DATA_HOME``/machine_translation/IWSLT15/`路径下,例如在linux系统下,默认存储路径是`/root/.paddlenlp/datasets/machine_translation/IWSLT15`
如果用户在初始化数据集时没有提供路径,数据集会自动下载到`paddlenlp.utils.env.DATA_HOME``/machine_translation/IWSLT15/`路径下,例如在linux系统下,默认存储路径是`~/.paddlenlp/datasets/machine_translation/IWSLT15`
## 模型训练
......@@ -56,7 +56,7 @@ python train.py \
## 模型预测
训练完成之后,可以使用保存的模型(由 `--init_from_ckpt` 指定)对测试集的数据集进行beam search解码,其中译文数据由 `--infer_target_file` 指定),在linux系统下,默认安装路径为`/root/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi`,如果您使用的是Windows系统,需要更改下面的路径。预测命令如下:
训练完成之后,可以使用保存的模型(由 `--init_from_ckpt` 指定)对测试集的数据集进行beam search解码,其中译文数据由 `--infer_target_file` 指定),在linux系统下,默认安装路径为`~/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi`,如果您使用的是Windows系统,需要更改下面的路径。预测命令如下:
```sh
python predict.py \
......@@ -67,7 +67,7 @@ python predict.py \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--init_from_ckpt attention_models/9 \
--infer_target_file /root/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi \
--infer_target_file ~/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi \
--infer_output_file infer_output.txt \
--beam_size 10 \
--use_gpu True
......@@ -75,11 +75,12 @@ python predict.py \
各参数的具体说明请参阅 `args.py` ,注意预测时所用模型超参数需和训练时一致。
## 效果评价
使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,使用方法如下:
## 预测效果评价
使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,将该工具下载在该项目路径下,然后使用如下的命令,可以看到BLEU指标的结果
(需要注意的是,在windows系统下,可能需要更改文件路径`~/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi`):
```sh
perl mosesdecoder/scripts/generic/multi-bleu.perl data/en-vi/tst2013.vi < infer_output.txt
perl mosesdecoder/scripts/generic/multi-bleu.perl ~/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi < infer_output.txt
```
取第10个epoch保存的模型进行预测,取beam_size=10。效果如下:
......
......@@ -30,7 +30,7 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)
本教程使用[couplet数据集](https://paddlenlp.bj.bcebos.com/datasets/couplet.tar.gz)数据集作为训练语料,train_src.tsv及train_tgt.tsv为训练集,dev_src.tsv及test_tgt.tsv为开发集,test_src.tsv及test_tgt.tsv为测试集。
数据集会在`CoupletDataset`初始化时自动下载,如果用户在初始化数据集时没有提供路径,在linux系统下,数据集会自动下载到`/root/.paddlenlp/datasets/machine_translation/CoupletDataset/`目录下
数据集会在`CoupletDataset`初始化时自动下载,如果用户在初始化数据集时没有提供路径,在linux系统下,数据集会自动下载到`~/.paddlenlp/datasets/machine_translation/CoupletDataset/`目录下
## 模型训练
......@@ -72,16 +72,12 @@ python predict.py \
上联:崖悬风雨骤 下联:月落水云寒
上联:约春章柳下 下联:邀月醉花间
上联:箬笠红尘外 下联:扁舟明月中
上联:书香醉倒窗前月 下联:烛影摇红梦里人
上联:踏雪寻梅求雅趣 下联:临风把酒觅知音
上联:未出南阳天下论 下联:先登北斗汉中书
......
......@@ -19,10 +19,8 @@ from functools import partial
import numpy as np
import paddle
from paddle.utils.download import get_path_from_url
from paddlenlp.data import Vocab, Pad
from paddlenlp.data import SamplerHelper
from paddlenlp.utils.env import DATA_HOME
from paddlenlp.datasets import TranslationDataset
......@@ -32,7 +30,7 @@ def create_train_loader(batch_size=128):
pad_id = vocab[CoupletDataset.EOS_TOKEN]
train_batch_sampler = SamplerHelper(train_ds).shuffle().batch(
batch_size=batch_size).shard()
batch_size=batch_size)
train_loader = paddle.io.DataLoader(
train_ds,
......@@ -50,8 +48,7 @@ def create_infer_loader(batch_size=128):
bos_id = vocab[CoupletDataset.BOS_TOKEN]
eos_id = vocab[CoupletDataset.EOS_TOKEN]
test_batch_sampler = SamplerHelper(test_ds).batch(
batch_size=batch_size).shard()
test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size)
test_loader = paddle.io.DataLoader(
test_ds,
......@@ -103,7 +100,7 @@ class CoupletDataset(TranslationDataset):
raise TypeError(
'`train`, `dev` or `test` is supported but `{}` is passed in'.
format(mode))
# Download data
# Download and read data
self.data = self.get_data(mode=mode, root=root)
self.vocab, _ = self.get_vocab(root)
self.transform()
......
......@@ -44,7 +44,7 @@ def do_predict(args):
test_loader, vocab_size, pad_id, bos_id, eos_id = create_infer_loader(
args.batch_size)
vocab, _ = CoupletDataset.get_vocab()
trg_idx2word = vocab._idx_to_token
trg_idx2word = vocab.idx_to_token
model = paddle.Model(
Seq2SeqAttnInferModel(
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = '2.0.0a7'
__version__ = '2.0.0a8'
from . import data
from . import datasets
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册