提交 3a0bb1c0 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #282 from ranqiu92/mt_with_external_memory

fix bugs of mt_with_external_memory.
......@@ -58,7 +58,7 @@
(详情请参考论文\[[1](#参考文献)\])。根据寻址情况,图灵机写入 $M$ 或从 $M$ 读出信息,供其他网络使用。神经图灵机结构示意图,见图3,引自\[[1](#参考文献)\]
<div align="center">
<img src="image/neural_turing_machine_arch.png"><br/>
<img src="image/neural_turing_machine_arch.png" width="400"><br/>
图3. 神经图灵机结构示意图
</div>
......@@ -440,7 +440,7 @@ python infer.py
或自定义部分参数, 例如:
```bash
python train.py \
python infer.py \
--dict_size 30000 \
--word_vec_dim 512 \
--hidden_size 1024 \
......@@ -448,7 +448,7 @@ python train.py \
--use_gpu False \
--trainer_count 1 \
--memory_perturb_stddev 0.1 \
--infer_num_data 10 \
--infer_data_num 10 \
--model_filepath checkpoints/params.latest.tar.gz \
--beam_size 3
```
......
......@@ -113,7 +113,7 @@ class ExternalMemory(object):
boot_layer=self.initial_weight)
interpolated_weight = paddle.layer.interpolation(
name=self.name + "_addressing_weight_" + head_name,
input=[addressing_weight, addressing_weight],
input=[last_addressing_weight, addressing_weight],
weight=paddle.layer.expand(input=gate, expand_as=addressing_weight))
return interpolated_weight
......
......@@ -4,6 +4,7 @@
import distutils.util
import argparse
import gzip
import random
import paddle.v2 as paddle
from external_memory import ExternalMemory
......@@ -118,10 +119,11 @@ def infer():
infer_data = []
random.seed(0) # for keeping consitancy for multiple runs
bounded_memory_perturbation = [[
random.gauss(0, memory_perturb_stddev) for i in xrange(args.hidden_size)
random.gauss(0, args.memory_perturb_stddev)
for i in xrange(args.hidden_size)
] for j in xrange(args.memory_slot_num)]
test_append_reader = reader_append_wrapper(
reader=paddle.dataset.wmt14.test(dict_size),
reader=paddle.dataset.wmt14.test(args.dict_size),
append_tuple=(bounded_memory_perturbation, ))
for i, item in enumerate(test_append_reader()):
if i < args.infer_data_num:
......@@ -134,8 +136,8 @@ def infer():
input=infer_data,
field=['prob', 'id'])
# parse beam result and print
source_dict, target_dict = paddle.dataset.wmt14.get_dict(dict_size)
# parse beam result and print
source_dict, target_dict = paddle.dataset.wmt14.get_dict(args.dict_size)
beam_probs, beam_sentences = parse_beam_search_result(beam_result,
target_dict)
for i in xrange(args.infer_data_num):
......@@ -147,7 +149,7 @@ def infer():
def main():
paddle.init(use_gpu=False, trainer_count=1)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
infer()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册