未验证 提交 cd646190 编写于 作者: L LiuChiachi 提交者: GitHub

Fix seq2seq readme error (#5069)

* update __init__

* update couplet api usage, delete useless import

* update seq2seq readme

* update seq2seq readme path
上级 5f0ace0a
...@@ -31,7 +31,7 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder) ...@@ -31,7 +31,7 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)
本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集。 本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集。
### 数据获取 ### 数据获取
如果用户在初始化数据集时没有提供路径,数据集会自动下载到`paddlenlp.utils.env.DATA_HOME``/machine_translation/IWSLT15/`路径下,例如在linux系统下,默认存储路径是`/root/.paddlenlp/datasets/machine_translation/IWSLT15` 如果用户在初始化数据集时没有提供路径,数据集会自动下载到`paddlenlp.utils.env.DATA_HOME``/machine_translation/IWSLT15/`路径下,例如在linux系统下,默认存储路径是`~/.paddlenlp/datasets/machine_translation/IWSLT15`
## 模型训练 ## 模型训练
...@@ -56,7 +56,7 @@ python train.py \ ...@@ -56,7 +56,7 @@ python train.py \
## 模型预测 ## 模型预测
训练完成之后,可以使用保存的模型(由 `--init_from_ckpt` 指定)对测试集的数据集进行beam search解码,其中译文数据由 `--infer_target_file` 指定),在linux系统下,默认安装路径为`/root/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi`,如果您使用的是Windows系统,需要更改下面的路径。预测命令如下: 训练完成之后,可以使用保存的模型(由 `--init_from_ckpt` 指定)对测试集的数据集进行beam search解码,其中译文数据由 `--infer_target_file` 指定),在linux系统下,默认安装路径为`~/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi`,如果您使用的是Windows系统,需要更改下面的路径。预测命令如下:
```sh ```sh
python predict.py \ python predict.py \
...@@ -67,7 +67,7 @@ python predict.py \ ...@@ -67,7 +67,7 @@ python predict.py \
--init_scale 0.1 \ --init_scale 0.1 \
--max_grad_norm 5.0 \ --max_grad_norm 5.0 \
--init_from_ckpt attention_models/9 \ --init_from_ckpt attention_models/9 \
--infer_target_file /root/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi \ --infer_target_file ~/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi \
--infer_output_file infer_output.txt \ --infer_output_file infer_output.txt \
--beam_size 10 \ --beam_size 10 \
--use_gpu True --use_gpu True
...@@ -75,11 +75,12 @@ python predict.py \ ...@@ -75,11 +75,12 @@ python predict.py \
各参数的具体说明请参阅 `args.py` ,注意预测时所用模型超参数需和训练时一致。 各参数的具体说明请参阅 `args.py` ,注意预测时所用模型超参数需和训练时一致。
## 效果评价 ## 预测效果评价
使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,使用方法如下: 使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,将该工具下载在该项目路径下,然后使用如下的命令,可以看到BLEU指标的结果
(需要注意的是,在windows系统下,可能需要更改文件路径`~/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi`):
```sh ```sh
perl mosesdecoder/scripts/generic/multi-bleu.perl data/en-vi/tst2013.vi < infer_output.txt perl mosesdecoder/scripts/generic/multi-bleu.perl ~/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi < infer_output.txt
``` ```
取第10个epoch保存的模型进行预测,取beam_size=10。效果如下: 取第10个epoch保存的模型进行预测,取beam_size=10。效果如下:
......
...@@ -30,7 +30,7 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder) ...@@ -30,7 +30,7 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)
本教程使用[couplet数据集](https://paddlenlp.bj.bcebos.com/datasets/couplet.tar.gz)数据集作为训练语料,train_src.tsv及train_tgt.tsv为训练集,dev_src.tsv及test_tgt.tsv为开发集,test_src.tsv及test_tgt.tsv为测试集。 本教程使用[couplet数据集](https://paddlenlp.bj.bcebos.com/datasets/couplet.tar.gz)数据集作为训练语料,train_src.tsv及train_tgt.tsv为训练集,dev_src.tsv及test_tgt.tsv为开发集,test_src.tsv及test_tgt.tsv为测试集。
数据集会在`CoupletDataset`初始化时自动下载,如果用户在初始化数据集时没有提供路径,在linux系统下,数据集会自动下载到`/root/.paddlenlp/datasets/machine_translation/CoupletDataset/`目录下 数据集会在`CoupletDataset`初始化时自动下载,如果用户在初始化数据集时没有提供路径,在linux系统下,数据集会自动下载到`~/.paddlenlp/datasets/machine_translation/CoupletDataset/`目录下
## 模型训练 ## 模型训练
...@@ -72,16 +72,12 @@ python predict.py \ ...@@ -72,16 +72,12 @@ python predict.py \
上联:崖悬风雨骤 下联:月落水云寒 上联:崖悬风雨骤 下联:月落水云寒
上联:约春章柳下 下联:邀月醉花间 上联:约春章柳下 下联:邀月醉花间
上联:箬笠红尘外 下联:扁舟明月中 上联:箬笠红尘外 下联:扁舟明月中
上联:书香醉倒窗前月 下联:烛影摇红梦里人 上联:书香醉倒窗前月 下联:烛影摇红梦里人
上联:踏雪寻梅求雅趣 下联:临风把酒觅知音 上联:踏雪寻梅求雅趣 下联:临风把酒觅知音
上联:未出南阳天下论 下联:先登北斗汉中书 上联:未出南阳天下论 下联:先登北斗汉中书
......
...@@ -19,10 +19,8 @@ from functools import partial ...@@ -19,10 +19,8 @@ from functools import partial
import numpy as np import numpy as np
import paddle import paddle
from paddle.utils.download import get_path_from_url
from paddlenlp.data import Vocab, Pad from paddlenlp.data import Vocab, Pad
from paddlenlp.data import SamplerHelper from paddlenlp.data import SamplerHelper
from paddlenlp.utils.env import DATA_HOME
from paddlenlp.datasets import TranslationDataset from paddlenlp.datasets import TranslationDataset
...@@ -32,7 +30,7 @@ def create_train_loader(batch_size=128): ...@@ -32,7 +30,7 @@ def create_train_loader(batch_size=128):
pad_id = vocab[CoupletDataset.EOS_TOKEN] pad_id = vocab[CoupletDataset.EOS_TOKEN]
train_batch_sampler = SamplerHelper(train_ds).shuffle().batch( train_batch_sampler = SamplerHelper(train_ds).shuffle().batch(
batch_size=batch_size).shard() batch_size=batch_size)
train_loader = paddle.io.DataLoader( train_loader = paddle.io.DataLoader(
train_ds, train_ds,
...@@ -50,8 +48,7 @@ def create_infer_loader(batch_size=128): ...@@ -50,8 +48,7 @@ def create_infer_loader(batch_size=128):
bos_id = vocab[CoupletDataset.BOS_TOKEN] bos_id = vocab[CoupletDataset.BOS_TOKEN]
eos_id = vocab[CoupletDataset.EOS_TOKEN] eos_id = vocab[CoupletDataset.EOS_TOKEN]
test_batch_sampler = SamplerHelper(test_ds).batch( test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size)
batch_size=batch_size).shard()
test_loader = paddle.io.DataLoader( test_loader = paddle.io.DataLoader(
test_ds, test_ds,
...@@ -103,7 +100,7 @@ class CoupletDataset(TranslationDataset): ...@@ -103,7 +100,7 @@ class CoupletDataset(TranslationDataset):
raise TypeError( raise TypeError(
'`train`, `dev` or `test` is supported but `{}` is passed in'. '`train`, `dev` or `test` is supported but `{}` is passed in'.
format(mode)) format(mode))
# Download data # Download and read data
self.data = self.get_data(mode=mode, root=root) self.data = self.get_data(mode=mode, root=root)
self.vocab, _ = self.get_vocab(root) self.vocab, _ = self.get_vocab(root)
self.transform() self.transform()
......
...@@ -44,7 +44,7 @@ def do_predict(args): ...@@ -44,7 +44,7 @@ def do_predict(args):
test_loader, vocab_size, pad_id, bos_id, eos_id = create_infer_loader( test_loader, vocab_size, pad_id, bos_id, eos_id = create_infer_loader(
args.batch_size) args.batch_size)
vocab, _ = CoupletDataset.get_vocab() vocab, _ = CoupletDataset.get_vocab()
trg_idx2word = vocab._idx_to_token trg_idx2word = vocab.idx_to_token
model = paddle.Model( model = paddle.Model(
Seq2SeqAttnInferModel( Seq2SeqAttnInferModel(
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__version__ = '2.0.0a7' __version__ = '2.0.0a8'
from . import data from . import data
from . import datasets from . import datasets
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册