提交 ecb6d64c 编写于 作者: G guosheng

Update metric of seq2seq to adapt to latest code.

上级 79066ac6
...@@ -11,11 +11,9 @@ ...@@ -11,11 +11,9 @@
├── reader.py # 数据读入程序 ├── reader.py # 数据读入程序
├── download.py # 数据下载程序 ├── download.py # 数据下载程序
├── train.py # 训练主程序 ├── train.py # 训练主程序
├── infer.py # 预测主程序 ├── predict.py # 预测主程序
├── run.sh # 默认配置的启动脚本 ├── seq2seq_attn.py # 带注意力机制的翻译模型程序
├── infer.sh # 默认配置的解码脚本 └── seq2seq_base.py # 无注意力机制的翻译模型程序
├── attention_model.py # 带注意力机制的翻译模型程序
└── base_model.py # 无注意力机制的翻译模型程序
``` ```
## 简介 ## 简介
...@@ -40,13 +38,7 @@ python download.py ...@@ -40,13 +38,7 @@ python download.py
## 模型训练 ## 模型训练
`run.sh`包含训练程序的主函数,要使用默认参数开始训练,只需要简单地执行: 执行以下命令即可训练带有注意力机制的Seq2Seq机器翻译模型:
```
sh run.sh
```
默认使用带有注意力机制的RNN模型,可以通过修改 `attention` 参数为False来训练不带注意力机制的RNN模型。
```sh ```sh
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
...@@ -70,8 +62,7 @@ python train.py \ ...@@ -70,8 +62,7 @@ python train.py \
--model_path ./attention_models --model_path ./attention_models
``` ```
训练程序会在每个epoch训练结束之后,save一次模型。 可以通过修改 `attention` 参数为False来训练不带注意力机制的Seq2Seq模型,各参数的具体说明请参阅 `args.py` 。训练程序会在每个epoch训练结束之后,save一次模型。
默认使用动态图模式进行训练,可以通过设置 `eager_run` 参数为False来以静态图模式进行训练,如下: 默认使用动态图模式进行训练,可以通过设置 `eager_run` 参数为False来以静态图模式进行训练,如下:
...@@ -100,13 +91,7 @@ python train.py \ ...@@ -100,13 +91,7 @@ python train.py \
## 模型预测 ## 模型预测
当模型训练完成之后, 可以利用infer.sh的脚本进行预测,默认使用beam search的方法进行预测,加载第10个epoch的模型进行预测,对test的数据集进行解码 训练完成之后,可以使用保存的模型(由 `--reload_model` 指定)对test的数据集(由 `--infer_file` 指定)进行beam search解码,命令如下:
```
sh infer.sh
```
如果想预测别的数据文件,只需要将 --infer_file参数进行修改。
```sh ```sh
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
...@@ -124,13 +109,13 @@ python infer.py \ ...@@ -124,13 +109,13 @@ python infer.py \
--max_grad_norm 5.0 \ --max_grad_norm 5.0 \
--vocab_prefix data/en-vi/vocab \ --vocab_prefix data/en-vi/vocab \
--infer_file data/en-vi/tst2013.en \ --infer_file data/en-vi/tst2013.en \
--reload_model attention_models/epoch_10 \ --reload_model attention_models/10 \
--infer_output_file attention_infer_output/infer_output.txt \ --infer_output_file infer_output.txt \
--beam_size 10 \ --beam_size 10 \
--use_gpu True --use_gpu True
``` ```
和训练类似,预测时同样可以以静态图模式进行,如下: 各参数的具体说明请参阅 `args.py` ,注意预测时所用模型超参数需和训练时一致。和训练类似,预测时同样可以以静态图模式进行,如下:
```sh ```sh
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
...@@ -148,8 +133,8 @@ python infer.py \ ...@@ -148,8 +133,8 @@ python infer.py \
--max_grad_norm 5.0 \ --max_grad_norm 5.0 \
--vocab_prefix data/en-vi/vocab \ --vocab_prefix data/en-vi/vocab \
--infer_file data/en-vi/tst2013.en \ --infer_file data/en-vi/tst2013.en \
--reload_model attention_models/epoch_10 \ --reload_model attention_models/10 \
--infer_output_file attention_infer_output/infer_output.txt \ --infer_output_file infer_output.txt \
--beam_size 10 \ --beam_size 10 \
--use_gpu True \ --use_gpu True \
--eager_run False --eager_run False
......
...@@ -108,7 +108,7 @@ def do_predict(args): ...@@ -108,7 +108,7 @@ def do_predict(args):
# TODO(guosheng): use model.predict when support variant length # TODO(guosheng): use model.predict when support variant length
with io.open(args.infer_output_file, 'w', encoding='utf-8') as f: with io.open(args.infer_output_file, 'w', encoding='utf-8') as f:
for data in data_loader(): for data in data_loader():
finished_seq = model.test(inputs=flatten(data))[0] finished_seq = model.test_batch(inputs=flatten(data))[0]
finished_seq = finished_seq[:, :, np.newaxis] if len( finished_seq = finished_seq[:, :, np.newaxis] if len(
finished_seq.shape) == 2 else finished_seq finished_seq.shape) == 2 else finished_seq
finished_seq = np.transpose(finished_seq, [0, 2, 1]) finished_seq = np.transpose(finished_seq, [0, 2, 1])
......
export CUDA_VISIBLE_DEVICES=0
python train.py \
--src_lang en --tar_lang vi \
--attention True \
--num_layers 2 \
--hidden_size 512 \
--src_vocab_size 17191 \
--tar_vocab_size 7709 \
--batch_size 128 \
--dropout 0.2 \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--train_data_prefix data/en-vi/train \
--eval_data_prefix data/en-vi/tst2012 \
--test_data_prefix data/en-vi/tst2013 \
--vocab_prefix data/en-vi/vocab \
--use_gpu True \
--model_path attention_models
\ No newline at end of file
...@@ -67,7 +67,7 @@ def do_train(args): ...@@ -67,7 +67,7 @@ def do_train(args):
parameter_list=model.parameters(), parameter_list=model.parameters(),
grad_clip=grad_clip) grad_clip=grad_clip)
ppl_metric = PPL() ppl_metric = PPL(reset_freq=100) # ppl for every 100 batches
model.prepare( model.prepare(
optimizer, optimizer,
CrossEntropyCriterion(), CrossEntropyCriterion(),
......
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import math
import paddle.fluid as fluid import paddle.fluid as fluid
from hapi.metrics import Metric from hapi.metrics import Metric
...@@ -55,13 +56,12 @@ class PPL(Metric): ...@@ -55,13 +56,12 @@ class PPL(Metric):
self.reset_freq = reset_freq self.reset_freq = reset_freq
self.reset() self.reset()
def add_metric_op(self, pred, label): def add_metric_op(self, pred, seq_length, label):
seq_length = label[0]
word_num = fluid.layers.reduce_sum(seq_length) word_num = fluid.layers.reduce_sum(seq_length)
return word_num return word_num
def update(self, word_num): def update(self, word_num):
self.word_count += word_num[0] self.word_count += word_num
return word_num return word_num
def reset(self): def reset(self):
...@@ -76,5 +76,5 @@ class PPL(Metric): ...@@ -76,5 +76,5 @@ class PPL(Metric):
def cal_acc_ppl(self, batch_loss, batch_size): def cal_acc_ppl(self, batch_loss, batch_size):
self.total_loss += batch_loss * batch_size self.total_loss += batch_loss * batch_size
ppl = np.exp(self.total_loss / self.word_count) ppl = math.exp(self.total_loss / self.word_count)
return ppl return ppl
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册