提交 3a0bb1c0 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #282 from ranqiu92/mt_with_external_memory

fix bugs of mt_with_external_memory.
...@@ -58,7 +58,7 @@ ...@@ -58,7 +58,7 @@
(详情请参考论文\[[1](#参考文献)\])。根据寻址情况,图灵机写入 $M$ 或从 $M$ 读出信息,供其他网络使用。神经图灵机结构示意图,见图3,引自\[[1](#参考文献)\] (详情请参考论文\[[1](#参考文献)\])。根据寻址情况,图灵机写入 $M$ 或从 $M$ 读出信息,供其他网络使用。神经图灵机结构示意图,见图3,引自\[[1](#参考文献)\]
<div align="center"> <div align="center">
<img src="image/neural_turing_machine_arch.png"><br/> <img src="image/neural_turing_machine_arch.png" width="400"><br/>
图3. 神经图灵机结构示意图 图3. 神经图灵机结构示意图
</div> </div>
...@@ -440,7 +440,7 @@ python infer.py ...@@ -440,7 +440,7 @@ python infer.py
或自定义部分参数, 例如: 或自定义部分参数, 例如:
```bash ```bash
python train.py \ python infer.py \
--dict_size 30000 \ --dict_size 30000 \
--word_vec_dim 512 \ --word_vec_dim 512 \
--hidden_size 1024 \ --hidden_size 1024 \
...@@ -448,7 +448,7 @@ python train.py \ ...@@ -448,7 +448,7 @@ python train.py \
--use_gpu False \ --use_gpu False \
--trainer_count 1 \ --trainer_count 1 \
--memory_perturb_stddev 0.1 \ --memory_perturb_stddev 0.1 \
--infer_num_data 10 \ --infer_data_num 10 \
--model_filepath checkpoints/params.latest.tar.gz \ --model_filepath checkpoints/params.latest.tar.gz \
--beam_size 3 --beam_size 3
``` ```
......
...@@ -113,7 +113,7 @@ class ExternalMemory(object): ...@@ -113,7 +113,7 @@ class ExternalMemory(object):
boot_layer=self.initial_weight) boot_layer=self.initial_weight)
interpolated_weight = paddle.layer.interpolation( interpolated_weight = paddle.layer.interpolation(
name=self.name + "_addressing_weight_" + head_name, name=self.name + "_addressing_weight_" + head_name,
input=[addressing_weight, addressing_weight], input=[last_addressing_weight, addressing_weight],
weight=paddle.layer.expand(input=gate, expand_as=addressing_weight)) weight=paddle.layer.expand(input=gate, expand_as=addressing_weight))
return interpolated_weight return interpolated_weight
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import distutils.util import distutils.util
import argparse import argparse
import gzip import gzip
import random
import paddle.v2 as paddle import paddle.v2 as paddle
from external_memory import ExternalMemory from external_memory import ExternalMemory
...@@ -118,10 +119,11 @@ def infer(): ...@@ -118,10 +119,11 @@ def infer():
infer_data = [] infer_data = []
random.seed(0) # for keeping consitancy for multiple runs random.seed(0) # for keeping consitancy for multiple runs
bounded_memory_perturbation = [[ bounded_memory_perturbation = [[
random.gauss(0, memory_perturb_stddev) for i in xrange(args.hidden_size) random.gauss(0, args.memory_perturb_stddev)
for i in xrange(args.hidden_size)
] for j in xrange(args.memory_slot_num)] ] for j in xrange(args.memory_slot_num)]
test_append_reader = reader_append_wrapper( test_append_reader = reader_append_wrapper(
reader=paddle.dataset.wmt14.test(dict_size), reader=paddle.dataset.wmt14.test(args.dict_size),
append_tuple=(bounded_memory_perturbation, )) append_tuple=(bounded_memory_perturbation, ))
for i, item in enumerate(test_append_reader()): for i, item in enumerate(test_append_reader()):
if i < args.infer_data_num: if i < args.infer_data_num:
...@@ -134,8 +136,8 @@ def infer(): ...@@ -134,8 +136,8 @@ def infer():
input=infer_data, input=infer_data,
field=['prob', 'id']) field=['prob', 'id'])
# parse beam result and print # parse beam result and print
source_dict, target_dict = paddle.dataset.wmt14.get_dict(dict_size) source_dict, target_dict = paddle.dataset.wmt14.get_dict(args.dict_size)
beam_probs, beam_sentences = parse_beam_search_result(beam_result, beam_probs, beam_sentences = parse_beam_search_result(beam_result,
target_dict) target_dict)
for i in xrange(args.infer_data_num): for i in xrange(args.infer_data_num):
...@@ -147,7 +149,7 @@ def infer(): ...@@ -147,7 +149,7 @@ def infer():
def main(): def main():
paddle.init(use_gpu=False, trainer_count=1) paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
infer() infer()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册